Add adaptive doclift claim selection

This commit is contained in:
welsberr 2026-05-08 02:35:30 -04:00
parent 86a4db01a4
commit a9974110e2
3 changed files with 145 additions and 1 deletions

View File

@ -10,6 +10,7 @@ from .base import DiscoveredImportSource, StructuredImportRows, register_source_
class DocliftBundleSourceAdapter: class DocliftBundleSourceAdapter:
name = "doclift_bundle" name = "doclift_bundle"
_TRACK_STRATEGIES = ("conservative", "balanced", "broad")
_PROSE_SENTENCE_SPLIT = re.compile(r"(?<=[.!?])\s+") _PROSE_SENTENCE_SPLIT = re.compile(r"(?<=[.!?])\s+")
_METADATA_PREFIXES = ( _METADATA_PREFIXES = (
@ -139,6 +140,81 @@ class DocliftBundleSourceAdapter:
return (cue_hits - penalties, -abs(len(cleaned) - 140)) return (cue_hits - penalties, -abs(len(cleaned) - 140))
return (cue_hits - penalties, -len(cleaned)) return (cue_hits - penalties, -len(cleaned))
def _meta_noise_count(self, claims: list[str]) -> int:
count = 0
for claim in claims:
lowered = claim.lower()
if any(lowered.startswith(prefix) for prefix in self._METADATA_PREFIXES):
count += 1
if lowered.startswith(("the latest issue", "this is a brief introduction", "mistakes permeate")):
count += 1
return count
def _truncated_claim_count(self, claims: list[str]) -> int:
truncated = 0
bad_endings = (" and", " of", " to", " in", " by", " with", " that", " this")
for claim in claims:
stripped = claim.strip()
lowered = stripped.lower()
if not stripped.endswith((".", "!", "?", '"')):
truncated += 1
continue
if any(lowered.endswith(ending) for ending in bad_endings):
truncated += 1
return truncated
def _redundancy_penalty(self, claims: list[str]) -> float:
penalty = 0.0
normalized = [set(re.findall(r"[a-z0-9]+", claim.lower())) for claim in claims]
for i, left in enumerate(normalized):
for right in normalized[i + 1 :]:
if not left or not right:
continue
overlap = len(left & right) / len(left | right)
if overlap >= 0.6:
penalty += overlap
return penalty
def _score_claim_batch(self, claims: list[str], *, strategy: str) -> tuple[float, float, float, float, float]:
if not claims:
return (-999.0, -999.0, -999.0, -999.0, -999.0)
cue_score = 0.0
for claim in claims:
cue_score += self._claim_priority(claim, strategy=strategy)[0]
avg_cue_score = cue_score / len(claims)
meta_penalty = float(self._meta_noise_count(claims))
truncated_penalty = float(self._truncated_claim_count(claims))
redundancy_penalty = self._redundancy_penalty(claims)
avg_length = sum(len(claim) for claim in claims) / len(claims)
target_count = {
"conservative": 3,
"balanced": 5,
"broad": 6,
}.get(strategy, 4)
target_length = {
"conservative": 150,
"balanced": 100,
"broad": 180,
}.get(strategy, 140)
count_penalty = abs(len(claims) - target_count)
length_penalty = abs(avg_length - target_length) / 100.0
strategy_bias = {
"balanced": 0.35,
"conservative": 0.0,
"broad": -0.15,
}.get(strategy, 0.0)
return (
avg_cue_score + strategy_bias - meta_penalty - truncated_penalty - redundancy_penalty,
-count_penalty,
-(length_penalty + truncated_penalty + redundancy_penalty),
-meta_penalty,
-abs(avg_length - target_length),
)
def _strategy_total_score(self, score: tuple[float, float, float, float, float]) -> float:
primary, count_fit, quality_fit, meta_fit, length_fit = score
return (0.5 * primary) + (0.8 * count_fit) + (1.2 * quality_fit) + (0.5 * meta_fit) + (0.03 * length_fit)
def _extract_claim_sentences_from_paragraphs( def _extract_claim_sentences_from_paragraphs(
self, self,
paragraphs: list[str], paragraphs: list[str],
@ -224,10 +300,46 @@ class DocliftBundleSourceAdapter:
strategy: str = "conservative", strategy: str = "conservative",
limit: int = 4, limit: int = 4,
) -> list[str]: ) -> list[str]:
if strategy == "auto":
selected = self.select_document_claim_strategy(base, document, limit=limit)
return self.extract_document_claims(base, document, strategy=selected, limit=limit)
markdown_text = self._load_markdown_text(base, document) markdown_text = self._load_markdown_text(base, document)
title = str(document.get("title") or "") title = str(document.get("title") or "")
return self._extract_claim_sentences(markdown_text, title=title, limit=limit, strategy=strategy) return self._extract_claim_sentences(markdown_text, title=title, limit=limit, strategy=strategy)
def select_document_claim_strategy(self, base: Path, document: dict, *, limit: int = 4) -> str:
candidates: list[tuple[str, tuple[float, float, float, float, float]]] = []
for strategy in self._TRACK_STRATEGIES:
claims = self.extract_document_claims(base, document, strategy=strategy, limit=limit)
candidates.append((strategy, self._score_claim_batch(claims, strategy=strategy)))
return max(candidates, key=lambda item: self._strategy_total_score(item[1]))[0]
def select_bundle_claim_strategy(self, base: Path, documents: list[dict], *, limit: int = 4) -> str:
candidate_docs = [
document
for document in documents
if str(document.get("document_kind") or "").strip() in {"web_article", "document"}
]
if not candidate_docs:
return "conservative"
scored: list[tuple[str, tuple[float, float, float, float, float]]] = []
for strategy in self._TRACK_STRATEGIES:
total = [0.0, 0.0, 0.0, 0.0, 0.0]
used = 0
for document in candidate_docs[:6]:
claims = self.extract_document_claims(base, document, strategy=strategy, limit=limit)
score = self._score_claim_batch(claims, strategy=strategy)
if score[0] <= -999.0:
continue
total = [left + right for left, right in zip(total, score)]
used += 1
if used:
averaged = tuple(value / used for value in total)
else:
averaged = (-999.0, -999.0, -999.0, -999.0, -999.0)
scored.append((strategy, averaged))
return max(scored, key=lambda item: self._strategy_total_score(item[1]))[0]
def discover(self, root: str | Path) -> list[DiscoveredImportSource]: def discover(self, root: str | Path) -> list[DiscoveredImportSource]:
base = Path(root) base = Path(root)
rows: list[DiscoveredImportSource] = [] rows: list[DiscoveredImportSource] = []
@ -285,6 +397,7 @@ class DocliftBundleSourceAdapter:
artifact_by_path[source.relative_path] = artifact_id artifact_by_path[source.relative_path] = artifact_id
documents = [item for item in manifest.get("documents", []) if isinstance(item, dict)] documents = [item for item in manifest.get("documents", []) if isinstance(item, dict)]
bundle_claim_strategy = self.select_bundle_claim_strategy(base, documents)
previous_concept_id: str | None = None previous_concept_id: str | None = None
for index, document in enumerate(documents, start=1): for index, document in enumerate(documents, start=1):
title = str(document.get("title") or f"Document {index}") title = str(document.get("title") or f"Document {index}")
@ -402,7 +515,8 @@ class DocliftBundleSourceAdapter:
) )
document_claim_ids.append(claim_id) document_claim_ids.append(claim_id)
if not document_claim_ids and str(document.get("document_kind") or "").strip() in {"web_article", "document"}: if not document_claim_ids and str(document.get("document_kind") or "").strip() in {"web_article", "document"}:
for derived_index, claim_text in enumerate(self.extract_document_claims(base, document, strategy="conservative"), start=1): selected_strategy = bundle_claim_strategy
for derived_index, claim_text in enumerate(self.extract_document_claims(base, document, strategy=selected_strategy), start=1):
derived_observation_id = f"obs_doclift_{index}_derived_{derived_index}" derived_observation_id = f"obs_doclift_{index}_derived_{derived_index}"
claim_id = f"clm_doclift_{index}_derived_{derived_index}" claim_id = f"clm_doclift_{index}_derived_{derived_index}"
observation_rows.append( observation_rows.append(
@ -420,6 +534,7 @@ class DocliftBundleSourceAdapter:
"metadata": { "metadata": {
"source_path_kind": source_path_kind, "source_path_kind": source_path_kind,
"derived_from": "markdown_sentence", "derived_from": "markdown_sentence",
"claim_strategy": selected_strategy,
}, },
"grounding_status": "grounded", "grounding_status": "grounded",
"support_kind": "direct_source", "support_kind": "direct_source",

View File

@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
import json
from pathlib import Path from pathlib import Path
from groundrecall.doclift_claim_tournament import evaluate_doclift_claim_tracks from groundrecall.doclift_claim_tournament import evaluate_doclift_claim_tracks
from groundrecall.groundrecall_source_adapters.doclift_bundle import DocliftBundleSourceAdapter
def _fixture_root() -> Path: def _fixture_root() -> Path:
@ -50,3 +52,27 @@ def test_doclift_claim_tournament_runs_on_real_corpus_fixture() -> None:
assert tracks["balanced"]["matches"] >= 1 assert tracks["balanced"]["matches"] >= 1
assert tracks["balanced"]["recall"] >= tracks["broad"]["recall"] assert tracks["balanced"]["recall"] >= tracks["broad"]["recall"]
assert tracks["balanced"]["f1"] >= tracks["broad"]["f1"] assert tracks["balanced"]["f1"] >= tracks["broad"]["f1"]
def test_doclift_auto_bundle_strategy_prefers_balanced_on_real_corpus_fixture() -> None:
root = _pilot_fixture_root()
manifest = json.loads((root / "manifest.json").read_text(encoding="utf-8"))
adapter = DocliftBundleSourceAdapter()
documents = [item for item in manifest["documents"]]
strategy = adapter.select_bundle_claim_strategy(
root,
documents,
limit=6,
)
assert strategy == "balanced"
def test_doclift_auto_strategy_returns_available_track_on_synthetic_fixture() -> None:
root = _fixture_root()
manifest = json.loads((root / "manifest.json").read_text(encoding="utf-8"))
adapter = DocliftBundleSourceAdapter()
documents = {str(item["document_id"]): item for item in manifest["documents"]}
strategy = adapter.select_document_claim_strategy(root, documents["drift-essay"], limit=6)
assert strategy in {"conservative", "balanced", "broad"}

View File

@ -350,3 +350,6 @@ def test_doclift_bundle_import_derives_claims_from_prose_when_chunks_are_body_on
assert any("Random genetic drift can dominate allele-frequency change in small populations." in text for text in claim_texts) assert any("Random genetic drift can dominate allele-frequency change in small populations." in text for text in claim_texts)
assert not any(text == "Drift Essay is a web_article in the imported doclift bundle." for text in claim_texts) assert not any(text == "Drift Essay is a web_article in the imported doclift bundle." for text in claim_texts)
derived_observations = [item for item in result.observations if item["observation_id"].startswith("obs_doclift_1_derived_")]
assert derived_observations
assert derived_observations[0]["metadata"]["claim_strategy"] in {"conservative", "balanced", "broad"}