Source code for ads_mcp.client

"""HTTP client for the SAO/NASA Astrophysics Data System (ADS) API.

This module provides a thin, async wrapper around the ADS REST API
(https://github.com/adsabs/adsabs-dev-api).  All network I/O is performed
with :mod:`httpx` so it plays nicely with FastMCP's async event loop.

Example:
    Basic usage with an async context manager::

        async with ADSClient(api_key="my-token") as client:
            result = await client.search("black hole accretion", rows=5)
            print(result)
"""

from __future__ import annotations

import json
import os
from typing import Any
from urllib.parse import urlencode

import httpx
from dotenv import load_dotenv

load_dotenv()

#: Base URL for all ADS v1 API requests.
ADS_BASE_URL = "https://api.adsabs.harvard.edu/v1"

#: Default fields returned by the search endpoint.
DEFAULT_SEARCH_FIELDS = (
    "bibcode,title,author,year,abstract,doi,arxiv_class,"
    "identifier,citation_count,read_count,pub,volume,page,keyword"
)


[docs] class ADSError(Exception): """Raised when the ADS API returns a non-2xx response. Attributes: status_code: HTTP status code returned by the server. message: Human-readable error description. """
[docs] def __init__(self, status_code: int, message: str) -> None: self.status_code = status_code self.message = message super().__init__(f"ADS API error {status_code}: {message}")
[docs] class ADSClient: """Async HTTP client for the ADS API. Wraps the ADS REST API endpoints into straightforward async methods. An API key is required; obtain one at https://ui.adsabs.harvard.edu/user/settings/token. Args: api_key: ADS API bearer token. If *None*, the value of the ``ADS_API_TOKEN`` environment variable is used. timeout: Request timeout in seconds. Defaults to ``30``. Raises: ValueError: If no API key is provided and ``ADS_API_TOKEN`` is not set in the environment. Example: :: async with ADSClient() as client: papers = await client.search("exoplanet atmospheres", rows=10) """
[docs] def __init__(self, api_key: str | None = None, timeout: float = 30.0) -> None: resolved_key = api_key or os.environ.get("ADS_API_TOKEN") if not resolved_key: raise ValueError( "No ADS API key provided. Set the ADS_API_TOKEN environment variable " "or pass api_key= to ADSClient." ) self._api_key = resolved_key self._timeout = timeout self._client: httpx.AsyncClient | None = None
# ------------------------------------------------------------------ # Context-manager helpers # ------------------------------------------------------------------ async def __aenter__(self) -> ADSClient: self._client = httpx.AsyncClient( base_url=ADS_BASE_URL, headers={"Authorization": f"Bearer {self._api_key}"}, timeout=self._timeout, ) return self async def __aexit__(self, *_: Any) -> None: if self._client: await self._client.aclose() self._client = None # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ @property def _http(self) -> httpx.AsyncClient: if self._client is None: raise RuntimeError("ADSClient must be used as an async context manager.") return self._client def _check(self, response: httpx.Response) -> httpx.Response: """Raise :class:`ADSError` for non-2xx responses. Args: response: The HTTP response to check. Returns: The same *response* object if the status is successful. Raises: ADSError: If the status code indicates an error. """ if response.is_error: try: detail = response.json().get("error", response.text) except Exception: detail = response.text raise ADSError(response.status_code, str(detail)) return response # ------------------------------------------------------------------ # Search endpoint # ------------------------------------------------------------------
[docs] async def search( self, query: str, fields: str = DEFAULT_SEARCH_FIELDS, rows: int = 10, start: int = 0, sort: str = "date desc", ) -> dict[str, Any]: """Search the ADS database. Calls ``GET /v1/search/query``. Args: query: Solr/ADS query string, e.g. ``"author:Einstein relativity"``. fields: Comma-separated list of fields to return. rows: Maximum number of results. Capped at 2000 by ADS. start: Offset into the result set (for pagination). sort: Sort order, e.g. ``"citation_count desc"``. Returns: Parsed JSON response from ADS containing ``response.docs`` and ``response.numFound``. Raises: ADSError: If ADS returns a non-2xx status. """ params = urlencode( {"q": query, "fl": fields, "rows": rows, "start": start, "sort": sort} ) r = await self._http.get(f"/search/query?{params}") return self._check(r).json()
# ------------------------------------------------------------------ # Export endpoint # ------------------------------------------------------------------
[docs] async def export(self, bibcodes: list[str], fmt: str = "bibtex") -> str: """Export one or more records in a bibliographic format. Calls ``POST /v1/export/{fmt}``. Args: bibcodes: List of ADS bibcodes to export. fmt: Export format. One of ``bibtex``, ``bibtexabs``, ``ris``, ``endnote``, ``procite``, ``refworks``, ``aastex``, ``icarus``, ``mnras``, ``soph``, ``dcxml``, ``refxml``, ``ads``, ``medlars``, ``votable``. Returns: The exported bibliography as a plain string. Raises: ADSError: If ADS returns a non-2xx status. """ payload = json.dumps({"bibcode": bibcodes}) r = await self._http.post( f"/export/{fmt}", content=payload, headers={"Content-Type": "application/json"}, ) data = self._check(r).json() return data.get("export", "")
# ------------------------------------------------------------------ # Metrics endpoint # ------------------------------------------------------------------
[docs] async def metrics( self, bibcodes: list[str], types: list[str] | None = None, ) -> dict[str, Any]: """Retrieve citation and usage metrics for a set of papers. Calls ``POST /v1/metrics``. Args: bibcodes: List of ADS bibcodes to retrieve metrics for. types: Optional list selecting which metric types to return. Allowed values: ``"basic"``, ``"citations"``, ``"indicators"``, ``"histograms"``, ``"timeseries"``. If *None*, all metrics are returned. Returns: Parsed JSON metrics object from ADS. Raises: ADSError: If ADS returns a non-2xx status. """ body: dict[str, Any] = {"bibcodes": bibcodes} if types: body["types"] = types r = await self._http.post( "/metrics", content=json.dumps(body), headers={"Content-Type": "application/json"}, ) return self._check(r).json()
# ------------------------------------------------------------------ # Convenience searches # ------------------------------------------------------------------
[docs] async def search_by_bibcode( self, bibcodes: list[str], fields: str = DEFAULT_SEARCH_FIELDS, rows: int = 2000 ) -> dict[str, Any]: """Retrieve records for a list of bibcodes via a big-query POST. Calls ``POST /v1/search/bigquery``. Args: bibcodes: List of ADS bibcodes. fields: Comma-separated list of fields to return. rows: Maximum number of results. Returns: Parsed JSON response from ADS. Raises: ADSError: If ADS returns a non-2xx status. """ params = urlencode({"q": "*:*", "fl": fields, "rows": rows}) payload = "bibcode\n" + "\n".join(bibcodes) r = await self._http.post( f"/search/bigquery?{params}", content=payload, headers={"Content-Type": "big-query/csv"}, ) return self._check(r).json()
[docs] async def get_references(self, bibcode: str, rows: int = 200) -> dict[str, Any]: """Fetch the reference list of a paper. Uses the ``references()`` operator in the ADS search API. Args: bibcode: ADS bibcode of the paper whose references are requested. rows: Maximum number of references to return. Returns: Parsed JSON search response from ADS. Raises: ADSError: If ADS returns a non-2xx status. """ return await self.search( f"references(bibcode:{bibcode})", fields="bibcode,title,author,year,doi,pub", rows=rows, sort="date desc", )
[docs] async def get_citations(self, bibcode: str, rows: int = 200) -> dict[str, Any]: """Fetch the papers that cite a given paper. Uses the ``citations()`` operator in the ADS search API. Args: bibcode: ADS bibcode of the paper whose citations are requested. rows: Maximum number of citing papers to return. Returns: Parsed JSON search response from ADS. Raises: ADSError: If ADS returns a non-2xx status. """ return await self.search( f"citations(bibcode:{bibcode})", fields="bibcode,title,author,year,doi,citation_count,pub", rows=rows, sort="date desc", )