Add source client and fixture cache support

This commit is contained in:
welsberr 2026-03-19 21:10:38 -04:00
parent 10280a6229
commit d4d31b371f
7 changed files with 128 additions and 31 deletions

View File

@ -51,6 +51,7 @@ The initial repo includes:
- identifier-first metadata resolution for DOI, DBLP, and arXiv-backed entries; - identifier-first metadata resolution for DOI, DBLP, and arXiv-backed entries;
- local citation-graph traversal over stored `cites`, `cited_by`, and `crossref` edges; - local citation-graph traversal over stored `cites`, `cited_by`, and `crossref` edges;
- Crossref-backed graph expansion that materializes draft referenced works and edge provenance; - Crossref-backed graph expansion that materializes draft referenced works and edge provenance;
- a dedicated source-client layer with fixture/cache support for live-source development;
- normalized tables for entries, creators, identifiers, and citation relations; - normalized tables for entries, creators, identifiers, and citation relations;
- full-text-search-ready indexing over title, abstract, and fulltext when SQLite FTS5 is available; - full-text-search-ready indexing over title, abstract, and fulltext when SQLite FTS5 is available;
- tests covering parsing, ingestion, relation storage, and search. - tests covering parsing, ingestion, relation storage, and search.
@ -76,6 +77,7 @@ cd citegeist
python3 -m virtualenv --always-copy .venv python3 -m virtualenv --always-copy .venv
.venv/bin/pip install -e . .venv/bin/pip install -e .
.venv/bin/pip install pytest .venv/bin/pip install pytest
mkdir -p .cache/citegeist
PYTHONPATH=src .venv/bin/python - <<'PY' PYTHONPATH=src .venv/bin/python - <<'PY'
from citegeist import BibliographyStore from citegeist import BibliographyStore
@ -120,6 +122,8 @@ PYTHONPATH=src .venv/bin/python -m citegeist --db library.sqlite3 expand smith20
PYTHONPATH=src .venv/bin/python -m citegeist --db library.sqlite3 export --output reviewed.bib PYTHONPATH=src .venv/bin/python -m citegeist --db library.sqlite3 export --output reviewed.bib
``` ```
For live-source development, prefer fixture-backed or cache-backed source clients so resolver and expansion work can be exercised repeatedly without re-hitting upstream APIs on every run.
## Near-Term Priorities ## Near-Term Priorities
- stronger plaintext extraction coverage for more citation styles; - stronger plaintext extraction coverage for more citation styles;

View File

