Add balanced doclift claim track

This commit is contained in:
welsberr 2026-05-08 02:26:54 -04:00
parent 169500369f
commit 86a4db01a4
3 changed files with 61 additions and 10 deletions

View File

@ -107,10 +107,11 @@ def evaluate_doclift_claim_tracks(bundle_root: str | Path, benchmark_path: str |
adapter = DocliftBundleSourceAdapter() adapter = DocliftBundleSourceAdapter()
documents = {str(item.get("document_id")): item for item in manifest.get("documents", []) if isinstance(item, dict)} documents = {str(item.get("document_id")): item for item in manifest.get("documents", []) if isinstance(item, dict)}
strategies = ("conservative", "balanced", "broad")
per_document: list[dict[str, Any]] = [] per_document: list[dict[str, Any]] = []
aggregate: dict[str, dict[str, float]] = { aggregate: dict[str, dict[str, float]] = {
"conservative": {"matches": 0.0, "predicted": 0.0, "gold": 0.0, "meta_noise": 0.0}, strategy: {"matches": 0.0, "predicted": 0.0, "gold": 0.0, "meta_noise": 0.0}
"broad": {"matches": 0.0, "predicted": 0.0, "gold": 0.0, "meta_noise": 0.0}, for strategy in strategies
} }
for entry in benchmark.get("documents", []): for entry in benchmark.get("documents", []):
@ -118,7 +119,7 @@ def evaluate_doclift_claim_tracks(bundle_root: str | Path, benchmark_path: str |
document = documents[document_id] document = documents[document_id]
gold_claims = [str(item).strip() for item in entry.get("gold_claims", []) if str(item).strip()] gold_claims = [str(item).strip() for item in entry.get("gold_claims", []) if str(item).strip()]
track_scores = [] track_scores = []
for strategy in ("conservative", "broad"): for strategy in strategies:
predicted_claims = adapter.extract_document_claims(base, document, strategy=strategy, limit=6) predicted_claims = adapter.extract_document_claims(base, document, strategy=strategy, limit=6)
score = _score_track(predicted_claims, gold_claims, strategy) score = _score_track(predicted_claims, gold_claims, strategy)
track_scores.append(score) track_scores.append(score)

View File

@ -26,6 +26,21 @@ class DocliftBundleSourceAdapter:
"[last update", "[last update",
"this essay has been transferred here", "this essay has been transferred here",
) )
_CLAIM_CUES = (
" is ",
" are ",
" can ",
" do ",
" does ",
" means ",
" requires ",
" due to ",
" part of evolution",
" fixed in a population",
" by natural selection",
" by random genetic drift",
" over time",
)
def _resolve_bundle_path(self, base: Path, value: str | Path | None) -> Path: def _resolve_bundle_path(self, base: Path, value: str | Path | None) -> Path:
if value is None: if value is None:
@ -88,11 +103,18 @@ class DocliftBundleSourceAdapter:
def _is_claim_candidate(self, cleaned: str, *, title: str = "", strategy: str = "conservative") -> bool: def _is_claim_candidate(self, cleaned: str, *, title: str = "", strategy: str = "conservative") -> bool:
lowered = cleaned.lower() lowered = cleaned.lower()
normalized_title = self._normalize_inline_text(title).lower() normalized_title = self._normalize_inline_text(title).lower()
min_length = 70 if strategy == "conservative" else 40 if strategy == "conservative":
min_length = 70
elif strategy == "balanced":
min_length = 50
else:
min_length = 40
if len(cleaned) < min_length: if len(cleaned) < min_length:
return False return False
if strategy == "conservative" and len(cleaned) > 360: if strategy == "conservative" and len(cleaned) > 360:
return False return False
if strategy == "balanced" and len(cleaned) > 320:
return False
if strategy == "broad" and len(cleaned) > 520: if strategy == "broad" and len(cleaned) > 520:
return False return False
if any(lowered.startswith(prefix) for prefix in self._METADATA_PREFIXES): if any(lowered.startswith(prefix) for prefix in self._METADATA_PREFIXES):
@ -101,8 +123,22 @@ class DocliftBundleSourceAdapter:
return False return False
if cleaned.count(" ") < 8: if cleaned.count(" ") < 8:
return False return False
if strategy in {"balanced", "conservative"} and cleaned[:1].islower():
return False
return True return True
def _claim_priority(self, cleaned: str, *, strategy: str = "conservative") -> tuple[int, int]:
lowered = cleaned.lower()
cue_hits = sum(1 for cue in self._CLAIM_CUES if cue in lowered)
penalties = 0
if lowered.startswith(("the latest issue", "this is a brief introduction", "mistakes permeate")):
penalties += 1
if '"' in cleaned:
penalties += 1
if strategy == "balanced":
return (cue_hits - penalties, -abs(len(cleaned) - 140))
return (cue_hits - penalties, -len(cleaned))
def _extract_claim_sentences_from_paragraphs( def _extract_claim_sentences_from_paragraphs(
self, self,
paragraphs: list[str], paragraphs: list[str],
@ -113,6 +149,7 @@ class DocliftBundleSourceAdapter:
) -> list[str]: ) -> list[str]:
claims: list[str] = [] claims: list[str] = []
seen: set[str] = set() seen: set[str] = set()
candidates: list[tuple[str, tuple[int, int]]] = []
for paragraph in paragraphs: for paragraph in paragraphs:
normalized_paragraph = self._normalize_inline_text(paragraph) normalized_paragraph = self._normalize_inline_text(paragraph)
if len(normalized_paragraph) < 80: if len(normalized_paragraph) < 80:
@ -132,9 +169,15 @@ class DocliftBundleSourceAdapter:
if lowered in seen: if lowered in seen:
continue continue
seen.add(lowered) seen.add(lowered)
if strategy == "balanced":
candidates.append((cleaned, self._claim_priority(cleaned, strategy=strategy)))
else:
claims.append(cleaned) claims.append(cleaned)
if len(claims) >= limit: if len(claims) >= limit:
return claims return claims
if strategy == "balanced":
ranked = sorted(candidates, key=lambda item: item[1], reverse=True)
return [item[0] for item in ranked[:limit]]
return claims return claims
def _extract_claim_sentences(self, markdown_text: str, *, title: str = "", limit: int = 4, strategy: str = "conservative") -> list[str]: def _extract_claim_sentences(self, markdown_text: str, *, title: str = "", limit: int = 4, strategy: str = "conservative") -> list[str]:
@ -170,7 +213,7 @@ class DocliftBundleSourceAdapter:
paragraphs, paragraphs,
title=title, title=title,
limit=limit, limit=limit,
strategy="conservative", strategy=strategy,
) )
def extract_document_claims( def extract_document_claims(

View File

@ -17,12 +17,14 @@ def test_doclift_claim_tournament_scores_two_tracks() -> None:
root = _fixture_root() root = _fixture_root()
result = evaluate_doclift_claim_tracks(root, root / "benchmark.json") result = evaluate_doclift_claim_tracks(root, root / "benchmark.json")
assert result["judge_summary"]["winner"] in {"conservative", "broad"} assert result["judge_summary"]["winner"] in {"conservative", "balanced", "broad"}
assert set(result["judge_summary"]["tracks"].keys()) == {"conservative", "broad"} assert set(result["judge_summary"]["tracks"].keys()) == {"conservative", "balanced", "broad"}
assert len(result["per_document"]) == 2 assert len(result["per_document"]) == 2
intro = next(item for item in result["per_document"] if item["document_id"] == "intro-essay") intro = next(item for item in result["per_document"] if item["document_id"] == "intro-essay")
assert len(intro["tracks"]) == 3
assert intro["tracks"][0]["predicted_claims"] assert intro["tracks"][0]["predicted_claims"]
assert intro["tracks"][1]["predicted_claims"] assert intro["tracks"][1]["predicted_claims"]
assert intro["tracks"][2]["predicted_claims"]
def test_doclift_claim_tournament_broad_track_improves_recall_on_fixture() -> None: def test_doclift_claim_tournament_broad_track_improves_recall_on_fixture() -> None:
@ -32,6 +34,7 @@ def test_doclift_claim_tournament_broad_track_improves_recall_on_fixture() -> No
assert tracks["broad"]["recall"] >= tracks["conservative"]["recall"] assert tracks["broad"]["recall"] >= tracks["conservative"]["recall"]
assert tracks["broad"]["matches"] >= tracks["conservative"]["matches"] assert tracks["broad"]["matches"] >= tracks["conservative"]["matches"]
assert tracks["balanced"]["precision"] >= tracks["conservative"]["precision"]
def test_doclift_claim_tournament_runs_on_real_corpus_fixture() -> None: def test_doclift_claim_tournament_runs_on_real_corpus_fixture() -> None:
@ -41,5 +44,9 @@ def test_doclift_claim_tournament_runs_on_real_corpus_fixture() -> None:
assert len(result["per_document"]) == 2 assert len(result["per_document"]) == 2
assert tracks["conservative"]["gold_claims"] == 4 assert tracks["conservative"]["gold_claims"] == 4
assert tracks["balanced"]["gold_claims"] == 4
assert tracks["broad"]["gold_claims"] == 4 assert tracks["broad"]["gold_claims"] == 4
assert tracks["broad"]["matches"] >= 1 assert tracks["broad"]["matches"] >= 1
assert tracks["balanced"]["matches"] >= 1
assert tracks["balanced"]["recall"] >= tracks["broad"]["recall"]
assert tracks["balanced"]["f1"] >= tracks["broad"]["f1"]