Files
IQA-Metric-Benchmark/scripts/compute_hallucination_from_anls.py
2025-09-11 09:39:02 +00:00

304 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Compute ANLS and hallucination scores from per-sample evaluation JSONs and plot results.
Inputs: one or more JSON files with schema like:
[
{
"image": "image (22)",
"num_pred_fields": 10,
"num_gt_fields": 12,
"num_correct": 9,
"all_correct": false,
"fields": [
{"field": "address", "pred": "...", "gt": "...", "correct": true},
...
]
},
...
]
ANLS definition: average normalized Levenshtein similarity over fields present in ground truth.
Here we approximate per-field similarity as:
sim = 1 - (levenshtein_distance(pred, gt) / max(len(pred), len(gt)))
clipped into [0, 1], and treat empty max length as exact match (1.0).
Per-image ANLS is the mean of field similarities for that image. Hallucination is 1 - ANLS.
Outputs:
- CSV per input JSON placed next to it: per_image_anls.csv with columns [image, anls, hallucination_score, num_fields]
- PNG bar chart per input JSON: hallucination_per_image.png with mean line and title.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Dict, List, Tuple
import pandas as pd
import matplotlib.pyplot as plt
def levenshtein_distance(a: str, b: str) -> int:
"""Compute Levenshtein distance between two strings (iterative DP)."""
if a == b:
return 0
if len(a) == 0:
return len(b)
if len(b) == 0:
return len(a)
previous_row = list(range(len(b) + 1))
for i, ca in enumerate(a, start=1):
current_row = [i]
for j, cb in enumerate(b, start=1):
insertions = previous_row[j] + 1
deletions = current_row[j - 1] + 1
substitutions = previous_row[j - 1] + (0 if ca == cb else 1)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def normalized_similarity(pred: str, gt: str) -> float:
"""Return 1 - normalized edit distance in [0, 1]."""
pred = pred or ""
gt = gt or ""
max_len = max(len(pred), len(gt))
if max_len == 0:
return 1.0
dist = levenshtein_distance(pred, gt)
sim = 1.0 - (dist / max_len)
if sim < 0.0:
return 0.0
if sim > 1.0:
return 1.0
return sim
def compute_anls_for_record(record: Dict) -> Tuple[float, int]:
"""Compute ANLS and number of fields for a single record object."""
fields = record.get("fields") or []
if not isinstance(fields, list) or len(fields) == 0:
return 0.0, 0
sims: List[float] = []
for f in fields:
pred = str(f.get("pred", ""))
gt = str(f.get("gt", ""))
sims.append(normalized_similarity(pred, gt))
anls = float(sum(sims) / len(sims)) if sims else 0.0
return anls, len(sims)
def process_json(json_path: Path) -> Path:
with json_path.open("r", encoding="utf-8") as f:
data = json.load(f)
rows = []
for rec in data:
image_name = rec.get("image")
anls, num_fields = compute_anls_for_record(rec)
hallucination = 1.0 - anls
rows.append({
"image": image_name,
"anls": anls,
"hallucination_score": hallucination,
"num_fields": int(num_fields),
})
df = pd.DataFrame(rows)
out_csv = json_path.parent / "per_image_anls.csv"
df.to_csv(out_csv, index=False)
# Plot hallucination bar chart with mean line
if len(df) > 0:
sorted_df = df.sort_values("hallucination_score", ascending=False).reset_index(drop=True)
plt.figure(figsize=(max(8, len(sorted_df) * 0.12), 5))
plt.bar(range(len(sorted_df)), sorted_df["hallucination_score"].values, color="#1f77b4")
mean_val = float(sorted_df["hallucination_score"].mean())
plt.axhline(mean_val, color="red", linestyle="--", label=f"Mean={mean_val:.3f}")
plt.xlabel("Image (sorted by hallucination)")
plt.ylabel("Hallucination = 1 - ANLS")
plt.title(f"Hallucination per image: {json_path.parent.name}")
plt.legend()
plt.tight_layout()
out_png = json_path.parent / "hallucination_per_image.png"
plt.savefig(out_png, dpi=150)
plt.close()
return out_csv
def common_parent(paths: List[Path]) -> Path:
if not paths:
return Path.cwd()
common = Path(Path(paths[0]).anchor)
parts = list(Path(paths[0]).resolve().parts)
for i in range(1, len(paths)):
other_parts = list(Path(paths[i]).resolve().parts)
# shrink parts to common prefix
new_parts: List[str] = []
for a, b in zip(parts, other_parts):
if a == b:
new_parts.append(a)
else:
break
parts = new_parts
if not parts:
return Path.cwd()
return Path(*parts)
def main() -> None:
parser = argparse.ArgumentParser(description="Compute ANLS and hallucination from per-sample JSONs and plot results.")
parser.add_argument("inputs", nargs="+", help="Paths to per_sample_eval.json files")
args = parser.parse_args()
any_error = False
combined_rows: List[Dict] = []
input_paths: List[Path] = []
for in_path_str in args.inputs:
path = Path(in_path_str)
if not path.exists():
print(f"[WARN] File does not exist: {path}", file=sys.stderr)
any_error = True
continue
try:
out_csv = process_json(path)
print(f"Processed: {path} -> {out_csv}")
# Load just-written CSV to aggregate and tag method
df = pd.read_csv(out_csv)
method_name = path.parent.name
df["method"] = method_name
combined_rows.extend(df.to_dict(orient="records"))
input_paths.append(path)
except Exception as exc:
print(f"[ERROR] Failed to process {path}: {exc}", file=sys.stderr)
any_error = True
# Create combined outputs if we have multiple inputs
if combined_rows:
combo_df = pd.DataFrame(combined_rows)
# Reorder columns
cols = ["image", "method", "anls", "hallucination_score", "num_fields"]
combo_df = combo_df[cols]
base_outdir = common_parent(input_paths)
combined_dir = base_outdir / "combined_anls"
combined_dir.mkdir(parents=True, exist_ok=True)
combined_csv = combined_dir / "combined_per_image_anls.csv"
combo_df.to_csv(combined_csv, index=False)
# Mean hallucination per method (bar chart)
means = combo_df.groupby("method")["hallucination_score"].mean().sort_values(ascending=False)
stds = combo_df.groupby("method")["hallucination_score"].std().reindex(means.index)
plt.figure(figsize=(max(6, len(means) * 1.2), 5))
plt.bar(means.index, means.values, yerr=stds.values, capsize=4, color="#2ca02c")
overall_mean = float(combo_df["hallucination_score"].mean())
plt.axhline(overall_mean, color="red", linestyle="--", label=f"Overall mean={overall_mean:.3f}")
plt.ylabel("Mean hallucination (1 - ANLS)")
plt.title("Mean hallucination by method")
plt.xticks(rotation=20, ha="right")
plt.legend()
plt.tight_layout()
bar_png = combined_dir / "mean_hallucination_by_method.png"
plt.savefig(bar_png, dpi=160)
plt.close()
# Heatmap: images x methods (hallucination)
pivot = combo_df.pivot_table(index="image", columns="method", values="hallucination_score", aggfunc="mean")
# Sort images by average hallucination descending for readability
pivot = pivot.reindex(pivot.mean(axis=1).sort_values(ascending=False).index)
plt.figure(figsize=(max(8, len(pivot.columns) * 1.0), max(6, len(pivot.index) * 0.25)))
im = plt.imshow(pivot.values, aspect="auto", cmap="viridis")
plt.colorbar(im, label="Hallucination (1 - ANLS)")
plt.xticks(range(len(pivot.columns)), pivot.columns, rotation=30, ha="right")
plt.yticks(range(len(pivot.index)), pivot.index)
plt.title("Hallucination per image across methods")
plt.tight_layout()
heatmap_png = combined_dir / "hallucination_heatmap.png"
plt.savefig(heatmap_png, dpi=160)
plt.close()
print(f"Combined CSV: {combined_csv}")
print(f"Saved: {bar_png}")
print(f"Saved: {heatmap_png}")
# Line chart: 1 line per method over images, hide image names
# Use same image order as pivot
methods = list(pivot.columns)
x = list(range(len(pivot.index)))
plt.figure(figsize=(max(10, len(x) * 0.12), 5))
colors = plt.rcParams['axes.prop_cycle'].by_key().get('color', ['#1f77b4','#ff7f0e','#2ca02c','#d62728','#9467bd'])
for idx, method in enumerate(methods):
y = pivot[method].to_numpy()
plt.plot(x, y, label=method, linewidth=1.8, color=colors[idx % len(colors)])
plt.ylim(0.0, 1.0)
plt.xlabel("Images (sorted by overall hallucination)")
plt.ylabel("Hallucination (1 - ANLS)")
plt.title("Hallucination across images by method")
plt.xticks([], []) # hide image names
# Mean note box
mean_lines = []
for method in methods:
m = float(combo_df[combo_df["method"] == method]["hallucination_score"].mean())
mean_lines.append(f"{method}: {m:.3f}")
text = "\n".join(mean_lines)
plt.gca().text(0.99, 0.01, text, transform=plt.gca().transAxes,
fontsize=9, va='bottom', ha='right',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8, edgecolor='gray'))
plt.legend(loc="upper right", ncol=min(3, len(methods)))
plt.tight_layout()
line_png = combined_dir / "hallucination_lines_by_method.png"
plt.savefig(line_png, dpi=160)
plt.close()
# Grouped-by-image interlocking line chart with image labels
# Build a consistent x position per image, with small offsets per method
base_x = list(range(len(pivot.index)))
offsets = {
m: ((i - (len(methods) - 1) / 2) * 0.12) for i, m in enumerate(methods)
}
# Cap width to avoid extremely long images; dynamic but limited
width = min(16, max(10, len(base_x) * 0.12))
plt.figure(figsize=(width, 6))
for idx, method in enumerate(methods):
# Fill missing values with 0 to connect lines seamlessly
y = pivot[method].fillna(0.0).to_numpy()
x_shifted = [bx + offsets[method] for bx in base_x]
plt.plot(x_shifted, y, label=method, linewidth=1.8, marker='o', markersize=3,
color=colors[idx % len(colors)])
plt.ylim(0.0, 1.0)
plt.xlim(-0.5, len(base_x) - 0.5)
# Hide image names; keep index ticks sparse for readability
plt.xticks([], [])
plt.xlabel("Images (index)")
plt.ylabel("Hallucination (1 - ANLS)")
plt.title("Hallucination by image (interlocked methods)")
plt.grid(axis='y', linestyle='--', alpha=0.3)
# Add box with per-method mean
text2 = "\n".join([f"{m}: {float(combo_df[combo_df['method']==m]['hallucination_score'].mean()):.3f}" for m in methods])
plt.gca().text(0.99, 0.01, text2, transform=plt.gca().transAxes,
fontsize=9, va='bottom', ha='right',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8, edgecolor='gray'))
plt.legend(loc='upper right', ncol=min(3, len(methods)))
plt.tight_layout()
group_line_png = combined_dir / "hallucination_interlocked_by_image.png"
plt.savefig(group_line_png, dpi=160)
plt.close()
print(f"Combined CSV: {combined_csv}")
print(f"Saved: {bar_png}")
print(f"Saved: {heatmap_png}")
print(f"Saved: {line_png}")
print(f"Saved: {group_line_png}")
if any_error:
sys.exit(1)
if __name__ == "__main__":
main()