Add adaptive doclift claim selection
This commit is contained in:
parent
86a4db01a4
commit
a9974110e2
|
|
@ -10,6 +10,7 @@ from .base import DiscoveredImportSource, StructuredImportRows, register_source_
|
||||||
|
|
||||||
class DocliftBundleSourceAdapter:
|
class DocliftBundleSourceAdapter:
|
||||||
name = "doclift_bundle"
|
name = "doclift_bundle"
|
||||||
|
_TRACK_STRATEGIES = ("conservative", "balanced", "broad")
|
||||||
|
|
||||||
_PROSE_SENTENCE_SPLIT = re.compile(r"(?<=[.!?])\s+")
|
_PROSE_SENTENCE_SPLIT = re.compile(r"(?<=[.!?])\s+")
|
||||||
_METADATA_PREFIXES = (
|
_METADATA_PREFIXES = (
|
||||||
|
|
@ -139,6 +140,81 @@ class DocliftBundleSourceAdapter:
|
||||||
return (cue_hits - penalties, -abs(len(cleaned) - 140))
|
return (cue_hits - penalties, -abs(len(cleaned) - 140))
|
||||||
return (cue_hits - penalties, -len(cleaned))
|
return (cue_hits - penalties, -len(cleaned))
|
||||||
|
|
||||||
|
def _meta_noise_count(self, claims: list[str]) -> int:
|
||||||
|
count = 0
|
||||||
|
for claim in claims:
|
||||||
|
lowered = claim.lower()
|
||||||
|
if any(lowered.startswith(prefix) for prefix in self._METADATA_PREFIXES):
|
||||||
|
count += 1
|
||||||
|
if lowered.startswith(("the latest issue", "this is a brief introduction", "mistakes permeate")):
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def _truncated_claim_count(self, claims: list[str]) -> int:
|
||||||
|
truncated = 0
|
||||||
|
bad_endings = (" and", " of", " to", " in", " by", " with", " that", " this")
|
||||||
|
for claim in claims:
|
||||||
|
stripped = claim.strip()
|
||||||
|
lowered = stripped.lower()
|
||||||
|
if not stripped.endswith((".", "!", "?", '"')):
|
||||||
|
truncated += 1
|
||||||
|
continue
|
||||||
|
if any(lowered.endswith(ending) for ending in bad_endings):
|
||||||
|
truncated += 1
|
||||||
|
return truncated
|
||||||
|
|
||||||
|
def _redundancy_penalty(self, claims: list[str]) -> float:
|
||||||
|
penalty = 0.0
|
||||||
|
normalized = [set(re.findall(r"[a-z0-9]+", claim.lower())) for claim in claims]
|
||||||
|
for i, left in enumerate(normalized):
|
||||||
|
for right in normalized[i + 1 :]:
|
||||||
|
if not left or not right:
|
||||||
|
continue
|
||||||
|
overlap = len(left & right) / len(left | right)
|
||||||
|
if overlap >= 0.6:
|
||||||
|
penalty += overlap
|
||||||
|
return penalty
|
||||||
|
|
||||||
|
def _score_claim_batch(self, claims: list[str], *, strategy: str) -> tuple[float, float, float, float, float]:
|
||||||
|
if not claims:
|
||||||
|
return (-999.0, -999.0, -999.0, -999.0, -999.0)
|
||||||
|
cue_score = 0.0
|
||||||
|
for claim in claims:
|
||||||
|
cue_score += self._claim_priority(claim, strategy=strategy)[0]
|
||||||
|
avg_cue_score = cue_score / len(claims)
|
||||||
|
meta_penalty = float(self._meta_noise_count(claims))
|
||||||
|
truncated_penalty = float(self._truncated_claim_count(claims))
|
||||||
|
redundancy_penalty = self._redundancy_penalty(claims)
|
||||||
|
avg_length = sum(len(claim) for claim in claims) / len(claims)
|
||||||
|
target_count = {
|
||||||
|
"conservative": 3,
|
||||||
|
"balanced": 5,
|
||||||
|
"broad": 6,
|
||||||
|
}.get(strategy, 4)
|
||||||
|
target_length = {
|
||||||
|
"conservative": 150,
|
||||||
|
"balanced": 100,
|
||||||
|
"broad": 180,
|
||||||
|
}.get(strategy, 140)
|
||||||
|
count_penalty = abs(len(claims) - target_count)
|
||||||
|
length_penalty = abs(avg_length - target_length) / 100.0
|
||||||
|
strategy_bias = {
|
||||||
|
"balanced": 0.35,
|
||||||
|
"conservative": 0.0,
|
||||||
|
"broad": -0.15,
|
||||||
|
}.get(strategy, 0.0)
|
||||||
|
return (
|
||||||
|
avg_cue_score + strategy_bias - meta_penalty - truncated_penalty - redundancy_penalty,
|
||||||
|
-count_penalty,
|
||||||
|
-(length_penalty + truncated_penalty + redundancy_penalty),
|
||||||
|
-meta_penalty,
|
||||||
|
-abs(avg_length - target_length),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _strategy_total_score(self, score: tuple[float, float, float, float, float]) -> float:
|
||||||
|
primary, count_fit, quality_fit, meta_fit, length_fit = score
|
||||||
|
return (0.5 * primary) + (0.8 * count_fit) + (1.2 * quality_fit) + (0.5 * meta_fit) + (0.03 * length_fit)
|
||||||
|
|
||||||
def _extract_claim_sentences_from_paragraphs(
|
def _extract_claim_sentences_from_paragraphs(
|
||||||
self,
|
self,
|
||||||
paragraphs: list[str],
|
paragraphs: list[str],
|
||||||
|
|
@ -224,10 +300,46 @@ class DocliftBundleSourceAdapter:
|
||||||
strategy: str = "conservative",
|
strategy: str = "conservative",
|
||||||
limit: int = 4,
|
limit: int = 4,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
|
if strategy == "auto":
|
||||||
|
selected = self.select_document_claim_strategy(base, document, limit=limit)
|
||||||
|
return self.extract_document_claims(base, document, strategy=selected, limit=limit)
|
||||||
markdown_text = self._load_markdown_text(base, document)
|
markdown_text = self._load_markdown_text(base, document)
|
||||||
title = str(document.get("title") or "")
|
title = str(document.get("title") or "")
|
||||||
return self._extract_claim_sentences(markdown_text, title=title, limit=limit, strategy=strategy)
|
return self._extract_claim_sentences(markdown_text, title=title, limit=limit, strategy=strategy)
|
||||||
|
|
||||||
|
def select_document_claim_strategy(self, base: Path, document: dict, *, limit: int = 4) -> str:
|
||||||
|
candidates: list[tuple[str, tuple[float, float, float, float, float]]] = []
|
||||||
|
for strategy in self._TRACK_STRATEGIES:
|
||||||
|
claims = self.extract_document_claims(base, document, strategy=strategy, limit=limit)
|
||||||
|
candidates.append((strategy, self._score_claim_batch(claims, strategy=strategy)))
|
||||||
|
return max(candidates, key=lambda item: self._strategy_total_score(item[1]))[0]
|
||||||
|
|
||||||
|
def select_bundle_claim_strategy(self, base: Path, documents: list[dict], *, limit: int = 4) -> str:
|
||||||
|
candidate_docs = [
|
||||||
|
document
|
||||||
|
for document in documents
|
||||||
|
if str(document.get("document_kind") or "").strip() in {"web_article", "document"}
|
||||||
|
]
|
||||||
|
if not candidate_docs:
|
||||||
|
return "conservative"
|
||||||
|
scored: list[tuple[str, tuple[float, float, float, float, float]]] = []
|
||||||
|
for strategy in self._TRACK_STRATEGIES:
|
||||||
|
total = [0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
used = 0
|
||||||
|
for document in candidate_docs[:6]:
|
||||||
|
claims = self.extract_document_claims(base, document, strategy=strategy, limit=limit)
|
||||||
|
score = self._score_claim_batch(claims, strategy=strategy)
|
||||||
|
if score[0] <= -999.0:
|
||||||
|
continue
|
||||||
|
total = [left + right for left, right in zip(total, score)]
|
||||||
|
used += 1
|
||||||
|
if used:
|
||||||
|
averaged = tuple(value / used for value in total)
|
||||||
|
else:
|
||||||
|
averaged = (-999.0, -999.0, -999.0, -999.0, -999.0)
|
||||||
|
scored.append((strategy, averaged))
|
||||||
|
return max(scored, key=lambda item: self._strategy_total_score(item[1]))[0]
|
||||||
|
|
||||||
def discover(self, root: str | Path) -> list[DiscoveredImportSource]:
|
def discover(self, root: str | Path) -> list[DiscoveredImportSource]:
|
||||||
base = Path(root)
|
base = Path(root)
|
||||||
rows: list[DiscoveredImportSource] = []
|
rows: list[DiscoveredImportSource] = []
|
||||||
|
|
@ -285,6 +397,7 @@ class DocliftBundleSourceAdapter:
|
||||||
artifact_by_path[source.relative_path] = artifact_id
|
artifact_by_path[source.relative_path] = artifact_id
|
||||||
|
|
||||||
documents = [item for item in manifest.get("documents", []) if isinstance(item, dict)]
|
documents = [item for item in manifest.get("documents", []) if isinstance(item, dict)]
|
||||||
|
bundle_claim_strategy = self.select_bundle_claim_strategy(base, documents)
|
||||||
previous_concept_id: str | None = None
|
previous_concept_id: str | None = None
|
||||||
for index, document in enumerate(documents, start=1):
|
for index, document in enumerate(documents, start=1):
|
||||||
title = str(document.get("title") or f"Document {index}")
|
title = str(document.get("title") or f"Document {index}")
|
||||||
|
|
@ -402,7 +515,8 @@ class DocliftBundleSourceAdapter:
|
||||||
)
|
)
|
||||||
document_claim_ids.append(claim_id)
|
document_claim_ids.append(claim_id)
|
||||||
if not document_claim_ids and str(document.get("document_kind") or "").strip() in {"web_article", "document"}:
|
if not document_claim_ids and str(document.get("document_kind") or "").strip() in {"web_article", "document"}:
|
||||||
for derived_index, claim_text in enumerate(self.extract_document_claims(base, document, strategy="conservative"), start=1):
|
selected_strategy = bundle_claim_strategy
|
||||||
|
for derived_index, claim_text in enumerate(self.extract_document_claims(base, document, strategy=selected_strategy), start=1):
|
||||||
derived_observation_id = f"obs_doclift_{index}_derived_{derived_index}"
|
derived_observation_id = f"obs_doclift_{index}_derived_{derived_index}"
|
||||||
claim_id = f"clm_doclift_{index}_derived_{derived_index}"
|
claim_id = f"clm_doclift_{index}_derived_{derived_index}"
|
||||||
observation_rows.append(
|
observation_rows.append(
|
||||||
|
|
@ -420,6 +534,7 @@ class DocliftBundleSourceAdapter:
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"source_path_kind": source_path_kind,
|
"source_path_kind": source_path_kind,
|
||||||
"derived_from": "markdown_sentence",
|
"derived_from": "markdown_sentence",
|
||||||
|
"claim_strategy": selected_strategy,
|
||||||
},
|
},
|
||||||
"grounding_status": "grounded",
|
"grounding_status": "grounded",
|
||||||
"support_kind": "direct_source",
|
"support_kind": "direct_source",
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from groundrecall.doclift_claim_tournament import evaluate_doclift_claim_tracks
|
from groundrecall.doclift_claim_tournament import evaluate_doclift_claim_tracks
|
||||||
|
from groundrecall.groundrecall_source_adapters.doclift_bundle import DocliftBundleSourceAdapter
|
||||||
|
|
||||||
|
|
||||||
def _fixture_root() -> Path:
|
def _fixture_root() -> Path:
|
||||||
|
|
@ -50,3 +52,27 @@ def test_doclift_claim_tournament_runs_on_real_corpus_fixture() -> None:
|
||||||
assert tracks["balanced"]["matches"] >= 1
|
assert tracks["balanced"]["matches"] >= 1
|
||||||
assert tracks["balanced"]["recall"] >= tracks["broad"]["recall"]
|
assert tracks["balanced"]["recall"] >= tracks["broad"]["recall"]
|
||||||
assert tracks["balanced"]["f1"] >= tracks["broad"]["f1"]
|
assert tracks["balanced"]["f1"] >= tracks["broad"]["f1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_doclift_auto_bundle_strategy_prefers_balanced_on_real_corpus_fixture() -> None:
|
||||||
|
root = _pilot_fixture_root()
|
||||||
|
manifest = json.loads((root / "manifest.json").read_text(encoding="utf-8"))
|
||||||
|
adapter = DocliftBundleSourceAdapter()
|
||||||
|
documents = [item for item in manifest["documents"]]
|
||||||
|
|
||||||
|
strategy = adapter.select_bundle_claim_strategy(
|
||||||
|
root,
|
||||||
|
documents,
|
||||||
|
limit=6,
|
||||||
|
)
|
||||||
|
assert strategy == "balanced"
|
||||||
|
|
||||||
|
|
||||||
|
def test_doclift_auto_strategy_returns_available_track_on_synthetic_fixture() -> None:
|
||||||
|
root = _fixture_root()
|
||||||
|
manifest = json.loads((root / "manifest.json").read_text(encoding="utf-8"))
|
||||||
|
adapter = DocliftBundleSourceAdapter()
|
||||||
|
documents = {str(item["document_id"]): item for item in manifest["documents"]}
|
||||||
|
|
||||||
|
strategy = adapter.select_document_claim_strategy(root, documents["drift-essay"], limit=6)
|
||||||
|
assert strategy in {"conservative", "balanced", "broad"}
|
||||||
|
|
|
||||||
|
|
@ -350,3 +350,6 @@ def test_doclift_bundle_import_derives_claims_from_prose_when_chunks_are_body_on
|
||||||
|
|
||||||
assert any("Random genetic drift can dominate allele-frequency change in small populations." in text for text in claim_texts)
|
assert any("Random genetic drift can dominate allele-frequency change in small populations." in text for text in claim_texts)
|
||||||
assert not any(text == "Drift Essay is a web_article in the imported doclift bundle." for text in claim_texts)
|
assert not any(text == "Drift Essay is a web_article in the imported doclift bundle." for text in claim_texts)
|
||||||
|
derived_observations = [item for item in result.observations if item["observation_id"].startswith("obs_doclift_1_derived_")]
|
||||||
|
assert derived_observations
|
||||||
|
assert derived_observations[0]["metadata"]["claim_strategy"] in {"conservative", "balanced", "broad"}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue