Add adaptive doclift claim selection
This commit is contained in:
parent
86a4db01a4
commit
a9974110e2
|
|
@ -10,6 +10,7 @@ from .base import DiscoveredImportSource, StructuredImportRows, register_source_
|
|||
|
||||
class DocliftBundleSourceAdapter:
|
||||
name = "doclift_bundle"
|
||||
_TRACK_STRATEGIES = ("conservative", "balanced", "broad")
|
||||
|
||||
_PROSE_SENTENCE_SPLIT = re.compile(r"(?<=[.!?])\s+")
|
||||
_METADATA_PREFIXES = (
|
||||
|
|
@ -139,6 +140,81 @@ class DocliftBundleSourceAdapter:
|
|||
return (cue_hits - penalties, -abs(len(cleaned) - 140))
|
||||
return (cue_hits - penalties, -len(cleaned))
|
||||
|
||||
def _meta_noise_count(self, claims: list[str]) -> int:
|
||||
count = 0
|
||||
for claim in claims:
|
||||
lowered = claim.lower()
|
||||
if any(lowered.startswith(prefix) for prefix in self._METADATA_PREFIXES):
|
||||
count += 1
|
||||
if lowered.startswith(("the latest issue", "this is a brief introduction", "mistakes permeate")):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _truncated_claim_count(self, claims: list[str]) -> int:
|
||||
truncated = 0
|
||||
bad_endings = (" and", " of", " to", " in", " by", " with", " that", " this")
|
||||
for claim in claims:
|
||||
stripped = claim.strip()
|
||||
lowered = stripped.lower()
|
||||
if not stripped.endswith((".", "!", "?", '"')):
|
||||
truncated += 1
|
||||
continue
|
||||
if any(lowered.endswith(ending) for ending in bad_endings):
|
||||
truncated += 1
|
||||
return truncated
|
||||
|
||||
def _redundancy_penalty(self, claims: list[str]) -> float:
|
||||
penalty = 0.0
|
||||
normalized = [set(re.findall(r"[a-z0-9]+", claim.lower())) for claim in claims]
|
||||
for i, left in enumerate(normalized):
|
||||
for right in normalized[i + 1 :]:
|
||||
if not left or not right:
|
||||
continue
|
||||
overlap = len(left & right) / len(left | right)
|
||||
if overlap >= 0.6:
|
||||
penalty += overlap
|
||||
return penalty
|
||||
|
||||
def _score_claim_batch(self, claims: list[str], *, strategy: str) -> tuple[float, float, float, float, float]:
|
||||
if not claims:
|
||||
return (-999.0, -999.0, -999.0, -999.0, -999.0)
|
||||
cue_score = 0.0
|
||||
for claim in claims:
|
||||
cue_score += self._claim_priority(claim, strategy=strategy)[0]
|
||||
avg_cue_score = cue_score / len(claims)
|
||||
meta_penalty = float(self._meta_noise_count(claims))
|
||||
truncated_penalty = float(self._truncated_claim_count(claims))
|
||||
redundancy_penalty = self._redundancy_penalty(claims)
|
||||
avg_length = sum(len(claim) for claim in claims) / len(claims)
|
||||
target_count = {
|
||||
"conservative": 3,
|
||||
"balanced": 5,
|
||||
"broad": 6,
|
||||
}.get(strategy, 4)
|
||||
target_length = {
|
||||
"conservative": 150,
|
||||
"balanced": 100,
|
||||
"broad": 180,
|
||||
}.get(strategy, 140)
|
||||
count_penalty = abs(len(claims) - target_count)
|
||||
length_penalty = abs(avg_length - target_length) / 100.0
|
||||
strategy_bias = {
|
||||
"balanced": 0.35,
|
||||
"conservative": 0.0,
|
||||
"broad": -0.15,
|
||||
}.get(strategy, 0.0)
|
||||
return (
|
||||
avg_cue_score + strategy_bias - meta_penalty - truncated_penalty - redundancy_penalty,
|
||||
-count_penalty,
|
||||
-(length_penalty + truncated_penalty + redundancy_penalty),
|
||||
-meta_penalty,
|
||||
-abs(avg_length - target_length),
|
||||
)
|
||||
|
||||
def _strategy_total_score(self, score: tuple[float, float, float, float, float]) -> float:
|
||||
primary, count_fit, quality_fit, meta_fit, length_fit = score
|
||||
return (0.5 * primary) + (0.8 * count_fit) + (1.2 * quality_fit) + (0.5 * meta_fit) + (0.03 * length_fit)
|
||||
|
||||
def _extract_claim_sentences_from_paragraphs(
|
||||
self,
|
||||
paragraphs: list[str],
|
||||
|
|
@ -224,10 +300,46 @@ class DocliftBundleSourceAdapter:
|
|||
strategy: str = "conservative",
|
||||
limit: int = 4,
|
||||
) -> list[str]:
|
||||
if strategy == "auto":
|
||||
selected = self.select_document_claim_strategy(base, document, limit=limit)
|
||||
return self.extract_document_claims(base, document, strategy=selected, limit=limit)
|
||||
markdown_text = self._load_markdown_text(base, document)
|
||||
title = str(document.get("title") or "")
|
||||
return self._extract_claim_sentences(markdown_text, title=title, limit=limit, strategy=strategy)
|
||||
|
||||
def select_document_claim_strategy(self, base: Path, document: dict, *, limit: int = 4) -> str:
|
||||
candidates: list[tuple[str, tuple[float, float, float, float, float]]] = []
|
||||
for strategy in self._TRACK_STRATEGIES:
|
||||
claims = self.extract_document_claims(base, document, strategy=strategy, limit=limit)
|
||||
candidates.append((strategy, self._score_claim_batch(claims, strategy=strategy)))
|
||||
return max(candidates, key=lambda item: self._strategy_total_score(item[1]))[0]
|
||||
|
||||
def select_bundle_claim_strategy(self, base: Path, documents: list[dict], *, limit: int = 4) -> str:
|
||||
candidate_docs = [
|
||||
document
|
||||
for document in documents
|
||||
if str(document.get("document_kind") or "").strip() in {"web_article", "document"}
|
||||
]
|
||||
if not candidate_docs:
|
||||
return "conservative"
|
||||
scored: list[tuple[str, tuple[float, float, float, float, float]]] = []
|
||||
for strategy in self._TRACK_STRATEGIES:
|
||||
total = [0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
used = 0
|
||||
for document in candidate_docs[:6]:
|
||||
claims = self.extract_document_claims(base, document, strategy=strategy, limit=limit)
|
||||
score = self._score_claim_batch(claims, strategy=strategy)
|
||||
if score[0] <= -999.0:
|
||||
continue
|
||||
total = [left + right for left, right in zip(total, score)]
|
||||
used += 1
|
||||
if used:
|
||||
averaged = tuple(value / used for value in total)
|
||||
else:
|
||||
averaged = (-999.0, -999.0, -999.0, -999.0, -999.0)
|
||||
scored.append((strategy, averaged))
|
||||
return max(scored, key=lambda item: self._strategy_total_score(item[1]))[0]
|
||||
|
||||
def discover(self, root: str | Path) -> list[DiscoveredImportSource]:
|
||||
base = Path(root)
|
||||
rows: list[DiscoveredImportSource] = []
|
||||
|
|
@ -285,6 +397,7 @@ class DocliftBundleSourceAdapter:
|
|||
artifact_by_path[source.relative_path] = artifact_id
|
||||
|
||||
documents = [item for item in manifest.get("documents", []) if isinstance(item, dict)]
|
||||
bundle_claim_strategy = self.select_bundle_claim_strategy(base, documents)
|
||||
previous_concept_id: str | None = None
|
||||
for index, document in enumerate(documents, start=1):
|
||||
title = str(document.get("title") or f"Document {index}")
|
||||
|
|
@ -402,7 +515,8 @@ class DocliftBundleSourceAdapter:
|
|||
)
|
||||
document_claim_ids.append(claim_id)
|
||||
if not document_claim_ids and str(document.get("document_kind") or "").strip() in {"web_article", "document"}:
|
||||
for derived_index, claim_text in enumerate(self.extract_document_claims(base, document, strategy="conservative"), start=1):
|
||||
selected_strategy = bundle_claim_strategy
|
||||
for derived_index, claim_text in enumerate(self.extract_document_claims(base, document, strategy=selected_strategy), start=1):
|
||||
derived_observation_id = f"obs_doclift_{index}_derived_{derived_index}"
|
||||
claim_id = f"clm_doclift_{index}_derived_{derived_index}"
|
||||
observation_rows.append(
|
||||
|
|
@ -420,6 +534,7 @@ class DocliftBundleSourceAdapter:
|
|||
"metadata": {
|
||||
"source_path_kind": source_path_kind,
|
||||
"derived_from": "markdown_sentence",
|
||||
"claim_strategy": selected_strategy,
|
||||
},
|
||||
"grounding_status": "grounded",
|
||||
"support_kind": "direct_source",
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from groundrecall.doclift_claim_tournament import evaluate_doclift_claim_tracks
|
||||
from groundrecall.groundrecall_source_adapters.doclift_bundle import DocliftBundleSourceAdapter
|
||||
|
||||
|
||||
def _fixture_root() -> Path:
|
||||
|
|
@ -50,3 +52,27 @@ def test_doclift_claim_tournament_runs_on_real_corpus_fixture() -> None:
|
|||
assert tracks["balanced"]["matches"] >= 1
|
||||
assert tracks["balanced"]["recall"] >= tracks["broad"]["recall"]
|
||||
assert tracks["balanced"]["f1"] >= tracks["broad"]["f1"]
|
||||
|
||||
|
||||
def test_doclift_auto_bundle_strategy_prefers_balanced_on_real_corpus_fixture() -> None:
|
||||
root = _pilot_fixture_root()
|
||||
manifest = json.loads((root / "manifest.json").read_text(encoding="utf-8"))
|
||||
adapter = DocliftBundleSourceAdapter()
|
||||
documents = [item for item in manifest["documents"]]
|
||||
|
||||
strategy = adapter.select_bundle_claim_strategy(
|
||||
root,
|
||||
documents,
|
||||
limit=6,
|
||||
)
|
||||
assert strategy == "balanced"
|
||||
|
||||
|
||||
def test_doclift_auto_strategy_returns_available_track_on_synthetic_fixture() -> None:
|
||||
root = _fixture_root()
|
||||
manifest = json.loads((root / "manifest.json").read_text(encoding="utf-8"))
|
||||
adapter = DocliftBundleSourceAdapter()
|
||||
documents = {str(item["document_id"]): item for item in manifest["documents"]}
|
||||
|
||||
strategy = adapter.select_document_claim_strategy(root, documents["drift-essay"], limit=6)
|
||||
assert strategy in {"conservative", "balanced", "broad"}
|
||||
|
|
|
|||
|
|
@ -350,3 +350,6 @@ def test_doclift_bundle_import_derives_claims_from_prose_when_chunks_are_body_on
|
|||
|
||||
assert any("Random genetic drift can dominate allele-frequency change in small populations." in text for text in claim_texts)
|
||||
assert not any(text == "Drift Essay is a web_article in the imported doclift bundle." for text in claim_texts)
|
||||
derived_observations = [item for item in result.observations if item["observation_id"].startswith("obs_doclift_1_derived_")]
|
||||
assert derived_observations
|
||||
assert derived_observations[0]["metadata"]["claim_strategy"] in {"conservative", "balanced", "broad"}
|
||||
|
|
|
|||
Loading…
Reference in New Issue