328 lines
12 KiB
Python
328 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""Compare externally generated Abaqus ODB-extracted CSV results."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import math
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
try:
|
|
from validate_reference_artifacts import validate_metadata
|
|
except ImportError:
|
|
from scripts.validate_reference_artifacts import validate_metadata
|
|
|
|
|
|
def load_csv_rows(path: Path) -> tuple[list[str], list[dict[str, str]]]:
|
|
with path.open(newline="", encoding="utf-8") as handle:
|
|
reader = csv.DictReader(handle)
|
|
return list(reader.fieldnames or []), list(reader)
|
|
|
|
|
|
def make_key(row: dict[str, str], key_columns: list[str]) -> tuple[str, ...]:
|
|
return tuple(row.get(column, "") for column in key_columns)
|
|
|
|
|
|
def _key_text(key: tuple[str, ...]) -> str:
|
|
return "|".join(key)
|
|
|
|
|
|
def _parse_finite(value: str) -> float:
|
|
parsed = float(value)
|
|
if not math.isfinite(parsed):
|
|
raise ValueError(f"nonfinite value: {value}")
|
|
return parsed
|
|
|
|
|
|
def failed_quantity(quantity: str, classification: str, message: str) -> dict:
|
|
return {
|
|
"quantity": quantity,
|
|
"result": "fail",
|
|
"classification": classification,
|
|
"message": message,
|
|
"compared_rows": 0,
|
|
"missing_rows": 0,
|
|
"extra_rows": 0,
|
|
"max_abs_error": None,
|
|
"max_rel_error": None,
|
|
"rms_error": None,
|
|
"worst_key": None,
|
|
"worst_component": None,
|
|
}
|
|
|
|
|
|
def validate_columns(headers: list[str], required_columns: list[str]) -> list[str]:
|
|
return [column for column in required_columns if column not in headers]
|
|
|
|
|
|
def duplicate_columns(headers: list[str]) -> list[str]:
|
|
seen: set[str] = set()
|
|
duplicates: list[str] = []
|
|
for header in headers:
|
|
if header in seen and header not in duplicates:
|
|
duplicates.append(header)
|
|
seen.add(header)
|
|
return duplicates
|
|
|
|
|
|
def _rows_by_key(rows: list[dict[str, str]], key_columns: list[str]) -> tuple[dict[tuple[str, ...], dict[str, str]], set[tuple[str, ...]]]:
|
|
keyed: dict[tuple[str, ...], dict[str, str]] = {}
|
|
duplicates: set[tuple[str, ...]] = set()
|
|
for row in rows:
|
|
key = make_key(row, key_columns)
|
|
if key in keyed:
|
|
duplicates.add(key)
|
|
keyed[key] = row
|
|
return keyed, duplicates
|
|
|
|
|
|
def validate_contract(contract: dict) -> list[str]:
|
|
required_keys = [
|
|
"reference_csv",
|
|
"actual_csv",
|
|
"required_columns",
|
|
"key_columns",
|
|
"value_column",
|
|
"tolerance",
|
|
]
|
|
missing = [key for key in required_keys if key not in contract]
|
|
if not isinstance(contract.get("tolerance", {}), dict):
|
|
missing.append("tolerance")
|
|
return sorted(set(missing))
|
|
|
|
|
|
def compare_quantity(quantity: str, contract: dict, reference_root: Path, actual_root: Path) -> dict:
|
|
contract_errors = validate_contract(contract)
|
|
if contract_errors:
|
|
return failed_quantity(
|
|
quantity,
|
|
"upstream-contract",
|
|
f"missing comparison contract keys: {', '.join(contract_errors)}",
|
|
)
|
|
|
|
reference_csv = reference_root / contract["reference_csv"]
|
|
actual_csv = actual_root / contract["actual_csv"]
|
|
if not reference_csv.exists():
|
|
return failed_quantity(quantity, "missing-reference-artifact", f"missing reference CSV: {reference_csv}")
|
|
if not actual_csv.exists():
|
|
return failed_quantity(quantity, "missing-generated-output", f"missing actual CSV: {actual_csv}")
|
|
|
|
reference_headers, reference_rows = load_csv_rows(reference_csv)
|
|
actual_headers, actual_rows = load_csv_rows(actual_csv)
|
|
|
|
required_columns = list(contract["required_columns"])
|
|
key_columns = list(contract["key_columns"])
|
|
value_column = contract["value_column"]
|
|
repeated_columns = duplicate_columns(reference_headers) + duplicate_columns(actual_headers)
|
|
if repeated_columns:
|
|
return failed_quantity(
|
|
quantity,
|
|
"schema-mismatch",
|
|
f"duplicate CSV header columns: {', '.join(sorted(set(repeated_columns)))}",
|
|
)
|
|
reference_missing_columns = validate_columns(reference_headers, required_columns)
|
|
actual_missing_columns = validate_columns(actual_headers, required_columns)
|
|
if reference_missing_columns or actual_missing_columns:
|
|
missing = sorted(set(reference_missing_columns + actual_missing_columns))
|
|
return failed_quantity(quantity, "schema-mismatch", f"missing required columns: {', '.join(missing)}")
|
|
|
|
reference_by_key, reference_duplicates = _rows_by_key(reference_rows, key_columns)
|
|
actual_by_key, actual_duplicates = _rows_by_key(actual_rows, key_columns)
|
|
duplicate_keys = reference_duplicates | actual_duplicates
|
|
if duplicate_keys:
|
|
return failed_quantity(quantity, "schema-mismatch", f"duplicate key rows: {_key_text(sorted(duplicate_keys)[0])}")
|
|
|
|
reference_keys = set(reference_by_key)
|
|
actual_keys = set(actual_by_key)
|
|
missing_keys = sorted(reference_keys - actual_keys)
|
|
extra_keys = sorted(actual_keys - reference_keys)
|
|
if missing_keys or extra_keys:
|
|
result = failed_quantity(quantity, "id-mismatch", "reference and actual row keys do not match")
|
|
result["missing_rows"] = len(missing_keys)
|
|
result["extra_rows"] = len(extra_keys)
|
|
result["worst_key"] = _key_text((missing_keys or extra_keys)[0])
|
|
return result
|
|
|
|
tolerance = contract.get("tolerance", {})
|
|
absolute = float(tolerance.get("absolute", 0.0))
|
|
relative = float(tolerance.get("relative", 0.0))
|
|
relative_floor = float(tolerance.get("relative_floor", 0.0))
|
|
unit_column = contract.get("unit_column")
|
|
coordinate_system_column = contract.get("coordinate_system_column")
|
|
|
|
compared_rows = 0
|
|
max_abs_error = 0.0
|
|
max_rel_error = 0.0
|
|
sum_square_error = 0.0
|
|
worst_key: tuple[str, ...] | None = None
|
|
tolerance_failed = False
|
|
|
|
for key in sorted(reference_keys):
|
|
reference_row = reference_by_key[key]
|
|
actual_row = actual_by_key[key]
|
|
try:
|
|
reference_value = _parse_finite(reference_row[value_column])
|
|
actual_value = _parse_finite(actual_row[value_column])
|
|
except (KeyError, ValueError) as exc:
|
|
return failed_quantity(quantity, "nonfinite-result", str(exc))
|
|
if unit_column and reference_row[unit_column] != actual_row[unit_column]:
|
|
result = failed_quantity(quantity, "unit-or-coordinate-mismatch", f"unit mismatch at {_key_text(key)}")
|
|
result["worst_key"] = _key_text(key)
|
|
return result
|
|
if coordinate_system_column and reference_row[coordinate_system_column] != actual_row[coordinate_system_column]:
|
|
result = failed_quantity(
|
|
quantity,
|
|
"unit-or-coordinate-mismatch",
|
|
f"coordinate system mismatch at {_key_text(key)}",
|
|
)
|
|
result["worst_key"] = _key_text(key)
|
|
return result
|
|
abs_error = abs(actual_value - reference_value)
|
|
rel_denominator = max(abs(reference_value), relative_floor)
|
|
rel_error = abs_error / rel_denominator if rel_denominator else 0.0
|
|
allowed_error = absolute + relative * rel_denominator
|
|
compared_rows += 1
|
|
sum_square_error += abs_error * abs_error
|
|
max_rel_error = max(max_rel_error, rel_error)
|
|
if worst_key is None or abs_error > max_abs_error:
|
|
max_abs_error = abs_error
|
|
worst_key = key
|
|
if abs_error > allowed_error:
|
|
tolerance_failed = True
|
|
|
|
rms_error = math.sqrt(sum_square_error / compared_rows) if compared_rows else 0.0
|
|
worst_key_text = _key_text(worst_key) if worst_key is not None else None
|
|
worst_component = worst_key[-1] if worst_key else None
|
|
result = "fail" if tolerance_failed else "pass"
|
|
classification = "tolerance-failure" if tolerance_failed else "N/A"
|
|
return {
|
|
"quantity": quantity,
|
|
"result": result,
|
|
"classification": classification,
|
|
"message": "",
|
|
"compared_rows": compared_rows,
|
|
"missing_rows": 0,
|
|
"extra_rows": 0,
|
|
"max_abs_error": max_abs_error,
|
|
"max_rel_error": max_rel_error,
|
|
"rms_error": rms_error,
|
|
"worst_key": worst_key_text,
|
|
"worst_component": worst_component,
|
|
}
|
|
|
|
|
|
def compare_metadata(
|
|
metadata_path: Path,
|
|
actual_root: Path,
|
|
*,
|
|
quantities: list[str] | None = None,
|
|
validate_artifacts: bool = True,
|
|
) -> dict:
|
|
payload = json.loads(metadata_path.read_text(encoding="utf-8"))
|
|
comparisons = payload.get("comparisons", {})
|
|
if not isinstance(comparisons, dict) or (not comparisons and quantities is None):
|
|
quantity_names = quantities if quantities is not None else ["metadata"]
|
|
results = [
|
|
failed_quantity(quantity, "upstream-contract", "missing comparison contracts")
|
|
for quantity in quantity_names
|
|
]
|
|
return {
|
|
"metadata": str(metadata_path),
|
|
"actual_root": str(actual_root),
|
|
"overall_result": "fail",
|
|
"quantities": results,
|
|
}
|
|
selected_quantities = quantities if quantities is not None else sorted(comparisons)
|
|
if validate_artifacts:
|
|
validation_errors = validate_metadata(metadata_path, _project_root_from_metadata(metadata_path))
|
|
if validation_errors:
|
|
message = "; ".join(validation_errors)
|
|
results = [
|
|
failed_quantity(quantity, "missing-reference-artifact", message)
|
|
for quantity in (selected_quantities or ["metadata"])
|
|
]
|
|
return {
|
|
"metadata": str(metadata_path),
|
|
"actual_root": str(actual_root),
|
|
"overall_result": "fail",
|
|
"quantities": results,
|
|
}
|
|
|
|
results = []
|
|
for quantity in selected_quantities:
|
|
contract = comparisons.get(quantity)
|
|
if contract is None:
|
|
results.append(failed_quantity(quantity, "upstream-contract", f"missing comparison contract: {quantity}"))
|
|
continue
|
|
results.append(compare_quantity(quantity, contract, metadata_path.parent, actual_root))
|
|
overall = "pass" if all(result["result"] == "pass" for result in results) else "fail"
|
|
return {
|
|
"metadata": str(metadata_path),
|
|
"actual_root": str(actual_root),
|
|
"overall_result": overall,
|
|
"quantities": results,
|
|
}
|
|
|
|
|
|
class _ArgumentParser(argparse.ArgumentParser):
|
|
def error(self, message: str) -> None:
|
|
raise ValueError(message)
|
|
|
|
|
|
def build_arg_parser() -> argparse.ArgumentParser:
|
|
parser = _ArgumentParser(description="Compare externally generated ODB-extracted CSV outputs.")
|
|
parser.add_argument("--metadata", required=True, type=Path, help="Reference metadata.json path.")
|
|
parser.add_argument("--actual-root", required=True, type=Path, help="Root directory containing actual extracted CSVs.")
|
|
parser.add_argument("--quantity", action="append", default=None, help="Quantity key to compare. May be repeated.")
|
|
parser.add_argument("--report-json", type=Path, default=None, help="Optional JSON report output path.")
|
|
return parser
|
|
|
|
|
|
def _format_summary(result: dict) -> str:
|
|
status = result["result"].upper()
|
|
parts = [
|
|
f"{status} {result['quantity']}",
|
|
f"rows={result['compared_rows']}",
|
|
f"max_abs_error={result['max_abs_error']}",
|
|
f"max_rel_error={result['max_rel_error']}",
|
|
f"rms_error={result['rms_error']}",
|
|
f"worst_key={result['worst_key']}",
|
|
]
|
|
if result["classification"] != "N/A":
|
|
parts.insert(2, f"classification={result['classification']}")
|
|
return " ".join(parts)
|
|
|
|
|
|
def _project_root_from_metadata(metadata_path: Path) -> Path:
|
|
resolved = metadata_path.resolve()
|
|
for parent in resolved.parents:
|
|
if parent.name == "references":
|
|
return parent.parent
|
|
return resolved.parent
|
|
|
|
|
|
def main(argv: list[str] | None = None) -> int:
|
|
parser = build_arg_parser()
|
|
try:
|
|
args = parser.parse_args(argv)
|
|
report = compare_metadata(args.metadata, args.actual_root, quantities=args.quantity)
|
|
except (OSError, ValueError, json.JSONDecodeError) as exc:
|
|
print(f"CSV comparison configuration failed: {exc}", file=sys.stderr)
|
|
return 2
|
|
|
|
if args.report_json is not None:
|
|
args.report_json.parent.mkdir(parents=True, exist_ok=True)
|
|
args.report_json.write_text(json.dumps(report, indent=2), encoding="utf-8")
|
|
|
|
for result in report["quantities"]:
|
|
print(_format_summary(result))
|
|
|
|
return 0 if report["overall_result"] == "pass" else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|