diff --git a/src/groundrecall/doclift_claim_tournament.py b/src/groundrecall/doclift_claim_tournament.py index 8e0226e..d581759 100644 --- a/src/groundrecall/doclift_claim_tournament.py +++ b/src/groundrecall/doclift_claim_tournament.py @@ -107,10 +107,11 @@ def evaluate_doclift_claim_tracks(bundle_root: str | Path, benchmark_path: str | adapter = DocliftBundleSourceAdapter() documents = {str(item.get("document_id")): item for item in manifest.get("documents", []) if isinstance(item, dict)} + strategies = ("conservative", "balanced", "broad") per_document: list[dict[str, Any]] = [] aggregate: dict[str, dict[str, float]] = { - "conservative": {"matches": 0.0, "predicted": 0.0, "gold": 0.0, "meta_noise": 0.0}, - "broad": {"matches": 0.0, "predicted": 0.0, "gold": 0.0, "meta_noise": 0.0}, + strategy: {"matches": 0.0, "predicted": 0.0, "gold": 0.0, "meta_noise": 0.0} + for strategy in strategies } for entry in benchmark.get("documents", []): @@ -118,7 +119,7 @@ def evaluate_doclift_claim_tracks(bundle_root: str | Path, benchmark_path: str | document = documents[document_id] gold_claims = [str(item).strip() for item in entry.get("gold_claims", []) if str(item).strip()] track_scores = [] - for strategy in ("conservative", "broad"): + for strategy in strategies: predicted_claims = adapter.extract_document_claims(base, document, strategy=strategy, limit=6) score = _score_track(predicted_claims, gold_claims, strategy) track_scores.append(score) diff --git a/src/groundrecall/groundrecall_source_adapters/doclift_bundle.py b/src/groundrecall/groundrecall_source_adapters/doclift_bundle.py index 1d71da3..051155c 100755 --- a/src/groundrecall/groundrecall_source_adapters/doclift_bundle.py +++ b/src/groundrecall/groundrecall_source_adapters/doclift_bundle.py @@ -26,6 +26,21 @@ class DocliftBundleSourceAdapter: "[last update", "this essay has been transferred here", ) + _CLAIM_CUES = ( + " is ", + " are ", + " can ", + " do ", + " does ", + " means ", + " requires ", + " due to ", + " part of evolution", + " fixed in a population", + " by natural selection", + " by random genetic drift", + " over time", + ) def _resolve_bundle_path(self, base: Path, value: str | Path | None) -> Path: if value is None: @@ -88,11 +103,18 @@ class DocliftBundleSourceAdapter: def _is_claim_candidate(self, cleaned: str, *, title: str = "", strategy: str = "conservative") -> bool: lowered = cleaned.lower() normalized_title = self._normalize_inline_text(title).lower() - min_length = 70 if strategy == "conservative" else 40 + if strategy == "conservative": + min_length = 70 + elif strategy == "balanced": + min_length = 50 + else: + min_length = 40 if len(cleaned) < min_length: return False if strategy == "conservative" and len(cleaned) > 360: return False + if strategy == "balanced" and len(cleaned) > 320: + return False if strategy == "broad" and len(cleaned) > 520: return False if any(lowered.startswith(prefix) for prefix in self._METADATA_PREFIXES): @@ -101,8 +123,22 @@ class DocliftBundleSourceAdapter: return False if cleaned.count(" ") < 8: return False + if strategy in {"balanced", "conservative"} and cleaned[:1].islower(): + return False return True + def _claim_priority(self, cleaned: str, *, strategy: str = "conservative") -> tuple[int, int]: + lowered = cleaned.lower() + cue_hits = sum(1 for cue in self._CLAIM_CUES if cue in lowered) + penalties = 0 + if lowered.startswith(("the latest issue", "this is a brief introduction", "mistakes permeate")): + penalties += 1 + if '"' in cleaned: + penalties += 1 + if strategy == "balanced": + return (cue_hits - penalties, -abs(len(cleaned) - 140)) + return (cue_hits - penalties, -len(cleaned)) + def _extract_claim_sentences_from_paragraphs( self, paragraphs: list[str], @@ -113,6 +149,7 @@ class DocliftBundleSourceAdapter: ) -> list[str]: claims: list[str] = [] seen: set[str] = set() + candidates: list[tuple[str, tuple[int, int]]] = [] for paragraph in paragraphs: normalized_paragraph = self._normalize_inline_text(paragraph) if len(normalized_paragraph) < 80: @@ -132,9 +169,15 @@ class DocliftBundleSourceAdapter: if lowered in seen: continue seen.add(lowered) - claims.append(cleaned) - if len(claims) >= limit: - return claims + if strategy == "balanced": + candidates.append((cleaned, self._claim_priority(cleaned, strategy=strategy))) + else: + claims.append(cleaned) + if len(claims) >= limit: + return claims + if strategy == "balanced": + ranked = sorted(candidates, key=lambda item: item[1], reverse=True) + return [item[0] for item in ranked[:limit]] return claims def _extract_claim_sentences(self, markdown_text: str, *, title: str = "", limit: int = 4, strategy: str = "conservative") -> list[str]: @@ -170,7 +213,7 @@ class DocliftBundleSourceAdapter: paragraphs, title=title, limit=limit, - strategy="conservative", + strategy=strategy, ) def extract_document_claims( diff --git a/tests/test_doclift_claim_tournament.py b/tests/test_doclift_claim_tournament.py index 271aa8d..957c93a 100644 --- a/tests/test_doclift_claim_tournament.py +++ b/tests/test_doclift_claim_tournament.py @@ -17,12 +17,14 @@ def test_doclift_claim_tournament_scores_two_tracks() -> None: root = _fixture_root() result = evaluate_doclift_claim_tracks(root, root / "benchmark.json") - assert result["judge_summary"]["winner"] in {"conservative", "broad"} - assert set(result["judge_summary"]["tracks"].keys()) == {"conservative", "broad"} + assert result["judge_summary"]["winner"] in {"conservative", "balanced", "broad"} + assert set(result["judge_summary"]["tracks"].keys()) == {"conservative", "balanced", "broad"} assert len(result["per_document"]) == 2 intro = next(item for item in result["per_document"] if item["document_id"] == "intro-essay") + assert len(intro["tracks"]) == 3 assert intro["tracks"][0]["predicted_claims"] assert intro["tracks"][1]["predicted_claims"] + assert intro["tracks"][2]["predicted_claims"] def test_doclift_claim_tournament_broad_track_improves_recall_on_fixture() -> None: @@ -32,6 +34,7 @@ def test_doclift_claim_tournament_broad_track_improves_recall_on_fixture() -> No assert tracks["broad"]["recall"] >= tracks["conservative"]["recall"] assert tracks["broad"]["matches"] >= tracks["conservative"]["matches"] + assert tracks["balanced"]["precision"] >= tracks["conservative"]["precision"] def test_doclift_claim_tournament_runs_on_real_corpus_fixture() -> None: @@ -41,5 +44,9 @@ def test_doclift_claim_tournament_runs_on_real_corpus_fixture() -> None: assert len(result["per_document"]) == 2 assert tracks["conservative"]["gold_claims"] == 4 + assert tracks["balanced"]["gold_claims"] == 4 assert tracks["broad"]["gold_claims"] == 4 assert tracks["broad"]["matches"] >= 1 + assert tracks["balanced"]["matches"] >= 1 + assert tracks["balanced"]["recall"] >= tracks["broad"]["recall"] + assert tracks["balanced"]["f1"] >= tracks["broad"]["f1"]