Pyspark Test Generator

by fusionet24

data

Generate comprehensive PySpark-based data quality validation tests for Databricks tables. Use when creating automated tests for data completeness, accuracy, consistency, and conformity, or when user mentions test generation, data validation, quality monitoring, or PySpark test frameworks.

Skill Details

Repository Files

1 file in this skill directory


name: pyspark-test-generator description: Generate comprehensive PySpark-based data quality validation tests for Databricks tables. Use when creating automated tests for data completeness, accuracy, consistency, and conformity, or when user mentions test generation, data validation, quality monitoring, or PySpark test frameworks. version: 1.0.0

PySpark Test Generator Skill

Overview

This skill enables AI agents to automatically generate comprehensive PySpark-based data quality validation tests for Databricks tables. It creates executable test suites that validate data completeness, accuracy, consistency, and conformity.

Purpose

  • Generate PySpark validation tests based on data profiling results
  • Create reusable test frameworks for data quality monitoring
  • Implement custom validation rules using PySpark SQL and DataFrame operations
  • Produce detailed test reports with pass/fail metrics
  • Support continuous data quality monitoring in production pipelines

When to Use This Skill

Use this skill when you need to:

  • Create automated data quality tests after ingestion
  • Validate data against business rules and constraints
  • Monitor data quality over time with repeatable tests
  • Generate test code from profiling metadata
  • Implement custom validation logic beyond simple assertions

Test Categories

1. Completeness Tests

Validate that required data is present and non-null.

Example: Check for null values

from pyspark.sql import functions as F

def test_completeness_customer_id(spark, table_name):
    """
    Test: customer_id column should have no null values
    Severity: CRITICAL
    """
    df = spark.table(table_name)
    total_rows = df.count()
    null_count = df.filter(F.col("customer_id").isNull()).count()

    null_percentage = (null_count / total_rows * 100) if total_rows > 0 else 0

    result = {
        "test_name": "completeness_customer_id",
        "column": "customer_id",
        "passed": null_count == 0,
        "total_rows": total_rows,
        "null_count": null_count,
        "null_percentage": null_percentage,
        "severity": "CRITICAL",
        "message": f"Found {null_count} null values ({null_percentage:.2f}%)" if null_count > 0
                   else "No null values found"
    }

    return result

2. Format/Pattern Tests

Validate data conforms to expected patterns (email, phone, UUID, etc.).

Example: Email format validation

def test_format_email(spark, table_name, column_name="email"):
    """
    Test: Email addresses should match valid email pattern
    Severity: HIGH
    """
    df = spark.table(table_name)

    # Email regex pattern
    email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'

    total_rows = df.count()
    invalid_count = df.filter(
        ~F.col(column_name).rlike(email_pattern) & F.col(column_name).isNotNull()
    ).count()

    invalid_percentage = (invalid_count / total_rows * 100) if total_rows > 0 else 0

    result = {
        "test_name": f"format_{column_name}",
        "column": column_name,
        "passed": invalid_count == 0,
        "total_rows": total_rows,
        "invalid_count": invalid_count,
        "invalid_percentage": invalid_percentage,
        "severity": "HIGH",
        "message": f"Found {invalid_count} invalid email addresses ({invalid_percentage:.2f}%)"
                   if invalid_count > 0 else "All email addresses are valid"
    }

    return result

3. Range/Boundary Tests

Validate numeric values fall within expected ranges.

Example: Age range validation

def test_range_age(spark, table_name, min_value=0, max_value=120):
    """
    Test: Age should be between 0 and 120
    Severity: MEDIUM
    """
    df = spark.table(table_name)

    total_rows = df.count()
    out_of_range = df.filter(
        (F.col("age") < min_value) | (F.col("age") > max_value)
    ).count()

    out_of_range_percentage = (out_of_range / total_rows * 100) if total_rows > 0 else 0

    # Get min and max actual values
    stats = df.agg(
        F.min("age").alias("min_age"),
        F.max("age").alias("max_age")
    ).collect()[0]

    result = {
        "test_name": "range_age",
        "column": "age",
        "passed": out_of_range == 0,
        "total_rows": total_rows,
        "out_of_range_count": out_of_range,
        "out_of_range_percentage": out_of_range_percentage,
        "expected_range": f"{min_value}-{max_value}",
        "actual_range": f"{stats['min_age']}-{stats['max_age']}",
        "severity": "MEDIUM",
        "message": f"Found {out_of_range} values outside range {min_value}-{max_value}"
                   if out_of_range > 0 else f"All values within range {min_value}-{max_value}"
    }

    return result

4. Uniqueness Tests

Validate columns that should have unique values (IDs, keys).

Example: Primary key uniqueness

def test_uniqueness_customer_id(spark, table_name):
    """
    Test: customer_id should be unique
    Severity: CRITICAL
    """
    df = spark.table(table_name)

    total_rows = df.count()
    distinct_count = df.select("customer_id").distinct().count()
    duplicate_count = total_rows - distinct_count

    duplicate_percentage = (duplicate_count / total_rows * 100) if total_rows > 0 else 0

    result = {
        "test_name": "uniqueness_customer_id",
        "column": "customer_id",
        "passed": duplicate_count == 0,
        "total_rows": total_rows,
        "distinct_count": distinct_count,
        "duplicate_count": duplicate_count,
        "duplicate_percentage": duplicate_percentage,
        "severity": "CRITICAL",
        "message": f"Found {duplicate_count} duplicate values ({duplicate_percentage:.2f}%)"
                   if duplicate_count > 0 else "All values are unique"
    }

    return result

5. Referential Integrity Tests

Validate foreign key relationships between tables.

Example: Foreign key validation

def test_referential_integrity_customer_id(spark, child_table, parent_table):
    """
    Test: All customer_ids in orders should exist in customers table
    Severity: HIGH
    """
    child_df = spark.table(child_table)
    parent_df = spark.table(parent_table)

    # Left anti join to find orphaned records
    orphaned = child_df.join(
        parent_df,
        child_df.customer_id == parent_df.customer_id,
        "left_anti"
    )

    total_child_rows = child_df.count()
    orphaned_count = orphaned.count()
    orphaned_percentage = (orphaned_count / total_child_rows * 100) if total_child_rows > 0 else 0

    result = {
        "test_name": "referential_integrity_customer_id",
        "column": "customer_id",
        "child_table": child_table,
        "parent_table": parent_table,
        "passed": orphaned_count == 0,
        "total_rows": total_child_rows,
        "orphaned_count": orphaned_count,
        "orphaned_percentage": orphaned_percentage,
        "severity": "HIGH",
        "message": f"Found {orphaned_count} orphaned records ({orphaned_percentage:.2f}%)"
                   if orphaned_count > 0 else "All foreign keys are valid"
    }

    return result

6. Statistical Tests

Validate data distributions and statistical properties.

Example: Standard deviation check

def test_statistical_amount(spark, table_name, column_name="amount"):
    """
    Test: Amount should be within 3 standard deviations of mean
    Severity: MEDIUM
    """
    df = spark.table(table_name)

    # Calculate statistics
    stats = df.select(
        F.mean(column_name).alias("mean"),
        F.stddev(column_name).alias("stddev")
    ).collect()[0]

    mean_val = stats["mean"]
    stddev_val = stats["stddev"]

    # Find outliers (beyond 3 standard deviations)
    lower_bound = mean_val - (3 * stddev_val)
    upper_bound = mean_val + (3 * stddev_val)

    total_rows = df.count()
    outliers = df.filter(
        (F.col(column_name) < lower_bound) | (F.col(column_name) > upper_bound)
    ).count()

    outlier_percentage = (outliers / total_rows * 100) if total_rows > 0 else 0

    result = {
        "test_name": f"statistical_{column_name}",
        "column": column_name,
        "passed": outlier_percentage < 1.0,  # Pass if less than 1% outliers
        "total_rows": total_rows,
        "outlier_count": outliers,
        "outlier_percentage": outlier_percentage,
        "mean": mean_val,
        "stddev": stddev_val,
        "bounds": f"{lower_bound:.2f} to {upper_bound:.2f}",
        "severity": "MEDIUM",
        "message": f"Found {outliers} outliers ({outlier_percentage:.2f}%)"
                   if outliers > 0 else "Statistical distribution is normal"
    }

    return result

7. Custom Business Rule Tests

Validate domain-specific business logic.

Example: Order total validation

def test_business_rule_order_total(spark, table_name):
    """
    Test: Order total should equal sum of line items
    Severity: HIGH
    """
    df = spark.table(table_name)

    # Calculate discrepancies
    with_calculated = df.withColumn(
        "calculated_total",
        F.col("quantity") * F.col("unit_price")
    ).withColumn(
        "discrepancy",
        F.abs(F.col("order_total") - F.col("calculated_total"))
    )

    total_rows = with_calculated.count()
    discrepancies = with_calculated.filter(F.col("discrepancy") > 0.01).count()  # Allow 1 cent rounding

    discrepancy_percentage = (discrepancies / total_rows * 100) if total_rows > 0 else 0

    result = {
        "test_name": "business_rule_order_total",
        "columns": ["order_total", "quantity", "unit_price"],
        "passed": discrepancies == 0,
        "total_rows": total_rows,
        "discrepancy_count": discrepancies,
        "discrepancy_percentage": discrepancy_percentage,
        "severity": "HIGH",
        "message": f"Found {discrepancies} orders with incorrect totals ({discrepancy_percentage:.2f}%)"
                   if discrepancies > 0 else "All order totals are correct"
    }

    return result

Complete Test Suite Generator

Generate a complete test suite from profiling results:

from datetime import datetime

def generate_test_suite(table_name, profile_results):
    """
    Generate complete test suite based on profiling results.

    Args:
        table_name: Full table name (catalog.schema.table)
        profile_results: Dictionary from data-profiler skill

    Returns:
        Complete test suite as Python code string
    """

    tests = []

    for column_name, column_profile in profile_results["columns"].items():
        # Completeness test for non-nullable columns
        if not column_profile.get("nullable", True):
            tests.append(f"""
def test_completeness_{column_name}(spark):
    '''Test: {column_name} should have no null values'''
    df = spark.table("{table_name}")
    null_count = df.filter(F.col("{column_name}").isNull()).count()
    return {{"test": "completeness_{column_name}", "passed": null_count == 0, "null_count": null_count}}
""")

        # Pattern tests based on detected patterns
        patterns = column_profile.get("patterns", [])
        if "EMAIL" in patterns:
            tests.append(f"""
def test_format_{column_name}_email(spark):
    '''Test: {column_name} should contain valid email addresses'''
    df = spark.table("{table_name}")
    email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{{2,}}$'
    invalid = df.filter(~F.col("{column_name}").rlike(email_pattern) & F.col("{column_name}").isNotNull()).count()
    return {{"test": "format_{column_name}_email", "passed": invalid == 0, "invalid_count": invalid}}
""")

        # Uniqueness test for primary keys
        if column_profile.get("is_unique", False):
            tests.append(f"""
def test_uniqueness_{column_name}(spark):
    '''Test: {column_name} should contain unique values'''
    df = spark.table("{table_name}")
    total = df.count()
    distinct = df.select("{column_name}").distinct().count()
    return {{"test": "uniqueness_{column_name}", "passed": total == distinct, "duplicates": total - distinct}}
""")

        # Range test for numeric columns
        if column_profile.get("data_type") in ["int", "float", "double", "decimal"]:
            min_val = column_profile.get("min", 0)
            max_val = column_profile.get("max", 0)
            # Add 10% buffer
            buffer = (max_val - min_val) * 0.1
            tests.append(f"""
def test_range_{column_name}(spark):
    '''Test: {column_name} should be within expected range'''
    df = spark.table("{table_name}")
    out_of_range = df.filter((F.col("{column_name}") < {min_val - buffer}) | (F.col("{column_name}") > {max_val + buffer})).count()
    return {{"test": "range_{column_name}", "passed": out_of_range == 0, "out_of_range": out_of_range}}
""")

    # Generate complete test file
    test_suite = f'''
"""
Auto-generated Data Quality Tests for {table_name}
Generated: {datetime.now().isoformat()}

This test suite validates data quality for the {table_name} table.
Tests are generated based on data profiling results.
"""

from pyspark.sql import SparkSession, functions as F
from datetime import datetime
import json

# Test functions
{"".join(tests)}

def run_all_tests(spark):
    """Run all data quality tests and return results."""
    results = []

    test_functions = [
        {", ".join([f"test_{t.split('def test_')[1].split('(')[0]}" for t in tests if t.strip()])}
    ]

    for test_func in test_functions:
        try:
            result = test_func(spark)
            result["status"] = "SUCCESS"
            results.append(result)
        except Exception as e:
            results.append({{
                "test": test_func.__name__,
                "status": "ERROR",
                "error": str(e)
            }})

    return results

def generate_report(results):
    """Generate test report summary."""
    total_tests = len(results)
    passed_tests = sum(1 for r in results if r.get("passed", False))
    failed_tests = total_tests - passed_tests

    report = {{
        "table": "{table_name}",
        "timestamp": datetime.now().isoformat(),
        "total_tests": total_tests,
        "passed": passed_tests,
        "failed": failed_tests,
        "pass_rate": (passed_tests / total_tests * 100) if total_tests > 0 else 0,
        "results": results
    }}

    return report

if __name__ == "__main__":
    spark = SparkSession.builder.appName("DataQualityTests").getOrCreate()
    results = run_all_tests(spark)
    report = generate_report(results)

    print(json.dumps(report, indent=2))
'''

    return test_suite

