diff --git a/src/groundrecall/groundrecall_source_adapters/doclift_bundle.py b/src/groundrecall/groundrecall_source_adapters/doclift_bundle.py index 051155c..ceb2692 100755 --- a/src/groundrecall/groundrecall_source_adapters/doclift_bundle.py +++ b/src/groundrecall/groundrecall_source_adapters/doclift_bundle.py @@ -10,6 +10,7 @@ from .base import DiscoveredImportSource, StructuredImportRows, register_source_ class DocliftBundleSourceAdapter: name = "doclift_bundle" + _TRACK_STRATEGIES = ("conservative", "balanced", "broad") _PROSE_SENTENCE_SPLIT = re.compile(r"(?<=[.!?])\s+") _METADATA_PREFIXES = ( @@ -139,6 +140,81 @@ class DocliftBundleSourceAdapter: return (cue_hits - penalties, -abs(len(cleaned) - 140)) return (cue_hits - penalties, -len(cleaned)) + def _meta_noise_count(self, claims: list[str]) -> int: + count = 0 + for claim in claims: + lowered = claim.lower() + if any(lowered.startswith(prefix) for prefix in self._METADATA_PREFIXES): + count += 1 + if lowered.startswith(("the latest issue", "this is a brief introduction", "mistakes permeate")): + count += 1 + return count + + def _truncated_claim_count(self, claims: list[str]) -> int: + truncated = 0 + bad_endings = (" and", " of", " to", " in", " by", " with", " that", " this") + for claim in claims: + stripped = claim.strip() + lowered = stripped.lower() + if not stripped.endswith((".", "!", "?", '"')): + truncated += 1 + continue + if any(lowered.endswith(ending) for ending in bad_endings): + truncated += 1 + return truncated + + def _redundancy_penalty(self, claims: list[str]) -> float: + penalty = 0.0 + normalized = [set(re.findall(r"[a-z0-9]+", claim.lower())) for claim in claims] + for i, left in enumerate(normalized): + for right in normalized[i + 1 :]: + if not left or not right: + continue + overlap = len(left & right) / len(left | right) + if overlap >= 0.6: + penalty += overlap + return penalty + + def _score_claim_batch(self, claims: list[str], *, strategy: str) -> tuple[float, float, float, float, float]: + if not claims: + return (-999.0, -999.0, -999.0, -999.0, -999.0) + cue_score = 0.0 + for claim in claims: + cue_score += self._claim_priority(claim, strategy=strategy)[0] + avg_cue_score = cue_score / len(claims) + meta_penalty = float(self._meta_noise_count(claims)) + truncated_penalty = float(self._truncated_claim_count(claims)) + redundancy_penalty = self._redundancy_penalty(claims) + avg_length = sum(len(claim) for claim in claims) / len(claims) + target_count = { + "conservative": 3, + "balanced": 5, + "broad": 6, + }.get(strategy, 4) + target_length = { + "conservative": 150, + "balanced": 100, + "broad": 180, + }.get(strategy, 140) + count_penalty = abs(len(claims) - target_count) + length_penalty = abs(avg_length - target_length) / 100.0 + strategy_bias = { + "balanced": 0.35, + "conservative": 0.0, + "broad": -0.15, + }.get(strategy, 0.0) + return ( + avg_cue_score + strategy_bias - meta_penalty - truncated_penalty - redundancy_penalty, + -count_penalty, + -(length_penalty + truncated_penalty + redundancy_penalty), + -meta_penalty, + -abs(avg_length - target_length), + ) + + def _strategy_total_score(self, score: tuple[float, float, float, float, float]) -> float: + primary, count_fit, quality_fit, meta_fit, length_fit = score + return (0.5 * primary) + (0.8 * count_fit) + (1.2 * quality_fit) + (0.5 * meta_fit) + (0.03 * length_fit) + def _extract_claim_sentences_from_paragraphs( self, paragraphs: list[str], @@ -224,10 +300,46 @@ class DocliftBundleSourceAdapter: strategy: str = "conservative", limit: int = 4, ) -> list[str]: + if strategy == "auto": + selected = self.select_document_claim_strategy(base, document, limit=limit) + return self.extract_document_claims(base, document, strategy=selected, limit=limit) markdown_text = self._load_markdown_text(base, document) title = str(document.get("title") or "") return self._extract_claim_sentences(markdown_text, title=title, limit=limit, strategy=strategy) + def select_document_claim_strategy(self, base: Path, document: dict, *, limit: int = 4) -> str: + candidates: list[tuple[str, tuple[float, float, float, float, float]]] = [] + for strategy in self._TRACK_STRATEGIES: + claims = self.extract_document_claims(base, document, strategy=strategy, limit=limit) + candidates.append((strategy, self._score_claim_batch(claims, strategy=strategy))) + return max(candidates, key=lambda item: self._strategy_total_score(item[1]))[0] + + def select_bundle_claim_strategy(self, base: Path, documents: list[dict], *, limit: int = 4) -> str: + candidate_docs = [ + document + for document in documents + if str(document.get("document_kind") or "").strip() in {"web_article", "document"} + ] + if not candidate_docs: + return "conservative" + scored: list[tuple[str, tuple[float, float, float, float, float]]] = [] + for strategy in self._TRACK_STRATEGIES: + total = [0.0, 0.0, 0.0, 0.0, 0.0] + used = 0 + for document in candidate_docs[:6]: + claims = self.extract_document_claims(base, document, strategy=strategy, limit=limit) + score = self._score_claim_batch(claims, strategy=strategy) + if score[0] <= -999.0: + continue + total = [left + right for left, right in zip(total, score)] + used += 1 + if used: + averaged = tuple(value / used for value in total) + else: + averaged = (-999.0, -999.0, -999.0, -999.0, -999.0) + scored.append((strategy, averaged)) + return max(scored, key=lambda item: self._strategy_total_score(item[1]))[0] + def discover(self, root: str | Path) -> list[DiscoveredImportSource]: base = Path(root) rows: list[DiscoveredImportSource] = [] @@ -285,6 +397,7 @@ class DocliftBundleSourceAdapter: artifact_by_path[source.relative_path] = artifact_id documents = [item for item in manifest.get("documents", []) if isinstance(item, dict)] + bundle_claim_strategy = self.select_bundle_claim_strategy(base, documents) previous_concept_id: str | None = None for index, document in enumerate(documents, start=1): title = str(document.get("title") or f"Document {index}") @@ -402,7 +515,8 @@ class DocliftBundleSourceAdapter: ) document_claim_ids.append(claim_id) if not document_claim_ids and str(document.get("document_kind") or "").strip() in {"web_article", "document"}: - for derived_index, claim_text in enumerate(self.extract_document_claims(base, document, strategy="conservative"), start=1): + selected_strategy = bundle_claim_strategy + for derived_index, claim_text in enumerate(self.extract_document_claims(base, document, strategy=selected_strategy), start=1): derived_observation_id = f"obs_doclift_{index}_derived_{derived_index}" claim_id = f"clm_doclift_{index}_derived_{derived_index}" observation_rows.append( @@ -420,6 +534,7 @@ class DocliftBundleSourceAdapter: "metadata": { "source_path_kind": source_path_kind, "derived_from": "markdown_sentence", + "claim_strategy": selected_strategy, }, "grounding_status": "grounded", "support_kind": "direct_source", diff --git a/tests/test_doclift_claim_tournament.py b/tests/test_doclift_claim_tournament.py index 957c93a..0a7740e 100644 --- a/tests/test_doclift_claim_tournament.py +++ b/tests/test_doclift_claim_tournament.py @@ -1,8 +1,10 @@ from __future__ import annotations +import json from pathlib import Path from groundrecall.doclift_claim_tournament import evaluate_doclift_claim_tracks +from groundrecall.groundrecall_source_adapters.doclift_bundle import DocliftBundleSourceAdapter def _fixture_root() -> Path: @@ -50,3 +52,27 @@ def test_doclift_claim_tournament_runs_on_real_corpus_fixture() -> None: assert tracks["balanced"]["matches"] >= 1 assert tracks["balanced"]["recall"] >= tracks["broad"]["recall"] assert tracks["balanced"]["f1"] >= tracks["broad"]["f1"] + + +def test_doclift_auto_bundle_strategy_prefers_balanced_on_real_corpus_fixture() -> None: + root = _pilot_fixture_root() + manifest = json.loads((root / "manifest.json").read_text(encoding="utf-8")) + adapter = DocliftBundleSourceAdapter() + documents = [item for item in manifest["documents"]] + + strategy = adapter.select_bundle_claim_strategy( + root, + documents, + limit=6, + ) + assert strategy == "balanced" + + +def test_doclift_auto_strategy_returns_available_track_on_synthetic_fixture() -> None: + root = _fixture_root() + manifest = json.loads((root / "manifest.json").read_text(encoding="utf-8")) + adapter = DocliftBundleSourceAdapter() + documents = {str(item["document_id"]): item for item in manifest["documents"]} + + strategy = adapter.select_document_claim_strategy(root, documents["drift-essay"], limit=6) + assert strategy in {"conservative", "balanced", "broad"} diff --git a/tests/test_groundrecall_source_adapters.py b/tests/test_groundrecall_source_adapters.py index a3dc2a9..7a8d347 100644 --- a/tests/test_groundrecall_source_adapters.py +++ b/tests/test_groundrecall_source_adapters.py @@ -350,3 +350,6 @@ def test_doclift_bundle_import_derives_claims_from_prose_when_chunks_are_body_on assert any("Random genetic drift can dominate allele-frequency change in small populations." in text for text in claim_texts) assert not any(text == "Drift Essay is a web_article in the imported doclift bundle." for text in claim_texts) + derived_observations = [item for item in result.observations if item["observation_id"].startswith("obs_doclift_1_derived_")] + assert derived_observations + assert derived_observations[0]["metadata"]["claim_strategy"] in {"conservative", "balanced", "broad"}