diff --git a/.gitignore b/.gitignore index e41471e..e34b1cf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__/ .pytest_cache/ .venv/ +.cache/ *.pyc library.sqlite3 diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..8b20c95 --- /dev/null +++ b/Makefile @@ -0,0 +1,13 @@ +PYTHONPATH_SRC=PYTHONPATH=src +VENV_PYTHON=.venv/bin/python + +.PHONY: test test-live live-smoke + +test: + $(PYTHONPATH_SRC) $(VENV_PYTHON) -m pytest -q + +test-live: + CITEGEIST_LIVE_TESTS=1 CITEGEIST_SOURCE_CACHE=.cache/citegeist $(PYTHONPATH_SRC) $(VENV_PYTHON) -m pytest -m live -q + +live-smoke: + CITEGEIST_SOURCE_CACHE=.cache/citegeist $(PYTHONPATH_SRC) $(VENV_PYTHON) scripts/live_smoke.py diff --git a/README.md b/README.md index d00168f..e795a94 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,14 @@ PYTHONPATH=src .venv/bin/python scripts/live_smoke.py By default, live tests are skipped. They only run when `CITEGEIST_LIVE_TESTS=1` is set. +Convenience targets: + +```bash +make test +make test-live +make live-smoke +``` + ## Near-Term Priorities - additional resolvers and expansion paths for non-DOI scholarly ecosystems. diff --git a/src/citegeist/__init__.py b/src/citegeist/__init__.py index d52548d..952a02c 100644 --- a/src/citegeist/__init__.py +++ b/src/citegeist/__init__.py @@ -1,5 +1,5 @@ from .bibtex import BibEntry, parse_bibtex -from .expand import CrossrefExpander +from .expand import CrossrefExpander, OpenAlexExpander from .extract import extract_references from .resolve import MetadataResolver, merge_entries from .sources import SourceClient @@ -10,6 +10,7 @@ __all__ = [ "BibliographyStore", "CrossrefExpander", "MetadataResolver", + "OpenAlexExpander", "SourceClient", "extract_references", "merge_entries", diff --git a/src/citegeist/cli.py b/src/citegeist/cli.py index 25c5739..2af8de9 100644 --- a/src/citegeist/cli.py +++ b/src/citegeist/cli.py @@ -7,7 +7,7 @@ import sys from pathlib import Path from .bibtex import parse_bibtex, render_bibtex -from .expand import CrossrefExpander +from .expand import CrossrefExpander, OpenAlexExpander from .extract import extract_references from .resolve import MetadataResolver, merge_entries from .storage import BibliographyStore @@ -69,10 +69,17 @@ def build_parser() -> argparse.ArgumentParser: expand_parser.add_argument("citation_keys", nargs="+", help="Seed citation keys to expand") expand_parser.add_argument( "--source", - choices=["crossref"], + choices=["crossref", "openalex"], default="crossref", help="External source used for graph expansion", ) + expand_parser.add_argument( + "--relation", + choices=["cites", "cited_by"], + default="cites", + help="Graph direction to expand for sources that support it", + ) + expand_parser.add_argument("--limit", type=int, default=25, help="Maximum related works to fetch per seed") return parser @@ -107,7 +114,7 @@ def main(argv: list[str] | None = None) -> int: args.missing_only, ) if args.command == "expand": - return _run_expand(store, args.citation_keys, args.source) + return _run_expand(store, args.citation_keys, args.source, args.relation, args.limit) finally: store.close() @@ -237,14 +244,25 @@ def _run_graph( return 0 -def _run_expand(store: BibliographyStore, citation_keys: list[str], source: str) -> int: - if source != "crossref": +def _run_expand( + store: BibliographyStore, + citation_keys: list[str], + source: str, + relation: str, + limit: int, +) -> int: + if source == "crossref": + expander = CrossrefExpander() + expand_fn = lambda key: expander.expand_entry_references(store, key) + elif source == "openalex": + expander = OpenAlexExpander() + expand_fn = lambda key: expander.expand_entry(store, key, relation_type=relation, limit=limit) + else: print(f"Unsupported expansion source: {source}", file=sys.stderr) return 1 - expander = CrossrefExpander() all_results = [] for citation_key in citation_keys: - all_results.extend(expander.expand_entry_references(store, citation_key)) + all_results.extend(expand_fn(citation_key)) print(json.dumps([asdict(result) for result in all_results], indent=2)) return 0 diff --git a/src/citegeist/expand.py b/src/citegeist/expand.py index 52159a5..a9079a1 100644 --- a/src/citegeist/expand.py +++ b/src/citegeist/expand.py @@ -2,8 +2,9 @@ from __future__ import annotations import re from dataclasses import dataclass +from urllib.parse import quote, urlencode -from .bibtex import BibEntry +from .bibtex import BibEntry, parse_bibtex from .resolve import MetadataResolver from .storage import BibliographyStore @@ -73,6 +74,95 @@ class CrossrefExpander: return results +class OpenAlexExpander: + def __init__(self, resolver: MetadataResolver | None = None) -> None: + self.resolver = resolver or MetadataResolver() + + def expand_entry( + self, + store: BibliographyStore, + citation_key: str, + relation_type: str = "cites", + limit: int = 25, + ) -> list[ExpansionResult]: + entry = store.get_entry(citation_key) + if entry is None: + return [] + + openalex_id = entry.get("openalex") or self._lookup_openalex_id(entry) + if not openalex_id: + return [] + if not entry.get("openalex"): + bibtex = store.get_entry_bibtex(citation_key) + if bibtex: + seed_entry = parse_bibtex(bibtex)[0] + seed_entry.fields["openalex"] = openalex_id + store.replace_entry( + citation_key, + seed_entry, + source_type="resolver", + source_label=f"openalex:id:{openalex_id}", + review_status=str(entry.get("review_status") or "draft"), + ) + + filter_name = "cited_by" if relation_type == "cites" else "cites" + query = urlencode({"filter": f"{filter_name}:{openalex_id}", "per-page": limit}) + payload = self.resolver.source_client.get_json(f"https://api.openalex.org/works?{query}") + works = payload.get("results", []) + + results: list[ExpansionResult] = [] + for work in works: + discovered = _openalex_work_to_entry(work) + created = False + if store.get_entry(discovered.citation_key) is None: + store.upsert_entry( + discovered, + raw_bibtex=None, + source_type="graph_expand", + source_label=f"openalex:{relation_type}:{openalex_id}", + review_status="draft", + ) + store.connection.commit() + created = True + + if relation_type == "cites": + source_key = citation_key + target_key = discovered.citation_key + else: + source_key = discovered.citation_key + target_key = citation_key + + store.add_relation( + source_key, + target_key, + "cites", + source_type="graph_expand", + source_label=f"openalex:{relation_type}:{openalex_id}", + confidence=0.9, + ) + results.append( + ExpansionResult( + source_citation_key=source_key, + discovered_citation_key=discovered.citation_key, + created_entry=created, + relation_type=relation_type, + source_label=f"openalex:{relation_type}:{openalex_id}", + ) + ) + return results + + def _lookup_openalex_id(self, entry: dict[str, object]) -> str | None: + doi = entry.get("doi") + if not doi: + return None + query = urlencode({"filter": f"doi:https://doi.org/{doi}"}) + payload = self.resolver.source_client.get_json(f"https://api.openalex.org/works?{query}") + results = payload.get("results", []) + if not results: + return None + return _normalize_openalex_id(results[0].get("id", "")) + + def _crossref_reference_to_entry(reference: dict, source_citation_key: str, ordinal: int) -> BibEntry: title = ( reference.get("article-title") @@ -119,3 +209,83 @@ def _reference_citation_key(reference: dict, title: str, year: str, ordinal: int def _normalize_text(value: str) -> str: return " ".join(value.split()) + + +def _openalex_work_to_entry(work: dict) -> BibEntry: + title = _normalize_text(work.get("display_name", "") or "Untitled work") + year = str(work.get("publication_year") or "") + doi = _normalize_openalex_doi(work.get("doi")) + openalex_id = _normalize_openalex_id(work.get("id", "")) + authors = " and ".join(_openalex_author_name(item) for item in work.get("authorships", [])) + source = ((work.get("primary_location") or {}).get("source") or {}).get("display_name", "") + work_type = work.get("type", "") + + fields: dict[str, str] = {"title": title} + if year: + fields["year"] = year + if authors: + fields["author"] = authors + if doi: + fields["doi"] = doi + fields["url"] = f"https://doi.org/{doi}" + if openalex_id: + fields["openalex"] = openalex_id + if abstract := work.get("abstract_inverted_index"): + fields["abstract"] = _openalex_abstract_text(abstract) + if source: + if work_type == "article": + fields["journal"] = source + else: + fields["booktitle"] = source + + citation_key = _openalex_citation_key(openalex_id, authors, year, title) + entry_type = _openalex_type_to_bibtype(work_type) + return BibEntry(entry_type=entry_type, citation_key=citation_key, fields=fields) + + +def _openalex_author_name(authorship: dict) -> str: + author = authorship.get("author") or {} + name = author.get("display_name", "") + return _normalize_text(name) + + +def _openalex_abstract_text(inverted_index: dict) -> str: + positions: dict[int, str] = {} + for word, indexes in inverted_index.items(): + for index in indexes: + positions[int(index)] = word + return " ".join(word for _, word in sorted(positions.items())) + + +def _openalex_type_to_bibtype(work_type: str) -> str: + mapping = { + "article": "article", + "book": "book", + "book-chapter": "incollection", + "dissertation": "phdthesis", + "proceedings-article": "inproceedings", + } + return mapping.get(work_type, "misc") + + +def _openalex_citation_key(openalex_id: str, authors: str, year: str, title: str) -> str: + if openalex_id: + return f"openalex{re.sub(r'[^A-Za-z0-9]+', '', openalex_id).lower()}" + author = authors.split(" and ")[0] if authors else "ref" + family = re.sub(r"[^A-Za-z0-9]+", "", author.split()[-1]).lower() or "ref" + first_word = re.sub(r"[^A-Za-z0-9]+", "", title.split()[0]).lower() if title.split() else "untitled" + return f"{family}{year or 'nd'}{first_word}" + + +def _normalize_openalex_id(value: str) -> str: + if not value: + return "" + return value.rsplit("/", 1)[-1] + + +def _normalize_openalex_doi(value: str | None) -> str: + if not value: + return "" + if value.startswith("https://doi.org/"): + return value[len("https://doi.org/") :] + return value diff --git a/src/citegeist/extract.py b/src/citegeist/extract.py index 5df2eb6..782f984 100644 --- a/src/citegeist/extract.py +++ b/src/citegeist/extract.py @@ -5,11 +5,13 @@ import re from .bibtex import BibEntry YEAR_PATTERN = re.compile(r"\b(19|20)\d{2}\b") +YEAR_PAREN_PATTERN = re.compile(r"\((19|20)\d{2}\)") +REF_START_PATTERN = re.compile(r"^(?:\[\d+\]|\d+\.|\(\d+\))\s*") def extract_references(text: str) -> list[BibEntry]: entries: list[BibEntry] = [] - for index, line in enumerate(_iter_reference_lines(text), start=1): + for index, line in enumerate(_iter_reference_blocks(text), start=1): parsed = _parse_reference_line(line, index) if parsed is not None: entries.append(parsed) @@ -22,22 +24,95 @@ def render_extracted_bibtex(text: str) -> str: return render_bibtex(extract_references(text)) -def _iter_reference_lines(text: str) -> list[str]: +def _iter_reference_blocks(text: str) -> list[str]: lines: list[str] = [] + current: list[str] = [] for raw_line in text.splitlines(): line = raw_line.strip() if not line: + if current: + lines.append(" ".join(current)) + current = [] continue - line = re.sub(r"^\[\d+\]\s*", "", line) - line = re.sub(r"^\d+\.\s*", "", line) - line = re.sub(r"^\(\d+\)\s*", "", line) - if len(line) < 20: + starts_new = bool(REF_START_PATTERN.match(line)) + line = REF_START_PATTERN.sub("", line) + normalized = " ".join(line.split()) + if len(normalized) < 20: continue - lines.append(" ".join(line.split())) + if starts_new and current: + lines.append(" ".join(current)) + current = [normalized] + else: + current.append(normalized) + if current: + lines.append(" ".join(current)) return lines def _parse_reference_line(line: str, ordinal: int) -> BibEntry | None: + for parser in (_parse_apa_style_reference, _parse_publisher_style_reference, _parse_plain_year_reference): + parsed = parser(line, ordinal) + if parsed is not None: + return parsed + return None + + +def _parse_apa_style_reference(line: str, ordinal: int) -> BibEntry | None: + year_match = YEAR_PAREN_PATTERN.search(line) + if year_match is None: + return None + + year = year_match.group(0).strip("()") + author_part = line[: year_match.start()].strip(" .") + remainder = line[year_match.end() :].strip(" .") + if not author_part or not remainder: + return None + + segments = _segments_after_year(remainder) + if not segments: + return None + + title = _clean_title(segments[0]) + venue = segments[1] if len(segments) > 1 else "" + authors = _normalize_authors(author_part) + return _build_entry(line, ordinal, authors, year, title, venue) + + +def _parse_publisher_style_reference(line: str, ordinal: int) -> BibEntry | None: + year_match = YEAR_PATTERN.search(line) + if year_match is None: + return None + + prefix = line[: year_match.start()].strip(" .,;") + if "." not in prefix: + return None + + head, publisher = prefix.rsplit(".", 1) + if "." not in head: + return None + author_part, title = head.split(".", 1) + + authors = _normalize_authors(author_part) + title = _clean_title(title) + publisher = publisher.strip(" .,;") + if not authors or not title or not publisher: + return None + + citation_key = _make_citation_key(authors, year_match.group(0), title, ordinal) + return BibEntry( + entry_type="book", + citation_key=citation_key, + fields={ + "author": authors, + "year": year_match.group(0), + "title": title, + "publisher": publisher, + "note": f"extracted_reference = {{true}}; raw_reference = {{{line}}}", + }, + ) + + +def _parse_plain_year_reference(line: str, ordinal: int) -> BibEntry | None: year_match = YEAR_PATTERN.search(line) if year_match is None: return None @@ -48,14 +123,42 @@ def _parse_reference_line(line: str, ordinal: int) -> BibEntry | None: if not author_part or not remainder: return None - segments = [segment.strip(" .") for segment in remainder.split(".") if segment.strip(" .")] + segments = _segments_after_year(remainder) if not segments: return None - title = segments[0] + title = _clean_title(segments[0]) venue = segments[1] if len(segments) > 1 else "" - authors = _normalize_authors(author_part) + return _build_entry(line, ordinal, authors, year, title, venue) + + +def _normalize_authors(author_part: str) -> str: + normalized = author_part.replace(" & ", " and ") + normalized = re.sub(r"\bet al\.?$", "and others", normalized) + normalized = re.sub(r"\s+and\s+", " and ", normalized) + normalized = re.sub(r"\s*,\s*", ", ", normalized) + return normalized.strip(" .") + + +def _segments_after_year(remainder: str) -> list[str]: + return [segment.strip(" .") for segment in remainder.split(". ") if segment.strip(" .")] + + +def _clean_title(title: str) -> str: + cleaned = title.strip(" .\"'") + cleaned = re.sub(r"\s+", " ", cleaned) + return cleaned + + +def _build_entry( + raw_line: str, + ordinal: int, + authors: str, + year: str, + title: str, + venue: str, +) -> BibEntry: citation_key = _make_citation_key(authors, year, title, ordinal) entry_type = _guess_entry_type(venue) @@ -63,25 +166,19 @@ def _parse_reference_line(line: str, ordinal: int) -> BibEntry | None: "author": authors, "year": year, "title": title, - "note": f"extracted_reference = {{true}}; raw_reference = {{{line}}}", + "note": f"extracted_reference = {{true}}; raw_reference = {{{raw_line}}}", } if venue: if entry_type == "article": fields["journal"] = venue - else: + elif entry_type == "inproceedings": fields["booktitle"] = venue + else: + fields["howpublished"] = venue return BibEntry(entry_type=entry_type, citation_key=citation_key, fields=fields) -def _normalize_authors(author_part: str) -> str: - normalized = author_part.replace(" & ", " and ") - normalized = re.sub(r"\bet al\.$", "and others", normalized) - normalized = re.sub(r"\s+and\s+", " and ", normalized) - normalized = re.sub(r"\s*,\s*", ", ", normalized) - return normalized.strip(" .") - - def _make_citation_key(authors: str, year: str, title: str, ordinal: int) -> str: first_author = authors.split(" and ")[0] family_name = first_author.split(",")[0] if "," in first_author else first_author.split()[-1] @@ -99,4 +196,6 @@ def _guess_entry_type(venue: str) -> str: return "article" if any(token in lowered for token in ("proceedings", "conference", "workshop", "symposium")): return "inproceedings" + if any(token in lowered for token in ("press", "publisher", "university")): + return "book" return "misc" diff --git a/src/citegeist/resolve.py b/src/citegeist/resolve.py index f37d282..4d3ce28 100644 --- a/src/citegeist/resolve.py +++ b/src/citegeist/resolve.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re import urllib.parse import xml.etree.ElementTree as ET from dataclasses import dataclass @@ -30,6 +31,11 @@ class MetadataResolver: if resolved is not None: return resolved + if openalex_id := entry.fields.get("openalex"): + resolved = self.resolve_openalex(openalex_id) + if resolved is not None: + return resolved + if dblp_key := entry.fields.get("dblp"): resolved = self.resolve_dblp(dblp_key) if resolved is not None: @@ -40,6 +46,15 @@ class MetadataResolver: if resolved is not None: return resolved + if title := entry.fields.get("title"): + resolved = self.search_openalex_best_match( + title=title, + author_text=entry.fields.get("author", ""), + year=entry.fields.get("year", ""), + ) + if resolved is not None: + return resolved + return None def resolve_doi(self, doi: str) -> Resolution | None: @@ -102,6 +117,55 @@ class MetadataResolver: source_label=f"arxiv:id:{arxiv_id}", ) + def resolve_openalex(self, openalex_id: str) -> Resolution | None: + normalized_id = _normalize_openalex_id(openalex_id) + payload = self.source_client.get_json(f"https://api.openalex.org/works/{normalized_id}") + if not payload: + return None + return Resolution( + entry=_openalex_work_to_entry(payload), + source_type="resolver", + source_label=f"openalex:id:{normalized_id}", + ) + + def search_openalex(self, title: str, limit: int = 5) -> list[BibEntry]: + query = urllib.parse.urlencode({"search": title, "per-page": limit}) + payload = self.source_client.get_json(f"https://api.openalex.org/works?{query}") + return [_openalex_work_to_entry(item) for item in payload.get("results", [])] + + def search_openalex_best_match( + self, + title: str, + author_text: str = "", + year: str = "", + ) -> Resolution | None: + candidates = self.search_openalex(title, limit=5) + if not candidates: + return None + + title_norm = _normalize_match_text(title) + author_norm = _normalize_match_text(author_text) + for candidate in candidates: + candidate_title = _normalize_match_text(candidate.fields.get("title", "")) + candidate_author = _normalize_match_text(candidate.fields.get("author", "")) + candidate_year = candidate.fields.get("year", "") + if candidate_title == title_norm: + if author_norm and candidate_author and author_norm.split(" and ")[0] not in candidate_author: + continue + if year and candidate_year and year != candidate_year: + continue + return Resolution( + entry=candidate, + source_type="resolver", + source_label=f"openalex:search:{title}", + ) + + return Resolution( + entry=candidates[0], + source_type="resolver", + source_label=f"openalex:search:{title}", + ) + def merge_entries(base: BibEntry, resolved: BibEntry) -> BibEntry: merged_fields = dict(base.fields) for key, value in resolved.fields.items(): @@ -221,3 +285,81 @@ def _make_resolution_key(author_text: str, year: str, title: str) -> str: family_name = "".join(ch for ch in family_name.lower() if ch.isalnum()) or "ref" first_word = "".join(ch for ch in title.split()[0].lower() if ch.isalnum()) if title.split() else "untitled" return f"{family_name}{year}{first_word}" + + +def _openalex_work_to_entry(work: dict) -> BibEntry: + title = work.get("display_name", "") or "Untitled work" + year = str(work.get("publication_year") or "") + doi = _normalize_openalex_doi(work.get("doi")) + openalex_id = _normalize_openalex_id(work.get("id", "")) + authors = " and ".join(_openalex_author_name(item) for item in work.get("authorships", [])) + source = ((work.get("primary_location") or {}).get("source") or {}).get("display_name", "") + work_type = work.get("type", "") + + fields: dict[str, str] = {} + if authors: + fields["author"] = authors + if title: + fields["title"] = title + if year: + fields["year"] = year + if doi: + fields["doi"] = doi + fields["url"] = f"https://doi.org/{doi}" + if openalex_id: + fields["openalex"] = openalex_id + fields.setdefault("url", f"https://openalex.org/{openalex_id}") + if abstract := work.get("abstract_inverted_index"): + fields["abstract"] = _openalex_abstract_text(abstract) + if source: + if work_type == "article": + fields["journal"] = source + else: + fields["booktitle"] = source + + citation_key = f"openalex{re.sub(r'[^A-Za-z0-9]+', '', openalex_id).lower()}" if openalex_id else _make_resolution_key(authors or "openalex", year or "n.d.", title or "untitled") + return BibEntry(entry_type=_openalex_type_to_bibtype(work_type), citation_key=citation_key, fields=fields) + + +def _openalex_author_name(authorship: dict) -> str: + author = authorship.get("author") or {} + return " ".join(str(author.get("display_name", "")).split()) + + +def _openalex_abstract_text(inverted_index: dict) -> str: + positions: dict[int, str] = {} + for word, indexes in inverted_index.items(): + for index in indexes: + positions[int(index)] = word + return " ".join(word for _, word in sorted(positions.items())) + + +def _openalex_type_to_bibtype(work_type: str) -> str: + mapping = { + "article": "article", + "book": "book", + "book-chapter": "incollection", + "dissertation": "phdthesis", + "proceedings-article": "inproceedings", + } + return mapping.get(work_type, "misc") + + +def _normalize_openalex_id(value: str) -> str: + if not value: + return "" + return value.rsplit("/", 1)[-1] + + +def _normalize_openalex_doi(value: str | None) -> str: + if not value: + return "" + if value.startswith("https://doi.org/"): + return value[len("https://doi.org/") :] + return value + + +def _normalize_match_text(value: str) -> str: + lowered = value.lower() + lowered = re.sub(r"\W+", " ", lowered) + return " ".join(lowered.split()) diff --git a/src/citegeist/storage.py b/src/citegeist/storage.py index 10fdc6a..57e75ee 100644 --- a/src/citegeist/storage.py +++ b/src/citegeist/storage.py @@ -8,7 +8,7 @@ from pathlib import Path from .bibtex import BibEntry, parse_bibtex, render_bibtex -IDENTIFIER_FIELDS = ("doi", "isbn", "issn", "pmid", "arxiv", "dblp", "oai", "url") +IDENTIFIER_FIELDS = ("doi", "isbn", "issn", "pmid", "arxiv", "dblp", "oai", "openalex", "url") RELATION_FIELDS = { "references": "cites", "cites": "cites", @@ -383,7 +383,7 @@ class BibliographyStore: "SELECT * FROM entries WHERE citation_key = ?", (citation_key,), ).fetchone() - return dict(row) if row else None + return self._row_to_entry_dict(row) if row else None def list_entries(self, limit: int = 50) -> list[dict[str, object]]: rows = self.connection.execute( @@ -601,6 +601,13 @@ class BibliographyStore: ).fetchall() return [str(row["full_name"]) for row in rows] + def _row_to_entry_dict(self, row: sqlite3.Row) -> dict[str, object]: + payload = dict(row) + extra_fields = json.loads(str(payload.get("extra_fields_json") or "{}")) + for key, value in extra_fields.items(): + payload.setdefault(key, value) + return payload + def _iter_graph_edges(self, citation_key: str, allowed_relations: set[str]) -> list[sqlite3.Row]: rows = self.connection.execute( """