@ -2,6 +2,7 @@ from .bibtex import BibEntry, parse_bibtex
from .expand import CrossrefExpander from .expand import CrossrefExpander
from .extract import extract_references from .extract import extract_references
from .resolve import MetadataResolver, merge_entries from .resolve import MetadataResolver, merge_entries
from .sources import SourceClient
from .storage import BibliographyStore from .storage import BibliographyStore
__all__ = [ __all__ = [
@ -9,6 +10,7 @@ __all__ = [
"BibliographyStore", "BibliographyStore",
"CrossrefExpander", "CrossrefExpander",
"MetadataResolver", "MetadataResolver",
"SourceClient",
"extract_references", "extract_references",
"merge_entries", "merge_entries",
"parse_bibtex", "parse_bibtex",

View File

@ -34,7 +34,7 @@ class CrossrefExpander:
if not doi: if not doi:
return [] return []
payload = self.resolver._get_json( # noqa: SLF001 payload = self.resolver.source_client.get_json(
f"https://api.crossref.org/works/{doi}?mailto=welsberr@gmail.com" f"https://api.crossref.org/works/{doi}?mailto=welsberr@gmail.com"
) )
references = payload.get("message", {}).get("reference", []) references = payload.get("message", {}).get("reference", [])

View File

@ -1,12 +1,11 @@
from __future__ import annotations from __future__ import annotations
import json
import urllib.parse import urllib.parse
import urllib.request
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from dataclasses import dataclass from dataclasses import dataclass
from .bibtex import BibEntry, parse_bibtex from .bibtex import BibEntry, parse_bibtex
from .sources import SourceClient
@dataclass(slots=True) @dataclass(slots=True)
@ -17,8 +16,13 @@ class Resolution:
class MetadataResolver: class MetadataResolver:
def __init__(self, user_agent: str = "citegeist/0.1 (local research tool)") -> None: def __init__(
self,
user_agent: str = "citegeist/0.1 (local research tool)",
source_client: SourceClient | None = None,
) -> None:
self.user_agent = user_agent self.user_agent = user_agent
self.source_client = source_client or SourceClient(user_agent=user_agent)
def resolve_entry(self, entry: BibEntry) -> Resolution | None: def resolve_entry(self, entry: BibEntry) -> Resolution | None:
if doi := entry.fields.get("doi"): if doi := entry.fields.get("doi"):
@ -40,7 +44,7 @@ class MetadataResolver:
def resolve_doi(self, doi: str) -> Resolution | None: def resolve_doi(self, doi: str) -> Resolution | None:
encoded = urllib.parse.quote(doi, safe="") encoded = urllib.parse.quote(doi, safe="")
payload = self._get_json(f"https://api.crossref.org/works/{encoded}") payload = self.source_client.get_json(f"https://api.crossref.org/works/{encoded}")
message = payload.get("message", {}) message = payload.get("message", {})
if not message: if not message:
return None return None
@ -52,13 +56,13 @@ class MetadataResolver:
def search_crossref(self, title: str, limit: int = 5) -> list[BibEntry]: def search_crossref(self, title: str, limit: int = 5) -> list[BibEntry]:
query = urllib.parse.urlencode({"query.title": title, "rows": limit}) query = urllib.parse.urlencode({"query.title": title, "rows": limit})
payload = self._get_json(f"https://api.crossref.org/works?{query}") payload = self.source_client.get_json(f"https://api.crossref.org/works?{query}")
items = payload.get("message", {}).get("items", []) items = payload.get("message", {}).get("items", [])
return [_crossref_message_to_entry(item) for item in items] return [_crossref_message_to_entry(item) for item in items]
def resolve_dblp(self, dblp_key: str) -> Resolution | None: def resolve_dblp(self, dblp_key: str) -> Resolution | None:
encoded_key = urllib.parse.quote(dblp_key, safe="/:") encoded_key = urllib.parse.quote(dblp_key, safe="/:")
text = self._get_text(f"https://dblp.org/rec/{encoded_key}.bib") text = self.source_client.get_text(f"https://dblp.org/rec/{encoded_key}.bib")
entries = parse_bibtex(text) entries = parse_bibtex(text)
if not entries: if not entries:
return None return None
@ -70,7 +74,7 @@ class MetadataResolver:
def search_dblp(self, query_text: str, limit: int = 5) -> list[BibEntry]: def search_dblp(self, query_text: str, limit: int = 5) -> list[BibEntry]:
query = urllib.parse.urlencode({"q": query_text, "format": "json", "h": limit}) query = urllib.parse.urlencode({"q": query_text, "format": "json", "h": limit})
payload = self._get_json(f"https://dblp.org/search/publ/api?{query}") payload = self.source_client.get_json(f"https://dblp.org/search/publ/api?{query}")
hits = payload.get("result", {}).get("hits", {}).get("hit", []) hits = payload.get("result", {}).get("hits", {}).get("hit", [])
if isinstance(hits, dict): if isinstance(hits, dict):
hits = [hits] hits = [hits]
@ -87,7 +91,7 @@ class MetadataResolver:
def resolve_arxiv(self, arxiv_id: str) -> Resolution | None: def resolve_arxiv(self, arxiv_id: str) -> Resolution | None:
query = urllib.parse.urlencode({"id_list": arxiv_id}) query = urllib.parse.urlencode({"id_list": arxiv_id})
root = self._get_xml(f"https://export.arxiv.org/api/query?{query}") root = self.source_client.get_xml(f"https://export.arxiv.org/api/query?{query}")
namespace = {"atom": "http://www.w3.org/2005/Atom"} namespace = {"atom": "http://www.w3.org/2005/Atom"}
entry = root.find("atom:entry", namespace) entry = root.find("atom:entry", namespace)
if entry is None: if entry is None:
@ -98,27 +102,6 @@ class MetadataResolver:
source_label=f"arxiv:id:{arxiv_id}", source_label=f"arxiv:id:{arxiv_id}",
) )
def _get_json(self, url: str) -> dict:
with urllib.request.urlopen(self._request(url)) as response:
return json.load(response)
def _get_text(self, url: str) -> str:
with urllib.request.urlopen(self._request(url)) as response:
return response.read().decode("utf-8")
def _get_xml(self, url: str) -> ET.Element:
with urllib.request.urlopen(self._request(url)) as response:
return ET.fromstring(response.read())
def _request(self, url: str) -> urllib.request.Request:
return urllib.request.Request(
url,
headers={
"User-Agent": self.user_agent,
},
)
def merge_entries(base: BibEntry, resolved: BibEntry) -> BibEntry: def merge_entries(base: BibEntry, resolved: BibEntry) -> BibEntry:
merged_fields = dict(base.fields) merged_fields = dict(base.fields)
for key, value in resolved.fields.items(): for key, value in resolved.fields.items():

78
src/citegeist/sources.py Normal file
View File

@ -0,0 +1,78 @@
from __future__ import annotations
import hashlib
import json
import urllib.request
import xml.etree.ElementTree as ET
from pathlib import Path
class SourceClient:
def __init__(
self,
user_agent: str = "citegeist/0.1 (local research tool)",
cache_dir: str | Path | None = None,
fixtures_dir: str | Path | None = None,
) -> None:
self.user_agent = user_agent
self.cache_dir = Path(cache_dir) if cache_dir else None
self.fixtures_dir = Path(fixtures_dir) if fixtures_dir else None
def get_json(self, url: str) -> dict:
cached = self._read_cached(url, "json")
if cached is not None:
return json.loads(cached)
payload = self._fetch_bytes(url)
self._write_cache(url, "json", payload)
return json.loads(payload.decode("utf-8"))
def get_text(self, url: str) -> str:
cached = self._read_cached(url, "txt")
if cached is not None:
return cached.decode("utf-8")
payload = self._fetch_bytes(url)
self._write_cache(url, "txt", payload)
return payload.decode("utf-8")
def get_xml(self, url: str) -> ET.Element:
cached = self._read_cached(url, "xml")
if cached is not None:
return ET.fromstring(cached)
payload = self._fetch_bytes(url)
self._write_cache(url, "xml", payload)
return ET.fromstring(payload)
def _fetch_bytes(self, url: str) -> bytes:
with urllib.request.urlopen(self._request(url)) as response:
return response.read()
def _request(self, url: str) -> urllib.request.Request:
return urllib.request.Request(
url,
headers={
"User-Agent": self.user_agent,
},
)
def _cache_key(self, url: str, suffix: str) -> str:
digest = hashlib.sha1(url.encode("utf-8")).hexdigest()
return f"{digest}.{suffix}"
def _read_cached(self, url: str, suffix: str) -> bytes | None:
for root in (self.fixtures_dir, self.cache_dir):
if root is None:
continue
path = root / self._cache_key(url, suffix)
if path.exists():
return path.read_bytes()
return None
def _write_cache(self, url: str, suffix: str, payload: bytes) -> None:
if self.cache_dir is None:
return
self.cache_dir.mkdir(parents=True, exist_ok=True)
path = self.cache_dir / self._cache_key(url, suffix)
path.write_bytes(payload)

View File

@ -35,7 +35,7 @@ def test_crossref_expander_creates_draft_nodes_and_relations():
) )
expander = CrossrefExpander() expander = CrossrefExpander()
expander.resolver._get_json = lambda _url: { # type: ignore[method-assign] expander.resolver.source_client.get_json = lambda _url: { # type: ignore[method-assign]
"message": { "message": {
"reference": [ "reference": [
{ {

30
tests/test_sources.py Normal file
View File

@ -0,0 +1,30 @@
from pathlib import Path
from citegeist.sources import SourceClient
def test_source_client_reads_fixture_before_network(tmp_path: Path):
fixtures_dir = tmp_path / "fixtures"
fixtures_dir.mkdir()
client = SourceClient(cache_dir=tmp_path / "cache", fixtures_dir=fixtures_dir)
url = "https://api.crossref.org/works/10.1000/example"
fixture_path = fixtures_dir / client._cache_key(url, "json") # noqa: SLF001
fixture_path.write_text('{"message": {"DOI": "10.1000/example"}}', encoding="utf-8")
payload = client.get_json(url)
assert payload["message"]["DOI"] == "10.1000/example"
def test_source_client_writes_cache_after_fetch(tmp_path: Path):
cache_dir = tmp_path / "cache"
client = SourceClient(cache_dir=cache_dir)
url = "https://example.org/test"
client._fetch_bytes = lambda _url: b'{"ok": true}' # type: ignore[method-assign]
payload = client.get_json(url)
assert payload["ok"] is True
assert any(cache_dir.iterdir())