255 lines
9.1 KiB
Python
255 lines
9.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Compare hallucination across three pipelines over five preprocessing methods:
|
|
1) Raw: all images
|
|
2) DeQA-filtered: keep images with DeQA score >= threshold (default 2.6)
|
|
3) Human-filtered: keep images labeled High in CSV labels
|
|
|
|
Inputs:
|
|
- One or more per_sample_eval.json files (or per_image_anls.csv already generated)
|
|
- DeQA score file (txt): lines like "3.9 - image (9)_0.png"
|
|
- Human labels CSV with columns: filename,label where label in {High,Low}
|
|
|
|
Outputs:
|
|
- Combined means CSV: method vs mean hallucination for each pipeline
|
|
- Line chart (3 lines): hallucination mean per method across the three pipelines
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple, Optional
|
|
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def canonical_key(name: str) -> str:
|
|
"""Map various filenames to a canonical key used by per_sample_eval 'image' field.
|
|
|
|
Examples:
|
|
- "image (9)_0.png" -> "image (9)"
|
|
- "image (22)" -> "image (22)"
|
|
- "foo/bar/image (15)_3.jpg" -> "image (15)"
|
|
- other names -> stem without extension
|
|
"""
|
|
if not name:
|
|
return name
|
|
# Keep only basename
|
|
base = Path(name).name
|
|
# Try pattern image (N)
|
|
m = re.search(r"(image \(\d+\))", base, flags=re.IGNORECASE)
|
|
if m:
|
|
return m.group(1)
|
|
# Fallback: remove extension
|
|
return Path(base).stem
|
|
|
|
|
|
def read_deqa_scores(txt_path: Path) -> Dict[str, float]:
|
|
scores: Dict[str, float] = {}
|
|
with txt_path.open("r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
# Accept formats: "3.9 - filename" or "filename,3.9" etc.
|
|
m = re.match(r"\s*([0-9]+(?:\.[0-9]+)?)\s*[-,:]?\s*(.+)$", line)
|
|
if m:
|
|
score = float(m.group(1))
|
|
filename = m.group(2)
|
|
else:
|
|
parts = re.split(r"[,\t]", line)
|
|
if len(parts) >= 2:
|
|
try:
|
|
score = float(parts[1])
|
|
filename = parts[0]
|
|
except Exception:
|
|
continue
|
|
else:
|
|
continue
|
|
key = canonical_key(filename)
|
|
scores[key] = score
|
|
return scores
|
|
|
|
|
|
def read_human_labels(csv_path: Path) -> Dict[str, str]:
|
|
labels: Dict[str, str] = {}
|
|
with csv_path.open("r", encoding="utf-8") as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
filename = (row.get("filename") or row.get("file") or "").strip()
|
|
label = (row.get("label") or row.get("Label") or "").strip()
|
|
if not filename:
|
|
continue
|
|
key = canonical_key(filename)
|
|
if label:
|
|
labels[key] = label
|
|
return labels
|
|
|
|
|
|
def levenshtein_distance(a: str, b: str) -> int:
|
|
if a == b:
|
|
return 0
|
|
if len(a) == 0:
|
|
return len(b)
|
|
if len(b) == 0:
|
|
return len(a)
|
|
previous_row = list(range(len(b) + 1))
|
|
for i, ca in enumerate(a, start=1):
|
|
current_row = [i]
|
|
for j, cb in enumerate(b, start=1):
|
|
insertions = previous_row[j] + 1
|
|
deletions = current_row[j - 1] + 1
|
|
substitutions = previous_row[j - 1] + (0 if ca == cb else 1)
|
|
current_row.append(min(insertions, deletions, substitutions))
|
|
previous_row = current_row
|
|
return previous_row[-1]
|
|
|
|
|
|
def normalized_similarity(pred: str, gt: str) -> float:
|
|
pred = pred or ""
|
|
gt = gt or ""
|
|
max_len = max(len(pred), len(gt))
|
|
if max_len == 0:
|
|
return 1.0
|
|
dist = levenshtein_distance(pred, gt)
|
|
sim = 1.0 - (dist / max_len)
|
|
if sim < 0.0:
|
|
return 0.0
|
|
if sim > 1.0:
|
|
return 1.0
|
|
return sim
|
|
|
|
|
|
def compute_anls_for_record(record: Dict) -> tuple[float, int]:
|
|
fields = record.get("fields") or []
|
|
if not isinstance(fields, list) or len(fields) == 0:
|
|
return 0.0, 0
|
|
sims: List[float] = []
|
|
for f in fields:
|
|
pred = str(f.get("pred", ""))
|
|
gt = str(f.get("gt", ""))
|
|
sims.append(normalized_similarity(pred, gt))
|
|
anls = float(sum(sims) / len(sims)) if sims else 0.0
|
|
return anls, len(sims)
|
|
|
|
|
|
def load_per_image_anls(input_json: Path) -> pd.DataFrame:
|
|
# Prefer existing per_image_anls.csv, otherwise compute quickly
|
|
per_image_csv = input_json.parent / "per_image_anls.csv"
|
|
if per_image_csv.exists():
|
|
df = pd.read_csv(per_image_csv)
|
|
return df
|
|
# Fallback: compute minimal ANLS like in the other script
|
|
with input_json.open("r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
rows = []
|
|
for rec in data:
|
|
anls, num_fields = compute_anls_for_record(rec)
|
|
rows.append({
|
|
"image": rec.get("image"),
|
|
"anls": anls,
|
|
"hallucination_score": 1.0 - anls,
|
|
"num_fields": int(num_fields),
|
|
})
|
|
return pd.DataFrame(rows)
|
|
|
|
|
|
def main() -> None:
|
|
p = argparse.ArgumentParser(description="Compare hallucination across raw/DeQA/Human pipelines over methods")
|
|
p.add_argument("inputs", nargs="+", help="per_sample_eval.json files for each method")
|
|
p.add_argument("--deqa_txt", required=True, help="Path to DeQA scores txt (e.g., cni.txt)")
|
|
p.add_argument("--human_csv", required=True, help="Path to human labels CSV")
|
|
p.add_argument("--deqa_threshold", type=float, default=2.6, help="DeQA threshold (>=)")
|
|
args = p.parse_args()
|
|
|
|
# Load filters
|
|
deqa_scores = read_deqa_scores(Path(args.deqa_txt))
|
|
human_labels = read_human_labels(Path(args.human_csv))
|
|
|
|
# Aggregate per method
|
|
method_to_df: Dict[str, pd.DataFrame] = {}
|
|
for ip in args.inputs:
|
|
path = Path(ip)
|
|
df = load_per_image_anls(path)
|
|
df["method"] = path.parent.name
|
|
df["image_key"] = df["image"].apply(canonical_key)
|
|
method_to_df[path.parent.name] = df
|
|
|
|
# Compute means per pipeline (fair comparison: set excluded images to hallucination=0)
|
|
records = []
|
|
for method, df in method_to_df.items():
|
|
raw_mean = float(df["hallucination_score"].mean()) if len(df) else float("nan")
|
|
|
|
# DeQA filter: mark DeQA < threshold as hallucination=0, keep all images
|
|
df_deqa = df.copy()
|
|
mask_deqa = df_deqa["image_key"].map(lambda k: deqa_scores.get(k, None))
|
|
# Set hallucination=0 for images with DeQA < threshold (or missing DeQA)
|
|
df_deqa.loc[mask_deqa.isna() | (mask_deqa < args.deqa_threshold), "hallucination_score"] = 0.0
|
|
deqa_mean = float(df_deqa["hallucination_score"].mean()) if len(df_deqa) else float("nan")
|
|
|
|
# Human filter: mark Low labels as hallucination=0, keep all images
|
|
df_human = df.copy()
|
|
mask_human = df_human["image_key"].map(lambda k: human_labels.get(k, "").lower())
|
|
# Set hallucination=0 for images labeled Low (or missing label)
|
|
df_human.loc[mask_human != "high", "hallucination_score"] = 0.0
|
|
human_mean = float(df_human["hallucination_score"].mean()) if len(df_human) else float("nan")
|
|
|
|
records.append({
|
|
"method": method,
|
|
"raw_mean": raw_mean,
|
|
"deqa_mean": deqa_mean,
|
|
"human_mean": human_mean,
|
|
"raw_count": int(len(df)),
|
|
"deqa_count": int(len(df_deqa)), # Now equal to raw_count
|
|
"human_count": int(len(df_human)), # Now equal to raw_count
|
|
})
|
|
|
|
outdir = Path(args.inputs[0]).parent.parent / "combined_anls" / "pipeline"
|
|
outdir.mkdir(parents=True, exist_ok=True)
|
|
out_csv = outdir / "pipeline_means.csv"
|
|
means_df = pd.DataFrame(records).sort_values("method")
|
|
means_df.to_csv(out_csv, index=False)
|
|
|
|
# 3-line comparison plot over methods (narrower with score annotations)
|
|
x = range(len(means_df))
|
|
plt.figure(figsize=(7, 5))
|
|
|
|
# Plot lines and add score annotations
|
|
raw_vals = means_df["raw_mean"].values
|
|
deqa_vals = means_df["deqa_mean"].values
|
|
human_vals = means_df["human_mean"].values
|
|
|
|
plt.plot(x, raw_vals, marker="o", label="Raw", linewidth=2, markersize=6)
|
|
plt.plot(x, deqa_vals, marker="s", label=f"DeQA >= {args.deqa_threshold}", linewidth=2, markersize=6)
|
|
plt.plot(x, human_vals, marker="^", label="Human High", linewidth=2, markersize=6)
|
|
|
|
# Annotate each point with its score
|
|
for i, (r, d, h) in enumerate(zip(raw_vals, deqa_vals, human_vals)):
|
|
plt.annotate(f"{r:.3f}", (i, r), textcoords="offset points", xytext=(0,8), ha='center', fontsize=8)
|
|
plt.annotate(f"{d:.3f}", (i, d), textcoords="offset points", xytext=(0,8), ha='center', fontsize=8)
|
|
plt.annotate(f"{h:.3f}", (i, h), textcoords="offset points", xytext=(0,8), ha='center', fontsize=8)
|
|
|
|
plt.xticks(list(x), means_df["method"].tolist(), rotation=25, ha="right")
|
|
plt.ylabel("Mean hallucination (1 - ANLS)")
|
|
plt.title("Pipeline comparison over preprocessing methods")
|
|
plt.grid(axis="y", linestyle="--", alpha=0.3)
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
out_png = outdir / "pipeline_comparison.png"
|
|
plt.savefig(out_png, dpi=160)
|
|
plt.close()
|
|
|
|
print(f"Saved: {out_csv}")
|
|
print(f"Saved: {out_png}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|