Usage Example

# 1. Get profiling results
from data_profiler import profile_table
profile = profile_table("main.bronze.customers")

# 2. Generate test suite
test_suite_code = generate_test_suite("main.bronze.customers", profile)

# 3. Save to file
with open("tests/test_customers_quality.py", "w") as f:
    f.write(test_suite_code)

# 4. Run tests
results = run_all_tests(spark)

# 5. Generate report
report = generate_report(results)
print(f"Pass rate: {report['pass_rate']:.1f}%")

Output Format

Test results are returned in a standardized format:

{
    "table": "main.bronze.customers",
    "timestamp": "2025-12-17T10:30:00",
    "total_tests": 15,
    "passed": 13,
    "failed": 2,
    "pass_rate": 86.7,
    "results": [
        {
            "test_name": "completeness_customer_id",
            "column": "customer_id",
            "passed": True,
            "severity": "CRITICAL",
            "message": "No null values found"
        },
        {
            "test_name": "format_email",
            "column": "email",
            "passed": False,
            "invalid_count": 23,
            "severity": "HIGH",
            "message": "Found 23 invalid email addresses (0.23%)"
        }
    ]
}

Best Practices

  1. Test Severity: Assign appropriate severity levels (CRITICAL, HIGH, MEDIUM, LOW)
  2. Tolerance Levels: Allow small percentages of failures for non-critical tests
  3. Performance: Use sampling for large tables during development
  4. Incremental Testing: Test only new data in incremental scenarios
  5. Alerting: Integrate with monitoring systems for failed tests

Notes

  • Tests run in Databricks environment with PySpark
  • Generated code is production-ready and executable
  • Tests can be scheduled as Databricks jobs
  • Results can be stored in Delta tables for historical tracking
  • Compatible with Databricks SQL and Unity Catalog

Related Skills

Xlsx

Comprehensive spreadsheet creation, editing, and analysis with support for formulas, formatting, data analysis, and visualization. When Claude needs to work with spreadsheets (.xlsx, .xlsm, .csv, .tsv, etc) for: (1) Creating new spreadsheets with formulas and formatting, (2) Reading or analyzing data, (3) Modify existing spreadsheets while preserving formulas, (4) Data analysis and visualization in spreadsheets, or (5) Recalculating formulas

data

Clickhouse Io

ClickHouse database patterns, query optimization, analytics, and data engineering best practices for high-performance analytical workloads.

datacli

Clickhouse Io

ClickHouse database patterns, query optimization, analytics, and data engineering best practices for high-performance analytical workloads.

datacli

Analyzing Financial Statements

This skill calculates key financial ratios and metrics from financial statement data for investment analysis

data

Data Storytelling

Transform data into compelling narratives using visualization, context, and persuasive structure. Use when presenting analytics to stakeholders, creating data reports, or building executive presentations.

data

Kpi Dashboard Design

Design effective KPI dashboards with metrics selection, visualization best practices, and real-time monitoring patterns. Use when building business dashboards, selecting metrics, or designing data visualization layouts.

designdata

Dbt Transformation Patterns

Master dbt (data build tool) for analytics engineering with model organization, testing, documentation, and incremental strategies. Use when building data transformations, creating data models, or implementing analytics engineering best practices.

testingdocumenttool

Sql Optimization Patterns

Master SQL query optimization, indexing strategies, and EXPLAIN analysis to dramatically improve database performance and eliminate slow queries. Use when debugging slow queries, designing database schemas, or optimizing application performance.

designdata

Anndata

This skill should be used when working with annotated data matrices in Python, particularly for single-cell genomics analysis, managing experimental measurements with metadata, or handling large-scale biological datasets. Use when tasks involve AnnData objects, h5ad files, single-cell RNA-seq data, or integration with scanpy/scverse tools.

arttooldata

Xlsx

Spreadsheet toolkit (.xlsx/.csv). Create/edit with formulas/formatting, analyze data, visualization, recalculate formulas, for spreadsheet processing and analysis.

tooldata

Skill Information

Category:Data
Version:1.0.0
Last Updated:12/17/2025