Add balanced doclift claim track
This commit is contained in:
parent
169500369f
commit
86a4db01a4
|
|
@ -107,10 +107,11 @@ def evaluate_doclift_claim_tracks(bundle_root: str | Path, benchmark_path: str |
|
||||||
adapter = DocliftBundleSourceAdapter()
|
adapter = DocliftBundleSourceAdapter()
|
||||||
documents = {str(item.get("document_id")): item for item in manifest.get("documents", []) if isinstance(item, dict)}
|
documents = {str(item.get("document_id")): item for item in manifest.get("documents", []) if isinstance(item, dict)}
|
||||||
|
|
||||||
|
strategies = ("conservative", "balanced", "broad")
|
||||||
per_document: list[dict[str, Any]] = []
|
per_document: list[dict[str, Any]] = []
|
||||||
aggregate: dict[str, dict[str, float]] = {
|
aggregate: dict[str, dict[str, float]] = {
|
||||||
"conservative": {"matches": 0.0, "predicted": 0.0, "gold": 0.0, "meta_noise": 0.0},
|
strategy: {"matches": 0.0, "predicted": 0.0, "gold": 0.0, "meta_noise": 0.0}
|
||||||
"broad": {"matches": 0.0, "predicted": 0.0, "gold": 0.0, "meta_noise": 0.0},
|
for strategy in strategies
|
||||||
}
|
}
|
||||||
|
|
||||||
for entry in benchmark.get("documents", []):
|
for entry in benchmark.get("documents", []):
|
||||||
|
|
@ -118,7 +119,7 @@ def evaluate_doclift_claim_tracks(bundle_root: str | Path, benchmark_path: str |
|
||||||
document = documents[document_id]
|
document = documents[document_id]
|
||||||
gold_claims = [str(item).strip() for item in entry.get("gold_claims", []) if str(item).strip()]
|
gold_claims = [str(item).strip() for item in entry.get("gold_claims", []) if str(item).strip()]
|
||||||
track_scores = []
|
track_scores = []
|
||||||
for strategy in ("conservative", "broad"):
|
for strategy in strategies:
|
||||||
predicted_claims = adapter.extract_document_claims(base, document, strategy=strategy, limit=6)
|
predicted_claims = adapter.extract_document_claims(base, document, strategy=strategy, limit=6)
|
||||||
score = _score_track(predicted_claims, gold_claims, strategy)
|
score = _score_track(predicted_claims, gold_claims, strategy)
|
||||||
track_scores.append(score)
|
track_scores.append(score)
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,21 @@ class DocliftBundleSourceAdapter:
|
||||||
"[last update",
|
"[last update",
|
||||||
"this essay has been transferred here",
|
"this essay has been transferred here",
|
||||||
)
|
)
|
||||||
|
_CLAIM_CUES = (
|
||||||
|
" is ",
|
||||||
|
" are ",
|
||||||
|
" can ",
|
||||||
|
" do ",
|
||||||
|
" does ",
|
||||||
|
" means ",
|
||||||
|
" requires ",
|
||||||
|
" due to ",
|
||||||
|
" part of evolution",
|
||||||
|
" fixed in a population",
|
||||||
|
" by natural selection",
|
||||||
|
" by random genetic drift",
|
||||||
|
" over time",
|
||||||
|
)
|
||||||
|
|
||||||
def _resolve_bundle_path(self, base: Path, value: str | Path | None) -> Path:
|
def _resolve_bundle_path(self, base: Path, value: str | Path | None) -> Path:
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|
@ -88,11 +103,18 @@ class DocliftBundleSourceAdapter:
|
||||||
def _is_claim_candidate(self, cleaned: str, *, title: str = "", strategy: str = "conservative") -> bool:
|
def _is_claim_candidate(self, cleaned: str, *, title: str = "", strategy: str = "conservative") -> bool:
|
||||||
lowered = cleaned.lower()
|
lowered = cleaned.lower()
|
||||||
normalized_title = self._normalize_inline_text(title).lower()
|
normalized_title = self._normalize_inline_text(title).lower()
|
||||||
min_length = 70 if strategy == "conservative" else 40
|
if strategy == "conservative":
|
||||||
|
min_length = 70
|
||||||
|
elif strategy == "balanced":
|
||||||
|
min_length = 50
|
||||||
|
else:
|
||||||
|
min_length = 40
|
||||||
if len(cleaned) < min_length:
|
if len(cleaned) < min_length:
|
||||||
return False
|
return False
|
||||||
if strategy == "conservative" and len(cleaned) > 360:
|
if strategy == "conservative" and len(cleaned) > 360:
|
||||||
return False
|
return False
|
||||||
|
if strategy == "balanced" and len(cleaned) > 320:
|
||||||
|
return False
|
||||||
if strategy == "broad" and len(cleaned) > 520:
|
if strategy == "broad" and len(cleaned) > 520:
|
||||||
return False
|
return False
|
||||||
if any(lowered.startswith(prefix) for prefix in self._METADATA_PREFIXES):
|
if any(lowered.startswith(prefix) for prefix in self._METADATA_PREFIXES):
|
||||||
|
|
@ -101,8 +123,22 @@ class DocliftBundleSourceAdapter:
|
||||||
return False
|
return False
|
||||||
if cleaned.count(" ") < 8:
|
if cleaned.count(" ") < 8:
|
||||||
return False
|
return False
|
||||||
|
if strategy in {"balanced", "conservative"} and cleaned[:1].islower():
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _claim_priority(self, cleaned: str, *, strategy: str = "conservative") -> tuple[int, int]:
|
||||||
|
lowered = cleaned.lower()
|
||||||
|
cue_hits = sum(1 for cue in self._CLAIM_CUES if cue in lowered)
|
||||||
|
penalties = 0
|
||||||
|
if lowered.startswith(("the latest issue", "this is a brief introduction", "mistakes permeate")):
|
||||||
|
penalties += 1
|
||||||
|
if '"' in cleaned:
|
||||||
|
penalties += 1
|
||||||
|
if strategy == "balanced":
|
||||||
|
return (cue_hits - penalties, -abs(len(cleaned) - 140))
|
||||||
|
return (cue_hits - penalties, -len(cleaned))
|
||||||
|
|
||||||
def _extract_claim_sentences_from_paragraphs(
|
def _extract_claim_sentences_from_paragraphs(
|
||||||
self,
|
self,
|
||||||
paragraphs: list[str],
|
paragraphs: list[str],
|
||||||
|
|
@ -113,6 +149,7 @@ class DocliftBundleSourceAdapter:
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
claims: list[str] = []
|
claims: list[str] = []
|
||||||
seen: set[str] = set()
|
seen: set[str] = set()
|
||||||
|
candidates: list[tuple[str, tuple[int, int]]] = []
|
||||||
for paragraph in paragraphs:
|
for paragraph in paragraphs:
|
||||||
normalized_paragraph = self._normalize_inline_text(paragraph)
|
normalized_paragraph = self._normalize_inline_text(paragraph)
|
||||||
if len(normalized_paragraph) < 80:
|
if len(normalized_paragraph) < 80:
|
||||||
|
|
@ -132,9 +169,15 @@ class DocliftBundleSourceAdapter:
|
||||||
if lowered in seen:
|
if lowered in seen:
|
||||||
continue
|
continue
|
||||||
seen.add(lowered)
|
seen.add(lowered)
|
||||||
claims.append(cleaned)
|
if strategy == "balanced":
|
||||||
if len(claims) >= limit:
|
candidates.append((cleaned, self._claim_priority(cleaned, strategy=strategy)))
|
||||||
return claims
|
else:
|
||||||
|
claims.append(cleaned)
|
||||||
|
if len(claims) >= limit:
|
||||||
|
return claims
|
||||||
|
if strategy == "balanced":
|
||||||
|
ranked = sorted(candidates, key=lambda item: item[1], reverse=True)
|
||||||
|
return [item[0] for item in ranked[:limit]]
|
||||||
return claims
|
return claims
|
||||||
|
|
||||||
def _extract_claim_sentences(self, markdown_text: str, *, title: str = "", limit: int = 4, strategy: str = "conservative") -> list[str]:
|
def _extract_claim_sentences(self, markdown_text: str, *, title: str = "", limit: int = 4, strategy: str = "conservative") -> list[str]:
|
||||||
|
|
@ -170,7 +213,7 @@ class DocliftBundleSourceAdapter:
|
||||||
paragraphs,
|
paragraphs,
|
||||||
title=title,
|
title=title,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
strategy="conservative",
|
strategy=strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
def extract_document_claims(
|
def extract_document_claims(
|
||||||
|
|
|
||||||
|
|
@ -17,12 +17,14 @@ def test_doclift_claim_tournament_scores_two_tracks() -> None:
|
||||||
root = _fixture_root()
|
root = _fixture_root()
|
||||||
result = evaluate_doclift_claim_tracks(root, root / "benchmark.json")
|
result = evaluate_doclift_claim_tracks(root, root / "benchmark.json")
|
||||||
|
|
||||||
assert result["judge_summary"]["winner"] in {"conservative", "broad"}
|
assert result["judge_summary"]["winner"] in {"conservative", "balanced", "broad"}
|
||||||
assert set(result["judge_summary"]["tracks"].keys()) == {"conservative", "broad"}
|
assert set(result["judge_summary"]["tracks"].keys()) == {"conservative", "balanced", "broad"}
|
||||||
assert len(result["per_document"]) == 2
|
assert len(result["per_document"]) == 2
|
||||||
intro = next(item for item in result["per_document"] if item["document_id"] == "intro-essay")
|
intro = next(item for item in result["per_document"] if item["document_id"] == "intro-essay")
|
||||||
|
assert len(intro["tracks"]) == 3
|
||||||
assert intro["tracks"][0]["predicted_claims"]
|
assert intro["tracks"][0]["predicted_claims"]
|
||||||
assert intro["tracks"][1]["predicted_claims"]
|
assert intro["tracks"][1]["predicted_claims"]
|
||||||
|
assert intro["tracks"][2]["predicted_claims"]
|
||||||
|
|
||||||
|
|
||||||
def test_doclift_claim_tournament_broad_track_improves_recall_on_fixture() -> None:
|
def test_doclift_claim_tournament_broad_track_improves_recall_on_fixture() -> None:
|
||||||
|
|
@ -32,6 +34,7 @@ def test_doclift_claim_tournament_broad_track_improves_recall_on_fixture() -> No
|
||||||
|
|
||||||
assert tracks["broad"]["recall"] >= tracks["conservative"]["recall"]
|
assert tracks["broad"]["recall"] >= tracks["conservative"]["recall"]
|
||||||
assert tracks["broad"]["matches"] >= tracks["conservative"]["matches"]
|
assert tracks["broad"]["matches"] >= tracks["conservative"]["matches"]
|
||||||
|
assert tracks["balanced"]["precision"] >= tracks["conservative"]["precision"]
|
||||||
|
|
||||||
|
|
||||||
def test_doclift_claim_tournament_runs_on_real_corpus_fixture() -> None:
|
def test_doclift_claim_tournament_runs_on_real_corpus_fixture() -> None:
|
||||||
|
|
@ -41,5 +44,9 @@ def test_doclift_claim_tournament_runs_on_real_corpus_fixture() -> None:
|
||||||
|
|
||||||
assert len(result["per_document"]) == 2
|
assert len(result["per_document"]) == 2
|
||||||
assert tracks["conservative"]["gold_claims"] == 4
|
assert tracks["conservative"]["gold_claims"] == 4
|
||||||
|
assert tracks["balanced"]["gold_claims"] == 4
|
||||||
assert tracks["broad"]["gold_claims"] == 4
|
assert tracks["broad"]["gold_claims"] == 4
|
||||||
assert tracks["broad"]["matches"] >= 1
|
assert tracks["broad"]["matches"] >= 1
|
||||||
|
assert tracks["balanced"]["matches"] >= 1
|
||||||
|
assert tracks["balanced"]["recall"] >= tracks["broad"]["recall"]
|
||||||
|
assert tracks["balanced"]["f1"] >= tracks["broad"]["f1"]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue