diff --git a/README.md b/README.md index 0219a33..7bcecbd 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,7 @@ The initial repo includes: - identifier-first metadata resolution for DOI, DBLP, and arXiv-backed entries; - local citation-graph traversal over stored `cites`, `cited_by`, and `crossref` edges; - Crossref-backed graph expansion that materializes draft referenced works and edge provenance; +- a dedicated source-client layer with fixture/cache support for live-source development; - normalized tables for entries, creators, identifiers, and citation relations; - full-text-search-ready indexing over title, abstract, and fulltext when SQLite FTS5 is available; - tests covering parsing, ingestion, relation storage, and search. @@ -76,6 +77,7 @@ cd citegeist python3 -m virtualenv --always-copy .venv .venv/bin/pip install -e . .venv/bin/pip install pytest +mkdir -p .cache/citegeist PYTHONPATH=src .venv/bin/python - <<'PY' from citegeist import BibliographyStore @@ -120,6 +122,8 @@ PYTHONPATH=src .venv/bin/python -m citegeist --db library.sqlite3 expand smith20 PYTHONPATH=src .venv/bin/python -m citegeist --db library.sqlite3 export --output reviewed.bib ``` +For live-source development, prefer fixture-backed or cache-backed source clients so resolver and expansion work can be exercised repeatedly without re-hitting upstream APIs on every run. + ## Near-Term Priorities - stronger plaintext extraction coverage for more citation styles; diff --git a/src/citegeist/__init__.py b/src/citegeist/__init__.py index 4022f8e..d52548d 100644 --- a/src/citegeist/__init__.py +++ b/src/citegeist/__init__.py @@ -2,6 +2,7 @@ from .bibtex import BibEntry, parse_bibtex from .expand import CrossrefExpander from .extract import extract_references from .resolve import MetadataResolver, merge_entries +from .sources import SourceClient from .storage import BibliographyStore __all__ = [ @@ -9,6 +10,7 @@ __all__ = [ "BibliographyStore", "CrossrefExpander", "MetadataResolver", + "SourceClient", "extract_references", "merge_entries", "parse_bibtex", diff --git a/src/citegeist/expand.py b/src/citegeist/expand.py index 82670e8..52159a5 100644 --- a/src/citegeist/expand.py +++ b/src/citegeist/expand.py @@ -34,7 +34,7 @@ class CrossrefExpander: if not doi: return [] - payload = self.resolver._get_json( # noqa: SLF001 + payload = self.resolver.source_client.get_json( f"https://api.crossref.org/works/{doi}?mailto=welsberr@gmail.com" ) references = payload.get("message", {}).get("reference", []) diff --git a/src/citegeist/resolve.py b/src/citegeist/resolve.py index 8f6da4c..f37d282 100644 --- a/src/citegeist/resolve.py +++ b/src/citegeist/resolve.py @@ -1,12 +1,11 @@ from __future__ import annotations -import json import urllib.parse -import urllib.request import xml.etree.ElementTree as ET from dataclasses import dataclass from .bibtex import BibEntry, parse_bibtex +from .sources import SourceClient @dataclass(slots=True) @@ -17,8 +16,13 @@ class Resolution: class MetadataResolver: - def __init__(self, user_agent: str = "citegeist/0.1 (local research tool)") -> None: + def __init__( + self, + user_agent: str = "citegeist/0.1 (local research tool)", + source_client: SourceClient | None = None, + ) -> None: self.user_agent = user_agent + self.source_client = source_client or SourceClient(user_agent=user_agent) def resolve_entry(self, entry: BibEntry) -> Resolution | None: if doi := entry.fields.get("doi"): @@ -40,7 +44,7 @@ class MetadataResolver: def resolve_doi(self, doi: str) -> Resolution | None: encoded = urllib.parse.quote(doi, safe="") - payload = self._get_json(f"https://api.crossref.org/works/{encoded}") + payload = self.source_client.get_json(f"https://api.crossref.org/works/{encoded}") message = payload.get("message", {}) if not message: return None @@ -52,13 +56,13 @@ class MetadataResolver: def search_crossref(self, title: str, limit: int = 5) -> list[BibEntry]: query = urllib.parse.urlencode({"query.title": title, "rows": limit}) - payload = self._get_json(f"https://api.crossref.org/works?{query}") + payload = self.source_client.get_json(f"https://api.crossref.org/works?{query}") items = payload.get("message", {}).get("items", []) return [_crossref_message_to_entry(item) for item in items] def resolve_dblp(self, dblp_key: str) -> Resolution | None: encoded_key = urllib.parse.quote(dblp_key, safe="/:") - text = self._get_text(f"https://dblp.org/rec/{encoded_key}.bib") + text = self.source_client.get_text(f"https://dblp.org/rec/{encoded_key}.bib") entries = parse_bibtex(text) if not entries: return None @@ -70,7 +74,7 @@ class MetadataResolver: def search_dblp(self, query_text: str, limit: int = 5) -> list[BibEntry]: query = urllib.parse.urlencode({"q": query_text, "format": "json", "h": limit}) - payload = self._get_json(f"https://dblp.org/search/publ/api?{query}") + payload = self.source_client.get_json(f"https://dblp.org/search/publ/api?{query}") hits = payload.get("result", {}).get("hits", {}).get("hit", []) if isinstance(hits, dict): hits = [hits] @@ -87,7 +91,7 @@ class MetadataResolver: def resolve_arxiv(self, arxiv_id: str) -> Resolution | None: query = urllib.parse.urlencode({"id_list": arxiv_id}) - root = self._get_xml(f"https://export.arxiv.org/api/query?{query}") + root = self.source_client.get_xml(f"https://export.arxiv.org/api/query?{query}") namespace = {"atom": "http://www.w3.org/2005/Atom"} entry = root.find("atom:entry", namespace) if entry is None: @@ -98,27 +102,6 @@ class MetadataResolver: source_label=f"arxiv:id:{arxiv_id}", ) - def _get_json(self, url: str) -> dict: - with urllib.request.urlopen(self._request(url)) as response: - return json.load(response) - - def _get_text(self, url: str) -> str: - with urllib.request.urlopen(self._request(url)) as response: - return response.read().decode("utf-8") - - def _get_xml(self, url: str) -> ET.Element: - with urllib.request.urlopen(self._request(url)) as response: - return ET.fromstring(response.read()) - - def _request(self, url: str) -> urllib.request.Request: - return urllib.request.Request( - url, - headers={ - "User-Agent": self.user_agent, - }, - ) - - def merge_entries(base: BibEntry, resolved: BibEntry) -> BibEntry: merged_fields = dict(base.fields) for key, value in resolved.fields.items(): diff --git a/src/citegeist/sources.py b/src/citegeist/sources.py new file mode 100644 index 0000000..63bd23d --- /dev/null +++ b/src/citegeist/sources.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import hashlib +import json +import urllib.request +import xml.etree.ElementTree as ET +from pathlib import Path + + +class SourceClient: + def __init__( + self, + user_agent: str = "citegeist/0.1 (local research tool)", + cache_dir: str | Path | None = None, + fixtures_dir: str | Path | None = None, + ) -> None: + self.user_agent = user_agent + self.cache_dir = Path(cache_dir) if cache_dir else None + self.fixtures_dir = Path(fixtures_dir) if fixtures_dir else None + + def get_json(self, url: str) -> dict: + cached = self._read_cached(url, "json") + if cached is not None: + return json.loads(cached) + + payload = self._fetch_bytes(url) + self._write_cache(url, "json", payload) + return json.loads(payload.decode("utf-8")) + + def get_text(self, url: str) -> str: + cached = self._read_cached(url, "txt") + if cached is not None: + return cached.decode("utf-8") + + payload = self._fetch_bytes(url) + self._write_cache(url, "txt", payload) + return payload.decode("utf-8") + + def get_xml(self, url: str) -> ET.Element: + cached = self._read_cached(url, "xml") + if cached is not None: + return ET.fromstring(cached) + + payload = self._fetch_bytes(url) + self._write_cache(url, "xml", payload) + return ET.fromstring(payload) + + def _fetch_bytes(self, url: str) -> bytes: + with urllib.request.urlopen(self._request(url)) as response: + return response.read() + + def _request(self, url: str) -> urllib.request.Request: + return urllib.request.Request( + url, + headers={ + "User-Agent": self.user_agent, + }, + ) + + def _cache_key(self, url: str, suffix: str) -> str: + digest = hashlib.sha1(url.encode("utf-8")).hexdigest() + return f"{digest}.{suffix}" + + def _read_cached(self, url: str, suffix: str) -> bytes | None: + for root in (self.fixtures_dir, self.cache_dir): + if root is None: + continue + path = root / self._cache_key(url, suffix) + if path.exists(): + return path.read_bytes() + return None + + def _write_cache(self, url: str, suffix: str, payload: bytes) -> None: + if self.cache_dir is None: + return + self.cache_dir.mkdir(parents=True, exist_ok=True) + path = self.cache_dir / self._cache_key(url, suffix) + path.write_bytes(payload) diff --git a/tests/test_expand.py b/tests/test_expand.py index 365ba2c..4f6425b 100644 --- a/tests/test_expand.py +++ b/tests/test_expand.py @@ -35,7 +35,7 @@ def test_crossref_expander_creates_draft_nodes_and_relations(): ) expander = CrossrefExpander() - expander.resolver._get_json = lambda _url: { # type: ignore[method-assign] + expander.resolver.source_client.get_json = lambda _url: { # type: ignore[method-assign] "message": { "reference": [ { diff --git a/tests/test_sources.py b/tests/test_sources.py new file mode 100644 index 0000000..fea995a --- /dev/null +++ b/tests/test_sources.py @@ -0,0 +1,30 @@ +from pathlib import Path + +from citegeist.sources import SourceClient + + +def test_source_client_reads_fixture_before_network(tmp_path: Path): + fixtures_dir = tmp_path / "fixtures" + fixtures_dir.mkdir() + + client = SourceClient(cache_dir=tmp_path / "cache", fixtures_dir=fixtures_dir) + url = "https://api.crossref.org/works/10.1000/example" + fixture_path = fixtures_dir / client._cache_key(url, "json") # noqa: SLF001 + fixture_path.write_text('{"message": {"DOI": "10.1000/example"}}', encoding="utf-8") + + payload = client.get_json(url) + + assert payload["message"]["DOI"] == "10.1000/example" + + +def test_source_client_writes_cache_after_fetch(tmp_path: Path): + cache_dir = tmp_path / "cache" + client = SourceClient(cache_dir=cache_dir) + url = "https://example.org/test" + + client._fetch_bytes = lambda _url: b'{"ok": true}' # type: ignore[method-assign] + + payload = client.get_json(url) + + assert payload["ok"] is True + assert any(cache_dir.iterdir())