Files
AbaqusSubroutineDev/scripts/compare_extracted_csv.py
T
2026-06-11 11:08:27 +09:00

328 lines
12 KiB
Python

#!/usr/bin/env python3
"""Compare externally generated Abaqus ODB-extracted CSV results."""
from __future__ import annotations
import argparse
import csv
import json
import math
import sys
from pathlib import Path
try:
from validate_reference_artifacts import validate_metadata
except ImportError:
from scripts.validate_reference_artifacts import validate_metadata
def load_csv_rows(path: Path) -> tuple[list[str], list[dict[str, str]]]:
with path.open(newline="", encoding="utf-8") as handle:
reader = csv.DictReader(handle)
return list(reader.fieldnames or []), list(reader)
def make_key(row: dict[str, str], key_columns: list[str]) -> tuple[str, ...]:
return tuple(row.get(column, "") for column in key_columns)
def _key_text(key: tuple[str, ...]) -> str:
return "|".join(key)
def _parse_finite(value: str) -> float:
parsed = float(value)
if not math.isfinite(parsed):
raise ValueError(f"nonfinite value: {value}")
return parsed
def failed_quantity(quantity: str, classification: str, message: str) -> dict:
return {
"quantity": quantity,
"result": "fail",
"classification": classification,
"message": message,
"compared_rows": 0,
"missing_rows": 0,
"extra_rows": 0,
"max_abs_error": None,
"max_rel_error": None,
"rms_error": None,
"worst_key": None,
"worst_component": None,
}
def validate_columns(headers: list[str], required_columns: list[str]) -> list[str]:
return [column for column in required_columns if column not in headers]
def duplicate_columns(headers: list[str]) -> list[str]:
seen: set[str] = set()
duplicates: list[str] = []
for header in headers:
if header in seen and header not in duplicates:
duplicates.append(header)
seen.add(header)
return duplicates
def _rows_by_key(rows: list[dict[str, str]], key_columns: list[str]) -> tuple[dict[tuple[str, ...], dict[str, str]], set[tuple[str, ...]]]:
keyed: dict[tuple[str, ...], dict[str, str]] = {}
duplicates: set[tuple[str, ...]] = set()
for row in rows:
key = make_key(row, key_columns)
if key in keyed:
duplicates.add(key)
keyed[key] = row
return keyed, duplicates
def validate_contract(contract: dict) -> list[str]:
required_keys = [
"reference_csv",
"actual_csv",
"required_columns",
"key_columns",
"value_column",
"tolerance",
]
missing = [key for key in required_keys if key not in contract]
if not isinstance(contract.get("tolerance", {}), dict):
missing.append("tolerance")
return sorted(set(missing))
def compare_quantity(quantity: str, contract: dict, reference_root: Path, actual_root: Path) -> dict:
contract_errors = validate_contract(contract)
if contract_errors:
return failed_quantity(
quantity,
"upstream-contract",
f"missing comparison contract keys: {', '.join(contract_errors)}",
)
reference_csv = reference_root / contract["reference_csv"]
actual_csv = actual_root / contract["actual_csv"]
if not reference_csv.exists():
return failed_quantity(quantity, "missing-reference-artifact", f"missing reference CSV: {reference_csv}")
if not actual_csv.exists():
return failed_quantity(quantity, "missing-generated-output", f"missing actual CSV: {actual_csv}")
reference_headers, reference_rows = load_csv_rows(reference_csv)
actual_headers, actual_rows = load_csv_rows(actual_csv)
required_columns = list(contract["required_columns"])
key_columns = list(contract["key_columns"])
value_column = contract["value_column"]
repeated_columns = duplicate_columns(reference_headers) + duplicate_columns(actual_headers)
if repeated_columns:
return failed_quantity(
quantity,
"schema-mismatch",
f"duplicate CSV header columns: {', '.join(sorted(set(repeated_columns)))}",
)
reference_missing_columns = validate_columns(reference_headers, required_columns)
actual_missing_columns = validate_columns(actual_headers, required_columns)
if reference_missing_columns or actual_missing_columns:
missing = sorted(set(reference_missing_columns + actual_missing_columns))
return failed_quantity(quantity, "schema-mismatch", f"missing required columns: {', '.join(missing)}")
reference_by_key, reference_duplicates = _rows_by_key(reference_rows, key_columns)
actual_by_key, actual_duplicates = _rows_by_key(actual_rows, key_columns)
duplicate_keys = reference_duplicates | actual_duplicates
if duplicate_keys:
return failed_quantity(quantity, "schema-mismatch", f"duplicate key rows: {_key_text(sorted(duplicate_keys)[0])}")
reference_keys = set(reference_by_key)
actual_keys = set(actual_by_key)
missing_keys = sorted(reference_keys - actual_keys)
extra_keys = sorted(actual_keys - reference_keys)
if missing_keys or extra_keys:
result = failed_quantity(quantity, "id-mismatch", "reference and actual row keys do not match")
result["missing_rows"] = len(missing_keys)
result["extra_rows"] = len(extra_keys)
result["worst_key"] = _key_text((missing_keys or extra_keys)[0])
return result
tolerance = contract.get("tolerance", {})
absolute = float(tolerance.get("absolute", 0.0))
relative = float(tolerance.get("relative", 0.0))
relative_floor = float(tolerance.get("relative_floor", 0.0))
unit_column = contract.get("unit_column")
coordinate_system_column = contract.get("coordinate_system_column")
compared_rows = 0
max_abs_error = 0.0
max_rel_error = 0.0
sum_square_error = 0.0
worst_key: tuple[str, ...] | None = None
tolerance_failed = False
for key in sorted(reference_keys):
reference_row = reference_by_key[key]
actual_row = actual_by_key[key]
try:
reference_value = _parse_finite(reference_row[value_column])
actual_value = _parse_finite(actual_row[value_column])
except (KeyError, ValueError) as exc:
return failed_quantity(quantity, "nonfinite-result", str(exc))
if unit_column and reference_row[unit_column] != actual_row[unit_column]:
result = failed_quantity(quantity, "unit-or-coordinate-mismatch", f"unit mismatch at {_key_text(key)}")
result["worst_key"] = _key_text(key)
return result
if coordinate_system_column and reference_row[coordinate_system_column] != actual_row[coordinate_system_column]:
result = failed_quantity(
quantity,
"unit-or-coordinate-mismatch",
f"coordinate system mismatch at {_key_text(key)}",
)
result["worst_key"] = _key_text(key)
return result
abs_error = abs(actual_value - reference_value)
rel_denominator = max(abs(reference_value), relative_floor)
rel_error = abs_error / rel_denominator if rel_denominator else 0.0
allowed_error = absolute + relative * rel_denominator
compared_rows += 1
sum_square_error += abs_error * abs_error
max_rel_error = max(max_rel_error, rel_error)
if worst_key is None or abs_error > max_abs_error:
max_abs_error = abs_error
worst_key = key
if abs_error > allowed_error:
tolerance_failed = True
rms_error = math.sqrt(sum_square_error / compared_rows) if compared_rows else 0.0
worst_key_text = _key_text(worst_key) if worst_key is not None else None
worst_component = worst_key[-1] if worst_key else None
result = "fail" if tolerance_failed else "pass"
classification = "tolerance-failure" if tolerance_failed else "N/A"
return {
"quantity": quantity,
"result": result,
"classification": classification,
"message": "",
"compared_rows": compared_rows,
"missing_rows": 0,
"extra_rows": 0,
"max_abs_error": max_abs_error,
"max_rel_error": max_rel_error,
"rms_error": rms_error,
"worst_key": worst_key_text,
"worst_component": worst_component,
}
def compare_metadata(
metadata_path: Path,
actual_root: Path,
*,
quantities: list[str] | None = None,
validate_artifacts: bool = True,
) -> dict:
payload = json.loads(metadata_path.read_text(encoding="utf-8"))
comparisons = payload.get("comparisons", {})
if not isinstance(comparisons, dict) or (not comparisons and quantities is None):
quantity_names = quantities if quantities is not None else ["metadata"]
results = [
failed_quantity(quantity, "upstream-contract", "missing comparison contracts")
for quantity in quantity_names
]
return {
"metadata": str(metadata_path),
"actual_root": str(actual_root),
"overall_result": "fail",
"quantities": results,
}
selected_quantities = quantities if quantities is not None else sorted(comparisons)
if validate_artifacts:
validation_errors = validate_metadata(metadata_path, _project_root_from_metadata(metadata_path))
if validation_errors:
message = "; ".join(validation_errors)
results = [
failed_quantity(quantity, "missing-reference-artifact", message)
for quantity in (selected_quantities or ["metadata"])
]
return {
"metadata": str(metadata_path),
"actual_root": str(actual_root),
"overall_result": "fail",
"quantities": results,
}
results = []
for quantity in selected_quantities:
contract = comparisons.get(quantity)
if contract is None:
results.append(failed_quantity(quantity, "upstream-contract", f"missing comparison contract: {quantity}"))
continue
results.append(compare_quantity(quantity, contract, metadata_path.parent, actual_root))
overall = "pass" if all(result["result"] == "pass" for result in results) else "fail"
return {
"metadata": str(metadata_path),
"actual_root": str(actual_root),
"overall_result": overall,
"quantities": results,
}
class _ArgumentParser(argparse.ArgumentParser):
def error(self, message: str) -> None:
raise ValueError(message)
def build_arg_parser() -> argparse.ArgumentParser:
parser = _ArgumentParser(description="Compare externally generated ODB-extracted CSV outputs.")
parser.add_argument("--metadata", required=True, type=Path, help="Reference metadata.json path.")
parser.add_argument("--actual-root", required=True, type=Path, help="Root directory containing actual extracted CSVs.")
parser.add_argument("--quantity", action="append", default=None, help="Quantity key to compare. May be repeated.")
parser.add_argument("--report-json", type=Path, default=None, help="Optional JSON report output path.")
return parser
def _format_summary(result: dict) -> str:
status = result["result"].upper()
parts = [
f"{status} {result['quantity']}",
f"rows={result['compared_rows']}",
f"max_abs_error={result['max_abs_error']}",
f"max_rel_error={result['max_rel_error']}",
f"rms_error={result['rms_error']}",
f"worst_key={result['worst_key']}",
]
if result["classification"] != "N/A":
parts.insert(2, f"classification={result['classification']}")
return " ".join(parts)
def _project_root_from_metadata(metadata_path: Path) -> Path:
resolved = metadata_path.resolve()
for parent in resolved.parents:
if parent.name == "references":
return parent.parent
return resolved.parent
def main(argv: list[str] | None = None) -> int:
parser = build_arg_parser()
try:
args = parser.parse_args(argv)
report = compare_metadata(args.metadata, args.actual_root, quantities=args.quantity)
except (OSError, ValueError, json.JSONDecodeError) as exc:
print(f"CSV comparison configuration failed: {exc}", file=sys.stderr)
return 2
if args.report_json is not None:
args.report_json.parent.mkdir(parents=True, exist_ok=True)
args.report_json.write_text(json.dumps(report, indent=2), encoding="utf-8")
for result in report["quantities"]:
print(_format_summary(result))
return 0 if report["overall_result"] == "pass" else 1
if __name__ == "__main__":
raise SystemExit(main())