update ressult
This commit is contained in:
254
scripts/pipeline_compare.py
Normal file
254
scripts/pipeline_compare.py
Normal file
@@ -0,0 +1,254 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compare hallucination across three pipelines over five preprocessing methods:
|
||||
1) Raw: all images
|
||||
2) DeQA-filtered: keep images with DeQA score >= threshold (default 2.6)
|
||||
3) Human-filtered: keep images labeled High in CSV labels
|
||||
|
||||
Inputs:
|
||||
- One or more per_sample_eval.json files (or per_image_anls.csv already generated)
|
||||
- DeQA score file (txt): lines like "3.9 - image (9)_0.png"
|
||||
- Human labels CSV with columns: filename,label where label in {High,Low}
|
||||
|
||||
Outputs:
|
||||
- Combined means CSV: method vs mean hallucination for each pipeline
|
||||
- Line chart (3 lines): hallucination mean per method across the three pipelines
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def canonical_key(name: str) -> str:
|
||||
"""Map various filenames to a canonical key used by per_sample_eval 'image' field.
|
||||
|
||||
Examples:
|
||||
- "image (9)_0.png" -> "image (9)"
|
||||
- "image (22)" -> "image (22)"
|
||||
- "foo/bar/image (15)_3.jpg" -> "image (15)"
|
||||
- other names -> stem without extension
|
||||
"""
|
||||
if not name:
|
||||
return name
|
||||
# Keep only basename
|
||||
base = Path(name).name
|
||||
# Try pattern image (N)
|
||||
m = re.search(r"(image \(\d+\))", base, flags=re.IGNORECASE)
|
||||
if m:
|
||||
return m.group(1)
|
||||
# Fallback: remove extension
|
||||
return Path(base).stem
|
||||
|
||||
|
||||
def read_deqa_scores(txt_path: Path) -> Dict[str, float]:
|
||||
scores: Dict[str, float] = {}
|
||||
with txt_path.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Accept formats: "3.9 - filename" or "filename,3.9" etc.
|
||||
m = re.match(r"\s*([0-9]+(?:\.[0-9]+)?)\s*[-,:]?\s*(.+)$", line)
|
||||
if m:
|
||||
score = float(m.group(1))
|
||||
filename = m.group(2)
|
||||
else:
|
||||
parts = re.split(r"[,\t]", line)
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
score = float(parts[1])
|
||||
filename = parts[0]
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
key = canonical_key(filename)
|
||||
scores[key] = score
|
||||
return scores
|
||||
|
||||
|
||||
def read_human_labels(csv_path: Path) -> Dict[str, str]:
|
||||
labels: Dict[str, str] = {}
|
||||
with csv_path.open("r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
filename = (row.get("filename") or row.get("file") or "").strip()
|
||||
label = (row.get("label") or row.get("Label") or "").strip()
|
||||
if not filename:
|
||||
continue
|
||||
key = canonical_key(filename)
|
||||
if label:
|
||||
labels[key] = label
|
||||
return labels
|
||||
|
||||
|
||||
def levenshtein_distance(a: str, b: str) -> int:
|
||||
if a == b:
|
||||
return 0
|
||||
if len(a) == 0:
|
||||
return len(b)
|
||||
if len(b) == 0:
|
||||
return len(a)
|
||||
previous_row = list(range(len(b) + 1))
|
||||
for i, ca in enumerate(a, start=1):
|
||||
current_row = [i]
|
||||
for j, cb in enumerate(b, start=1):
|
||||
insertions = previous_row[j] + 1
|
||||
deletions = current_row[j - 1] + 1
|
||||
substitutions = previous_row[j - 1] + (0 if ca == cb else 1)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
return previous_row[-1]
|
||||
|
||||
|
||||
def normalized_similarity(pred: str, gt: str) -> float:
|
||||
pred = pred or ""
|
||||
gt = gt or ""
|
||||
max_len = max(len(pred), len(gt))
|
||||
if max_len == 0:
|
||||
return 1.0
|
||||
dist = levenshtein_distance(pred, gt)
|
||||
sim = 1.0 - (dist / max_len)
|
||||
if sim < 0.0:
|
||||
return 0.0
|
||||
if sim > 1.0:
|
||||
return 1.0
|
||||
return sim
|
||||
|
||||
|
||||
def compute_anls_for_record(record: Dict) -> tuple[float, int]:
|
||||
fields = record.get("fields") or []
|
||||
if not isinstance(fields, list) or len(fields) == 0:
|
||||
return 0.0, 0
|
||||
sims: List[float] = []
|
||||
for f in fields:
|
||||
pred = str(f.get("pred", ""))
|
||||
gt = str(f.get("gt", ""))
|
||||
sims.append(normalized_similarity(pred, gt))
|
||||
anls = float(sum(sims) / len(sims)) if sims else 0.0
|
||||
return anls, len(sims)
|
||||
|
||||
|
||||
def load_per_image_anls(input_json: Path) -> pd.DataFrame:
|
||||
# Prefer existing per_image_anls.csv, otherwise compute quickly
|
||||
per_image_csv = input_json.parent / "per_image_anls.csv"
|
||||
if per_image_csv.exists():
|
||||
df = pd.read_csv(per_image_csv)
|
||||
return df
|
||||
# Fallback: compute minimal ANLS like in the other script
|
||||
with input_json.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
rows = []
|
||||
for rec in data:
|
||||
anls, num_fields = compute_anls_for_record(rec)
|
||||
rows.append({
|
||||
"image": rec.get("image"),
|
||||
"anls": anls,
|
||||
"hallucination_score": 1.0 - anls,
|
||||
"num_fields": int(num_fields),
|
||||
})
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description="Compare hallucination across raw/DeQA/Human pipelines over methods")
|
||||
p.add_argument("inputs", nargs="+", help="per_sample_eval.json files for each method")
|
||||
p.add_argument("--deqa_txt", required=True, help="Path to DeQA scores txt (e.g., cni.txt)")
|
||||
p.add_argument("--human_csv", required=True, help="Path to human labels CSV")
|
||||
p.add_argument("--deqa_threshold", type=float, default=2.6, help="DeQA threshold (>=)")
|
||||
args = p.parse_args()
|
||||
|
||||
# Load filters
|
||||
deqa_scores = read_deqa_scores(Path(args.deqa_txt))
|
||||
human_labels = read_human_labels(Path(args.human_csv))
|
||||
|
||||
# Aggregate per method
|
||||
method_to_df: Dict[str, pd.DataFrame] = {}
|
||||
for ip in args.inputs:
|
||||
path = Path(ip)
|
||||
df = load_per_image_anls(path)
|
||||
df["method"] = path.parent.name
|
||||
df["image_key"] = df["image"].apply(canonical_key)
|
||||
method_to_df[path.parent.name] = df
|
||||
|
||||
# Compute means per pipeline (fair comparison: set excluded images to hallucination=0)
|
||||
records = []
|
||||
for method, df in method_to_df.items():
|
||||
raw_mean = float(df["hallucination_score"].mean()) if len(df) else float("nan")
|
||||
|
||||
# DeQA filter: mark DeQA < threshold as hallucination=0, keep all images
|
||||
df_deqa = df.copy()
|
||||
mask_deqa = df_deqa["image_key"].map(lambda k: deqa_scores.get(k, None))
|
||||
# Set hallucination=0 for images with DeQA < threshold (or missing DeQA)
|
||||
df_deqa.loc[mask_deqa.isna() | (mask_deqa < args.deqa_threshold), "hallucination_score"] = 0.0
|
||||
deqa_mean = float(df_deqa["hallucination_score"].mean()) if len(df_deqa) else float("nan")
|
||||
|
||||
# Human filter: mark Low labels as hallucination=0, keep all images
|
||||
df_human = df.copy()
|
||||
mask_human = df_human["image_key"].map(lambda k: human_labels.get(k, "").lower())
|
||||
# Set hallucination=0 for images labeled Low (or missing label)
|
||||
df_human.loc[mask_human != "high", "hallucination_score"] = 0.0
|
||||
human_mean = float(df_human["hallucination_score"].mean()) if len(df_human) else float("nan")
|
||||
|
||||
records.append({
|
||||
"method": method,
|
||||
"raw_mean": raw_mean,
|
||||
"deqa_mean": deqa_mean,
|
||||
"human_mean": human_mean,
|
||||
"raw_count": int(len(df)),
|
||||
"deqa_count": int(len(df_deqa)), # Now equal to raw_count
|
||||
"human_count": int(len(df_human)), # Now equal to raw_count
|
||||
})
|
||||
|
||||
outdir = Path(args.inputs[0]).parent.parent / "combined_anls" / "pipeline"
|
||||
outdir.mkdir(parents=True, exist_ok=True)
|
||||
out_csv = outdir / "pipeline_means.csv"
|
||||
means_df = pd.DataFrame(records).sort_values("method")
|
||||
means_df.to_csv(out_csv, index=False)
|
||||
|
||||
# 3-line comparison plot over methods (narrower with score annotations)
|
||||
x = range(len(means_df))
|
||||
plt.figure(figsize=(7, 5))
|
||||
|
||||
# Plot lines and add score annotations
|
||||
raw_vals = means_df["raw_mean"].values
|
||||
deqa_vals = means_df["deqa_mean"].values
|
||||
human_vals = means_df["human_mean"].values
|
||||
|
||||
plt.plot(x, raw_vals, marker="o", label="Raw", linewidth=2, markersize=6)
|
||||
plt.plot(x, deqa_vals, marker="s", label=f"DeQA >= {args.deqa_threshold}", linewidth=2, markersize=6)
|
||||
plt.plot(x, human_vals, marker="^", label="Human High", linewidth=2, markersize=6)
|
||||
|
||||
# Annotate each point with its score
|
||||
for i, (r, d, h) in enumerate(zip(raw_vals, deqa_vals, human_vals)):
|
||||
plt.annotate(f"{r:.3f}", (i, r), textcoords="offset points", xytext=(0,8), ha='center', fontsize=8)
|
||||
plt.annotate(f"{d:.3f}", (i, d), textcoords="offset points", xytext=(0,8), ha='center', fontsize=8)
|
||||
plt.annotate(f"{h:.3f}", (i, h), textcoords="offset points", xytext=(0,8), ha='center', fontsize=8)
|
||||
|
||||
plt.xticks(list(x), means_df["method"].tolist(), rotation=25, ha="right")
|
||||
plt.ylabel("Mean hallucination (1 - ANLS)")
|
||||
plt.title("Pipeline comparison over preprocessing methods")
|
||||
plt.grid(axis="y", linestyle="--", alpha=0.3)
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
out_png = outdir / "pipeline_comparison.png"
|
||||
plt.savefig(out_png, dpi=160)
|
||||
plt.close()
|
||||
|
||||
print(f"Saved: {out_csv}")
|
||||
print(f"Saved: {out_png}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user