From d4099c13a5b1b635fe9bf1fc54929281e22ed6fb Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Mon, 20 Apr 2026 17:13:18 -0700 Subject: [PATCH 01/15] pd retrieval with delta backed --- dev/generate_mcp_tools.py | 3 +- mp_api/client/core/utils.py | 6 +-- mp_api/client/mprester.py | 2 +- .../routes/materials/electronic_structure.py | 2 +- mp_api/client/routes/materials/thermo.py | 39 ++++++++++++------- mp_api/mcp/tools.py | 3 +- pyproject.toml | 4 +- 7 files changed, 35 insertions(+), 24 deletions(-) diff --git a/dev/generate_mcp_tools.py b/dev/generate_mcp_tools.py index d4037715..57ca9511 100644 --- a/dev/generate_mcp_tools.py +++ b/dev/generate_mcp_tools.py @@ -53,13 +53,14 @@ def regenerate_tools( from datetime import datetime from typing import Literal +from emmet.core.band_theory import BSPathType from emmet.core.chemenv import ( COORDINATION_GEOMETRIES, COORDINATION_GEOMETRIES_IUCR, COORDINATION_GEOMETRIES_IUPAC, COORDINATION_GEOMETRIES_NAMES, ) -from emmet.core.electronic_structure import BSPathType, DOSProjectionType +from emmet.core.electronic_structure import DOSProjectionType from emmet.core.grain_boundary import GBTypeEnum from emmet.core.mpid import MPID from emmet.core.thermo import ThermoType diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 5c15f8b4..95793591 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -116,10 +116,8 @@ def validate_ids(id_list: list[str]) -> list[str]: " data for all IDs and filter locally." ) - # TODO: after the transition to AlphaID in the document models, - # The following line should be changed to - # return [validate_identifier(idx,serialize=True) for idx in id_list] - return [str(validate_identifier(idx)) for idx in id_list] + [validate_identifier(idx, serialize=False) for idx in id_list] + return [getattr(idx, "string", str(idx)) for idx in id_list] def validate_endpoint(endpoint: str | None, suffix: str | None = None) -> str: diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index c2dfefde..2cf29e29 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -7,7 +7,7 @@ from functools import cache, lru_cache from typing import TYPE_CHECKING -from emmet.core.electronic_structure import BSPathType +from emmet.core.band_theory import BSPathType from emmet.core.mpid import MPID, AlphaID from emmet.core.types.enums import ThermoType from emmet.core.vasp.calc_types import CalcType diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index 28b02ed2..b0bee09e 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -4,8 +4,8 @@ from collections import defaultdict from typing import TYPE_CHECKING +from emmet.core.band_theory import BSPathType from emmet.core.electronic_structure import ( - BSPathType, DOSProjectionType, ElectronicStructureDoc, ) diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index 0ed19818..7d5182ac 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -3,13 +3,17 @@ from collections import defaultdict import numpy as np +import pyarrow as pa +from deltalake import DeltaTable, QueryBuilder from emmet.core.thermo import ThermoDoc from emmet.core.types.enums import ThermoType +from emmet.core.types.pymatgen_types.phase_diagram_adapter import PhaseDiagramType +from pydantic import TypeAdapter from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.core import Element from mp_api.client.core import BaseRester -from mp_api.client.core.utils import load_json, validate_ids +from mp_api.client.core.utils import validate_ids class ThermoRester(BaseRester): @@ -163,21 +167,28 @@ def get_phase_diagram_from_chemsys( ) sorted_chemsys = "-".join(sorted(chemsys.split("-"))) - phdiag_id = f"thermo_type={t_type}/chemsys={sorted_chemsys}" version = self.db_version.replace(".", "-") - obj_key = f"objects/{version}/phase-diagrams/{phdiag_id}.jsonl.gz" - pd_dct = self._query_open_data( # type: ignore[union-attr] - bucket="materialsproject-build", - key=obj_key, - decoder=lambda x: load_json(x, deser=False), - )[0][0].get("phase_diagram") - - pd = PhaseDiagram.from_dict( - { - k: v if k != "elements" else [e.get("element", e) for e in v] - for k, v in pd_dct.items() # type: ignore[union-attr] - } + + pd_tbl = DeltaTable( + "s3://materialsproject-build/objects/phase-diagrams/", + storage_options={"AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1"}, + ) + qb = QueryBuilder().register("phase_diagrams", pd_tbl) + table = pa.table( + qb.execute( + f"""SELECT phase_diagram + FROM phase_diagrams + WHERE chemsys='{sorted_chemsys}' + AND version='{version}' + AND thermo_type='{thermo_type}' + """ + ) ) + as_py = table["phase_diagram"].to_pylist(maps_as_pydicts="strict") + + pd: PhaseDiagram | None = None + if len(pds := TypeAdapter(list[PhaseDiagramType]).validate_python(as_py)) > 0: + pd = pds[0] # Ensure el_ref keys are Element objects for PDPlotter. # Ensure qhull_data is a numpy array diff --git a/mp_api/mcp/tools.py b/mp_api/mcp/tools.py index 21a0602f..6fcfe8e2 100644 --- a/mp_api/mcp/tools.py +++ b/mp_api/mcp/tools.py @@ -12,7 +12,8 @@ COORDINATION_GEOMETRIES_IUPAC, COORDINATION_GEOMETRIES_NAMES, ) -from emmet.core.electronic_structure import BSPathType, DOSProjectionType +from emmet.core.band_theory import BSPathType +from emmet.core.electronic_structure import DOSProjectionType from emmet.core.grain_boundary import GBTypeEnum from emmet.core.mpid import MPID from emmet.core.summary import HasProps diff --git a/pyproject.toml b/pyproject.toml index 4541d657..4efea2c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "typing-extensions>=3.7.4.1", "requests>=2.23.0", "monty>=2024.12.10", - "emmet-core>=0.86.4rc1,<0.86.5", + "emmet-core>=0.87.0.dev,<0.87.2", "boto3", "orjson >= 3.10,<4", "pyarrow >= 20.0.0", @@ -37,7 +37,7 @@ mcp = ["fastmcp"] server = ["flask"] all = [ "custodian", - "emmet-core[all]>=0.86.4rc1,<0.86.5", + "emmet-core[all]>=0.87.0.dev,<0.87.2", "fastmcp", "flask", "mpcontribs-client>=5.10", From 26a4a5b68a7be1b391801a609614ad7277e05239 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Mon, 20 Apr 2026 17:35:18 -0700 Subject: [PATCH 02/15] machinery for persistent query builder --- mp_api/client/core/client.py | 6 +++++- mp_api/client/mprester.py | 5 +++++ mp_api/client/routes/materials/thermo.py | 4 ++-- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 75f8fd0d..cbc4159c 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -129,6 +129,7 @@ def __init__( str | os.PathLike ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, force_renew: bool = False, + query_builder: QueryBuilder | None = None, **kwargs, ): """Initialize the REST API helper class. @@ -163,6 +164,7 @@ def __init__( local_dataset_cache: Target directory for downloading full datasets. Defaults to 'mp_datasets' in the user's home directory force_renew: Option to overwrite existing local dataset + query_builder : Instance of deltalake QueryBuilder to use in querying delta tables **kwargs: access to legacy kwargs that may be in the process of being deprecated """ self.api_key = validate_api_key(api_key) @@ -185,6 +187,7 @@ def __init__( self.force_renew = force_renew self._session = session + self._query_builder = query_builder or QueryBuilder() self._s3_client = s3_client if "monty_decode" in kwargs: @@ -545,7 +548,7 @@ def _query_delta_backed( else "" ) - builder = QueryBuilder().register("tbl", tbl) + builder = self._query_builder.register("tbl", tbl) # Setup progress bar num_docs_needed: int = tbl.count() @@ -1619,6 +1622,7 @@ def __getattr__(self, v: str): mute_progress_bars=self.mute_progress_bars, local_dataset_cache=self.local_dataset_cache, force_renew=self.force_renew, + query_builder=self._query_builder, ) return self.sub_resters[v] raise AttributeError(f"{self.__class__} has no attribute {v}") diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 2cf29e29..5525da69 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -7,6 +7,7 @@ from functools import cache, lru_cache from typing import TYPE_CHECKING +from deltalake import QueryBuilder from emmet.core.band_theory import BSPathType from emmet.core.mpid import MPID, AlphaID from emmet.core.types.enums import ThermoType @@ -102,6 +103,7 @@ def __init__( str | os.PathLike ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, force_renew: bool = False, + query_builder: QueryBuilder | None = None, **kwargs, ): """Initialize the MPRester. @@ -139,6 +141,7 @@ def __init__( local_dataset_cache: Target directory for downloading full datasets. Defaults to "mp_datasets" in the user's home directory force_renew: Option to overwrite existing local dataset + query_builder : Instance of deltalake QueryBuilder to use in querying delta tables **kwargs: access to legacy kwargs that may be in the process of being deprecated """ self.api_key = get_user_api_key(api_key=api_key) @@ -220,6 +223,7 @@ def __init__( # Instantiate top level core molecules, materials, and DOI resters, as well # as the sunder resters to allow the web server to work. + self._query_builder = query_builder or QueryBuilder() for rest_name, lazy_rester in (RESTER_LAYOUT | GENERIC_RESTERS).items(): if rest_name in TOP_LEVEL_RESTERS: setattr( @@ -235,6 +239,7 @@ def __init__( mute_progress_bars=self.mute_progress_bars, local_dataset_cache=self.local_dataset_cache, force_renew=self.force_renew, + query_builder=self._query_builder, ), ) diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index 7d5182ac..65f05489 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -4,7 +4,7 @@ import numpy as np import pyarrow as pa -from deltalake import DeltaTable, QueryBuilder +from deltalake import DeltaTable from emmet.core.thermo import ThermoDoc from emmet.core.types.enums import ThermoType from emmet.core.types.pymatgen_types.phase_diagram_adapter import PhaseDiagramType @@ -173,7 +173,7 @@ def get_phase_diagram_from_chemsys( "s3://materialsproject-build/objects/phase-diagrams/", storage_options={"AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1"}, ) - qb = QueryBuilder().register("phase_diagrams", pd_tbl) + qb = self._query_builder.register("phase_diagrams", pd_tbl) table = pa.table( qb.execute( f"""SELECT phase_diagram From 8f735b0c1a63b1553541a7188c49b417268a7ee3 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 21 Apr 2026 09:09:16 -0700 Subject: [PATCH 03/15] remove unused debug kwrag --- mp_api/client/core/client.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index cbc4159c..8d8cd3db 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -120,7 +120,6 @@ def __init__( include_user_agent: bool = True, session: requests.Session | None = None, s3_client: Any | None = None, - debug: bool = False, use_document_model: bool = True, timeout: int = 20, headers: dict | None = None, @@ -154,7 +153,6 @@ def __init__( session: requests Session object with which to connect to the API, for advanced usage only. s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores. - debug: if True, print the URL for every request use_document_model: If False, skip the creating the document model and return data as a dictionary. This can be simpler to work with but bypasses data validation and will not give auto-complete for available fields. @@ -171,7 +169,6 @@ def __init__( self.base_endpoint = validate_endpoint(endpoint) self.endpoint = validate_endpoint(endpoint, suffix=self.suffix) - self.debug = debug self.include_user_agent = include_user_agent self.use_document_model = use_document_model self.timeout = timeout @@ -187,7 +184,7 @@ def __init__( self.force_renew = force_renew self._session = session - self._query_builder = query_builder or QueryBuilder() + self._query_builder = query_builder self._s3_client = s3_client if "monty_decode" in kwargs: @@ -215,6 +212,12 @@ def s3_client(self): ) return self._s3_client + @property + def query_builder(self): + if not self._query_builder: + self._query_builder = QueryBuilder() + return self._query_builder + @staticmethod def _create_session(api_key, include_user_agent, headers): session = requests.Session() @@ -462,6 +465,16 @@ def _query_open_data( return decoded_data, len(decoded_data) # type: ignore + @staticmethod + def _get_delta_table(bucket : str, prefix : str) -> DeltaTable: + return DeltaTable( + f"s3a://{bucket}/{prefix}", + storage_options={ + "AWS_SKIP_SIGNATURE": "true", + "AWS_REGION": "us-east-1", + }, + ) + def _query_delta_backed( self, bucket: str, @@ -548,7 +561,7 @@ def _query_delta_backed( else "" ) - builder = self._query_builder.register("tbl", tbl) + builder = self.query_builder.register("tbl", tbl) # Setup progress bar num_docs_needed: int = tbl.count() From 36eb8ce019e269cc039a60f23114e21efd3c58a7 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 21 Apr 2026 09:47:51 -0700 Subject: [PATCH 04/15] align base and mp resters to inherit from same parent; cache delta table --- mp_api/client/core/client.py | 202 ++++++++++++++++------- mp_api/client/mprester.py | 69 +++----- mp_api/client/routes/materials/thermo.py | 5 +- 3 files changed, 169 insertions(+), 107 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 8d8cd3db..f2fa740a 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -43,6 +43,7 @@ from tqdm.auto import tqdm from urllib3.util.retry import Retry +from mp_api.client._server_utils import get_consumer, get_user_api_key, is_dev_env from mp_api.client.core.exceptions import ( MPRestError, MPRestWarning, @@ -104,24 +105,16 @@ def get(self, item: str, default: Any = None) -> Any: except AttributeError: return default - -class BaseRester: - """Base client class with core stubs.""" - - suffix: str = "" - document_model: type[BaseModel] = _DictLikeAccess - primary_key: str = "material_id" - delta_backed: bool = False - +class _Rester: + """Define base attributes of a REST interface.""" + def __init__( self, api_key: str | None = None, endpoint: str | None = None, include_user_agent: bool = True, - session: requests.Session | None = None, - s3_client: Any | None = None, use_document_model: bool = True, - timeout: int = 20, + session: requests.Session | None = None, headers: dict | None = None, mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS, local_dataset_cache: ( @@ -130,8 +123,8 @@ def __init__( force_renew: bool = False, query_builder: QueryBuilder | None = None, **kwargs, - ): - """Initialize the REST API helper class. + ) -> None: + """Initialize a RESTer. Arguments: api_key: A String API key for accessing the MaterialsProject @@ -150,13 +143,11 @@ def __init__( making the API request. This helps MP support pymatgen users, and is similar to what most web browsers send with each page request. Set to False to disable the user agent. - session: requests Session object with which to connect to the API, for - advanced usage only. - s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores. use_document_model: If False, skip the creating the document model and return data as a dictionary. This can be simpler to work with but bypasses data validation and will not give auto-complete for available fields. - timeout: Time in seconds to wait until a request timeout error is thrown + session: requests Session object with which to connect to the API, for + advanced usage only. headers: Custom headers for localhost connections. mute_progress_bars: Whether to disable progress bars. local_dataset_cache: Target directory for downloading full datasets. Defaults @@ -165,34 +156,36 @@ def __init__( query_builder : Instance of deltalake QueryBuilder to use in querying delta tables **kwargs: access to legacy kwargs that may be in the process of being deprecated """ - self.api_key = validate_api_key(api_key) - self.base_endpoint = validate_endpoint(endpoint) - self.endpoint = validate_endpoint(endpoint, suffix=self.suffix) + self.api_key = get_user_api_key(api_key=api_key) + self.endpoint = validate_endpoint(endpoint) self.include_user_agent = include_user_agent self.use_document_model = use_document_model - self.timeout = timeout - self.headers = headers or {} - self.mute_progress_bars = mute_progress_bars - ( - self.db_version, - self.access_controlled_batch_ids, - ) = BaseRester._get_heartbeat_info(self.base_endpoint) + self.headers = headers or get_consumer() + self._session = session or _Rester._create_session( + api_key=self.api_key, + include_user_agent=self.include_user_agent, + headers=self.headers, + ) - self.local_dataset_cache: Path = Path(local_dataset_cache) - self.force_renew = force_renew + if is_dev_env(): + self._session.headers["x-api-key"] = self.api_key or "" - self._session = session + self.use_document_model = use_document_model + self.mute_progress_bars = mute_progress_bars + self.local_dataset_cache = local_dataset_cache + self.force_renew = force_renew self._query_builder = query_builder - self._s3_client = s3_client if "monty_decode" in kwargs: + # Pop to not repeatedly trigger warning to the user + kwargs.pop("monty_decode",None) warnings.warn( "Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`." "The client by default returns results consistent with `monty_decode=True`.", - category=MPRestWarning, stacklevel=2, + category=MPRestWarning, ) @property @@ -203,15 +196,6 @@ def session(self) -> requests.Session: ) return self._session - @property - def s3_client(self): - if not self._s3_client: - self._s3_client = boto3.client( - "s3", - config=Config(signature_version=UNSIGNED), # type: ignore - ) - return self._s3_client - @property def query_builder(self): if not self._query_builder: @@ -256,6 +240,107 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover self.session.close() self._session = None + +class BaseRester(_Rester): + """Base client class with core stubs.""" + + suffix: str = "" + document_model: type[BaseModel] = _DictLikeAccess + primary_key: str = "material_id" + delta_backed: bool = False + + def __init__( + self, + api_key: str | None = None, + endpoint: str | None = None, + include_user_agent: bool = True, + use_document_model: bool = True, + session: requests.Session | None = None, + headers: dict | None = None, + mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS, + local_dataset_cache: ( + str | os.PathLike + ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, + force_renew: bool = False, + query_builder: QueryBuilder | None = None, + s3_client: Any | None = None, + timeout: int = 20, + **kwargs, + ): + """Initialize the REST API helper class. + + s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores. + timeout: Time in seconds to wait until a request timeout error is thrown + + Arguments: + api_key: A String API key for accessing the MaterialsProject + REST interface. Please obtain your API key at + https://www.materialsproject.org/dashboard. If this is None, + the code will check if there is a "PMG_MAPI_KEY" setting. + If so, it will use that environment variable. This makes + easier for heavy users to simply add this environment variable to + their setups and MPRester can then be called without any arguments. + endpoint: Url of endpoint to access the MaterialsProject REST + interface. Defaults to the standard Materials Project REST + address at "https://api.materialsproject.org", but + can be changed to other urls implementing a similar interface. + include_user_agent: If True, will include a user agent with the + HTTP request including information on pymatgen and system version + making the API request. This helps MP support pymatgen users, and + is similar to what most web browsers send with each page request. + Set to False to disable the user agent. + session: requests Session object with which to connect to the API, for + advanced usage only. + use_document_model: If False, skip the creating the document model and return data + as a dictionary. This can be simpler to work with but bypasses data validation + and will not give auto-complete for available fields. + headers: Custom headers for localhost connections. + mute_progress_bars: Whether to disable progress bars. + local_dataset_cache: Target directory for downloading full datasets. Defaults + to 'mp_datasets' in the user's home directory + force_renew: Option to overwrite existing local dataset + query_builder : Instance of deltalake QueryBuilder to use in querying delta tables + s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores. + timeout: Time in seconds to wait until a request timeout error is thrown + **kwargs: access to legacy kwargs that may be in the process of being deprecated + """ + + super().__init__( + api_key = api_key, + endpoint = endpoint, + include_user_agent = include_user_agent, + use_document_model = use_document_model, + session = session, + headers = headers, + mute_progress_bars = mute_progress_bars, + local_dataset_cache = local_dataset_cache, + force_renew = force_renew, + query_builder = query_builder, + ) + + self.base_endpoint = validate_endpoint(endpoint) + self.endpoint = validate_endpoint(endpoint, suffix=self.suffix) + + ( + self.db_version, + self.access_controlled_batch_ids, + ) = BaseRester._get_heartbeat_info(self.base_endpoint) + + self.timeout = timeout + self._s3_client = s3_client + + self._delta_tables : dict[str,DeltaTable] = {} + + @property + def s3_client(self): + if not self._s3_client: + self._s3_client = boto3.client( + "s3", + config=Config(signature_version=UNSIGNED), # type: ignore + ) + return self._s3_client + + @staticmethod @cache def _get_heartbeat_info(endpoint) -> tuple[str, list[str]]: @@ -465,15 +550,22 @@ def _query_open_data( return decoded_data, len(decoded_data) # type: ignore - @staticmethod - def _get_delta_table(bucket : str, prefix : str) -> DeltaTable: - return DeltaTable( - f"s3a://{bucket}/{prefix}", - storage_options={ - "AWS_SKIP_SIGNATURE": "true", - "AWS_REGION": "us-east-1", - }, - ) + def _get_delta_table(self, bucket : str, prefix : str, connector : str = "s3a") -> DeltaTable: + """Either create a new DeltaTable, or retrieve a cached one. + + Args: + bucket (str) : name of the bucket in S3 + prefix (str) : name of the prefix in S3 + connector (str) : s3, s3n, s3a (default), or other + valid Hadoop connector string. + + Returns: + DeltaTable : If one exists at the specified bucket / prefix, + will retrieve the cached instance. + """ + if (uri := f"{connector}://{bucket}/{prefix}") not in self._delta_tables: + self._delta_tables[uri] = DeltaTable(uri,storage_options={"AWS_SKIP_SIGNATURE": "true","AWS_REGION": "us-east-1"}) + return self._delta_tables[uri] def _query_delta_backed( self, @@ -543,13 +635,7 @@ def _query_delta_backed( ) } - tbl = DeltaTable( - f"s3a://{bucket}/{prefix}", - storage_options={ - "AWS_SKIP_SIGNATURE": "true", - "AWS_REGION": "us-east-1", - }, - ) + tbl = self._get_delta_table(bucket, prefix) controlled_batch_str = ",".join( [f"'{tag}'" for tag in self.access_controlled_batch_ids] diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 5525da69..00d6f4ab 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -22,8 +22,7 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from requests import Session, get -from mp_api.client._server_utils import get_consumer, get_user_api_key, is_dev_env -from mp_api.client.core import BaseRester +from mp_api.client.core.client import _Rester from mp_api.client.core._oxygen_evolution import OxygenEvolution from mp_api.client.core.exceptions import ( MPRestError, @@ -86,17 +85,16 @@ ] -class MPRester: +class MPRester(_Rester): """Access the new Materials Project API.""" def __init__( self, api_key: str | None = None, endpoint: str | None = None, - notify_db_version: bool = False, include_user_agent: bool = True, use_document_model: bool = True, - session: Session | None = None, + session: requests.Session | None = None, headers: dict | None = None, mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS, local_dataset_cache: ( @@ -104,6 +102,7 @@ def __init__( ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, force_renew: bool = False, query_builder: QueryBuilder | None = None, + notify_db_version: bool = False, **kwargs, ): """Initialize the MPRester. @@ -120,13 +119,6 @@ def __init__( interface. Defaults to the standard Materials Project REST address at "https://api.materialsproject.org", but can be changed to other URLs implementing a similar interface. - notify_db_version (bool): If True, the current MP database version will - be retrieved and logged locally in the ~/.mprester.log.yaml. If the database - version changes, you will be notified. The current database version is - also printed on instantiation. These local logs are not sent to - materialsproject.org and are not associated with your API key, so be - aware that a notification may not be presented if you run MPRester - from multiple computing environments. include_user_agent (bool): If True, will include a user agent with the HTTP request including information on pymatgen and system version making the API request. This helps MP support pymatgen users, and @@ -142,25 +134,29 @@ def __init__( to "mp_datasets" in the user's home directory force_renew: Option to overwrite existing local dataset query_builder : Instance of deltalake QueryBuilder to use in querying delta tables + notify_db_version (bool): If True, the current MP database version will + be retrieved and logged locally in the ~/.mprester.log.yaml. If the database + version changes, you will be notified. The current database version is + also printed on instantiation. These local logs are not sent to + materialsproject.org and are not associated with your API key, so be + aware that a notification may not be presented if you run MPRester + from multiple computing environments. **kwargs: access to legacy kwargs that may be in the process of being deprecated """ - self.api_key = get_user_api_key(api_key=api_key) - self.endpoint = validate_endpoint(endpoint) - - self.headers = headers or get_consumer() - self.session = session or BaseRester._create_session( - api_key=self.api_key, - include_user_agent=include_user_agent, - headers=self.headers, + super().__init__( + api_key = api_key, + endpoint = endpoint, + include_user_agent = include_user_agent, + use_document_model = use_document_model, + session = session, + headers = headers, + mute_progress_bars = mute_progress_bars, + local_dataset_cache = local_dataset_cache, + force_renew = force_renew, + query_builder = query_builder, ) - if is_dev_env(): - self.session.headers["x-api-key"] = self.api_key or "" - self._include_user_agent = include_user_agent - self.use_document_model = use_document_model - self.mute_progress_bars = mute_progress_bars - self.local_dataset_cache = local_dataset_cache - self.force_renew = force_renew + self._contribs = None self._deprecated_attributes = [ @@ -193,14 +189,6 @@ def __init__( "chemenv", ] - if "monty_decode" in kwargs: - warnings.warn( - "Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`." - "The client by default returns results consistent with `monty_decode=True`.", - stacklevel=2, - category=MPRestWarning, - ) - # Check if emmet version of server is compatible if (emmet_version := MPRester.get_emmet_version(self.endpoint)) and ( version.parse(emmet_version.base_version) @@ -223,7 +211,6 @@ def __init__( # Instantiate top level core molecules, materials, and DOI resters, as well # as the sunder resters to allow the web server to work. - self._query_builder = query_builder or QueryBuilder() for rest_name, lazy_rester in (RESTER_LAYOUT | GENERIC_RESTERS).items(): if rest_name in TOP_LEVEL_RESTERS: setattr( @@ -232,7 +219,7 @@ def __init__( lazy_rester( api_key=self.api_key, endpoint=self.endpoint, - include_user_agent=self._include_user_agent, + include_user_agent=self.include_user_agent, session=self.session, use_document_model=self.use_document_model, headers=self.headers, @@ -274,14 +261,6 @@ def contribs(self): return self._contribs - def __enter__(self): - """Support for "with" context.""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Support for "with" context.""" - self.session.close() - def __getattr__(self, attr): if attr in self._deprecated_attributes: warnings.warn( diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index 65f05489..a679f467 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -169,10 +169,7 @@ def get_phase_diagram_from_chemsys( sorted_chemsys = "-".join(sorted(chemsys.split("-"))) version = self.db_version.replace(".", "-") - pd_tbl = DeltaTable( - "s3://materialsproject-build/objects/phase-diagrams/", - storage_options={"AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1"}, - ) + pd_tbl = self._get_delta_table("materialsproject-build","objects/phase-diagrams") qb = self._query_builder.register("phase_diagrams", pd_tbl) table = pa.table( qb.execute( From 6fd4671100ebbc0d871673b6c683853ebbf43545 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 21 Apr 2026 09:58:48 -0700 Subject: [PATCH 05/15] precommit --- mp_api/client/core/client.py | 51 +++++++++++++----------- mp_api/client/mprester.py | 26 ++++++------ mp_api/client/routes/materials/thermo.py | 5 ++- 3 files changed, 43 insertions(+), 39 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index f2fa740a..7c20f2e7 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -23,7 +23,6 @@ from itertools import chain, islice from json import JSONDecodeError from math import ceil -from pathlib import Path from typing import TYPE_CHECKING, ForwardRef, Optional, get_args from urllib.parse import urljoin @@ -53,7 +52,6 @@ from mp_api.client.core.utils import ( MPDataset, load_json, - validate_api_key, validate_endpoint, validate_ids, ) @@ -105,9 +103,10 @@ def get(self, item: str, default: Any = None) -> Any: except AttributeError: return default + class _Rester: - """Define base attributes of a REST interface.""" - + """Define base attributes of a REST client.""" + def __init__( self, api_key: str | None = None, @@ -180,7 +179,7 @@ def __init__( if "monty_decode" in kwargs: # Pop to not repeatedly trigger warning to the user - kwargs.pop("monty_decode",None) + kwargs.pop("monty_decode", None) warnings.warn( "Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`." "The client by default returns results consistent with `monty_decode=True`.", @@ -304,18 +303,17 @@ def __init__( timeout: Time in seconds to wait until a request timeout error is thrown **kwargs: access to legacy kwargs that may be in the process of being deprecated """ - super().__init__( - api_key = api_key, - endpoint = endpoint, - include_user_agent = include_user_agent, - use_document_model = use_document_model, - session = session, - headers = headers, - mute_progress_bars = mute_progress_bars, - local_dataset_cache = local_dataset_cache, - force_renew = force_renew, - query_builder = query_builder, + api_key=api_key, + endpoint=endpoint, + include_user_agent=include_user_agent, + use_document_model=use_document_model, + session=session, + headers=headers, + mute_progress_bars=mute_progress_bars, + local_dataset_cache=local_dataset_cache, + force_renew=force_renew, + query_builder=query_builder, ) self.base_endpoint = validate_endpoint(endpoint) @@ -329,7 +327,7 @@ def __init__( self.timeout = timeout self._s3_client = s3_client - self._delta_tables : dict[str,DeltaTable] = {} + self._delta_tables: dict[str, DeltaTable] = {} @property def s3_client(self): @@ -340,7 +338,6 @@ def s3_client(self): ) return self._s3_client - @staticmethod @cache def _get_heartbeat_info(endpoint) -> tuple[str, list[str]]: @@ -550,21 +547,29 @@ def _query_open_data( return decoded_data, len(decoded_data) # type: ignore - def _get_delta_table(self, bucket : str, prefix : str, connector : str = "s3a") -> DeltaTable: + def _get_delta_table( + self, bucket: str, prefix: str, connector: str = "s3a" + ) -> DeltaTable: """Either create a new DeltaTable, or retrieve a cached one. Args: bucket (str) : name of the bucket in S3 prefix (str) : name of the prefix in S3 - connector (str) : s3, s3n, s3a (default), or other + connector (str) : s3, s3n, s3a (default), or other valid Hadoop connector string. - + Returns: DeltaTable : If one exists at the specified bucket / prefix, will retrieve the cached instance. """ - if (uri := f"{connector}://{bucket}/{prefix}") not in self._delta_tables: - self._delta_tables[uri] = DeltaTable(uri,storage_options={"AWS_SKIP_SIGNATURE": "true","AWS_REGION": "us-east-1"}) + if (uri := f"{connector}://{bucket}/{prefix}") not in self._delta_tables: + self._delta_tables[uri] = DeltaTable( + uri, + storage_options={ + "AWS_SKIP_SIGNATURE": "true", + "AWS_REGION": "us-east-1", + }, + ) return self._delta_tables[uri] def _query_delta_backed( diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 00d6f4ab..0a962528 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -22,8 +22,8 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from requests import Session, get -from mp_api.client.core.client import _Rester from mp_api.client.core._oxygen_evolution import OxygenEvolution +from mp_api.client.core.client import _Rester from mp_api.client.core.exceptions import ( MPRestError, MPRestWarning, @@ -33,7 +33,6 @@ from mp_api.client.core.utils import ( LazyImport, load_json, - validate_endpoint, validate_ids, ) from mp_api.client.routes import GENERIC_RESTERS @@ -94,7 +93,7 @@ def __init__( endpoint: str | None = None, include_user_agent: bool = True, use_document_model: bool = True, - session: requests.Session | None = None, + session: Session | None = None, headers: dict | None = None, mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS, local_dataset_cache: ( @@ -143,18 +142,17 @@ def __init__( from multiple computing environments. **kwargs: access to legacy kwargs that may be in the process of being deprecated """ - super().__init__( - api_key = api_key, - endpoint = endpoint, - include_user_agent = include_user_agent, - use_document_model = use_document_model, - session = session, - headers = headers, - mute_progress_bars = mute_progress_bars, - local_dataset_cache = local_dataset_cache, - force_renew = force_renew, - query_builder = query_builder, + api_key=api_key, + endpoint=endpoint, + include_user_agent=include_user_agent, + use_document_model=use_document_model, + session=session, + headers=headers, + mute_progress_bars=mute_progress_bars, + local_dataset_cache=local_dataset_cache, + force_renew=force_renew, + query_builder=query_builder, ) self._contribs = None diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index a679f467..b74742b7 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -4,7 +4,6 @@ import numpy as np import pyarrow as pa -from deltalake import DeltaTable from emmet.core.thermo import ThermoDoc from emmet.core.types.enums import ThermoType from emmet.core.types.pymatgen_types.phase_diagram_adapter import PhaseDiagramType @@ -169,7 +168,9 @@ def get_phase_diagram_from_chemsys( sorted_chemsys = "-".join(sorted(chemsys.split("-"))) version = self.db_version.replace(".", "-") - pd_tbl = self._get_delta_table("materialsproject-build","objects/phase-diagrams") + pd_tbl = self._get_delta_table( + "materialsproject-build", "objects/phase-diagrams" + ) qb = self._query_builder.register("phase_diagrams", pd_tbl) table = pa.table( qb.execute( From b3874ea228f9737b0dbd295ccec5feda18a7dcab Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 21 Apr 2026 13:40:25 -0700 Subject: [PATCH 06/15] query delta tables for electronic structure objects --- mp_api/client/core/client.py | 31 ++++-- mp_api/client/mprester.py | 29 ++++- .../routes/materials/electronic_structure.py | 103 +++++++++++++----- mp_api/client/routes/materials/thermo.py | 11 +- 4 files changed, 127 insertions(+), 47 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 7c20f2e7..9f24dfae 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -268,7 +268,7 @@ def __init__( ): """Initialize the REST API helper class. - s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores. + s3_client: boto3 S3 client object with which to connect to the object stores. timeout: Time in seconds to wait until a request timeout error is thrown Arguments: @@ -548,21 +548,32 @@ def _query_open_data( return decoded_data, len(decoded_data) # type: ignore def _get_delta_table( - self, bucket: str, prefix: str, connector: str = "s3a" - ) -> DeltaTable: + self, + bucket: str, + prefix: str, + connector: str = "s3a", + label: str | None = None, + ) -> tuple[str, DeltaTable]: """Either create a new DeltaTable, or retrieve a cached one. + If creating a new DeltaTable, will also register in self.query_builder + Args: bucket (str) : name of the bucket in S3 prefix (str) : name of the prefix in S3 connector (str) : s3, s3n, s3a (default), or other valid Hadoop connector string. + label (str or None) : optional label for the table in QueryBuilder + If `None`, will be gleaned from the URI Returns: + str : the table name in QueryBuilder DeltaTable : If one exists at the specified bucket / prefix, will retrieve the cached instance. """ - if (uri := f"{connector}://{bucket}/{prefix}") not in self._delta_tables: + full_key = f"{bucket}/{prefix}" + qb_label = label or full_key.replace("/", "_").replace("-", "_") + if (uri := f"{connector}://{full_key}") not in self._delta_tables: self._delta_tables[uri] = DeltaTable( uri, storage_options={ @@ -570,13 +581,16 @@ def _get_delta_table( "AWS_REGION": "us-east-1", }, ) - return self._delta_tables[uri] + self.query_builder.register(qb_label, self._delta_tables[uri]) + + return qb_label, self._delta_tables[uri] def _query_delta_backed( self, bucket: str, prefix: str, timeout: int | None = None, + label: str | None = None, ) -> dict[str, Any]: """Retrieve data from S3 backed by a DeltaTable. @@ -584,6 +598,7 @@ def _query_delta_backed( bucket (str) : S3 OpenData bucket prefix (str) : S3 object prefix timeout (int or None) : timeout on getting access-controlled groups + label (str or None) : label of the table in QueryBuilder Returns: dict of str to Any @@ -640,7 +655,7 @@ def _query_delta_backed( ) } - tbl = self._get_delta_table(bucket, prefix) + tbl_lbl, tbl = self._get_delta_table(bucket, prefix, label=label) controlled_batch_str = ",".join( [f"'{tag}'" for tag in self.access_controlled_batch_ids] @@ -652,8 +667,6 @@ def _query_delta_backed( else "" ) - builder = self.query_builder.register("tbl", tbl) - # Setup progress bar num_docs_needed: int = tbl.count() @@ -675,7 +688,7 @@ def _query_delta_backed( else None ) - iterator = builder.execute(f"SELECT * FROM tbl {predicate}") + iterator = self.query_builder.execute(f"SELECT * FROM {tbl_lbl} {predicate}") file_options = ds.ParquetFileFormat().make_write_options(compression="zstd") diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 0a962528..7423b39e 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -48,6 +48,7 @@ from packaging.version import Version from pymatgen.analysis.phase_diagram import PDEntry from pymatgen.analysis.pourbaix_diagram import PourbaixEntry + from pymatgen.electronic_structure.dos import Dos from pymatgen.entries.compatibility import Compatibility from pymatgen.entries.computed_entries import ( ComputedEntry, @@ -1121,18 +1122,34 @@ def get_bandstructure_by_material_id( material_id=material_id, path_type=path_type, line_mode=line_mode ) - def get_dos_by_material_id(self, material_id: str): - """Get the complete density of states pymatgen object associated with a Materials Project ID. + def get_dos_by_material_id(self, material_id: str) -> Dos: + """Get the density of states pymatgen object associated with a Materials Project ID. Arguments: material_id (str): Materials Project ID for a material Returns: - dos (CompleteDos): CompleteDos object + pymatgen Dos """ - return self.materials.electronic_structure_dos.get_dos_from_material_id( - material_id=material_id - ) # type: ignore + if ( + not ( + es_doc := self.materials.electronic_structure.search( + material_ids=material_id, fields=["dos"] + ) + ) + or not es_doc[0]["dos"] + ): + raise MPRestError(f"No DOS found for {material_id}") + + dos_data = es_doc[0]["dos"] + task_id = dos_data.task_id if self.use_document_model else dos_data["task_id"] + run_type = self.materials.tasks.search(task_ids=[task_id], fields=["run_type"])[ + 0 + ]["run_type"] + return self.materials.electronic_structure_dos.get_dos_from_task_id( + task_id, + run_type=run_type, + ) def get_phonon_dos_by_material_id(self, material_id: str): """Get phonon density of states data corresponding to a material_id. diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index b0bee09e..e94e35fe 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -4,20 +4,27 @@ from collections import defaultdict from typing import TYPE_CHECKING -from emmet.core.band_theory import BSPathType +import pyarrow as pa +from emmet.core.band_theory import BSPathType, ElectronicBS, ElectronicDos from emmet.core.electronic_structure import ( DOSProjectionType, ElectronicStructureDoc, ) +from emmet.core.mpid import AlphaID +from emmet.core.vasp.calc_types.enums import RunType from pymatgen.analysis.magnetism.analyzer import Ordering from pymatgen.core.periodic_table import Element +from pymatgen.electronic_structure.bandstructure import ( + BandStructure, + BandStructureSymmLine, +) from pymatgen.electronic_structure.core import OrbitalType, Spin from mp_api.client.core import BaseRester, MPRestError -from mp_api.client.core.utils import load_json, validate_ids +from mp_api.client.core.utils import validate_ids if TYPE_CHECKING: - from pymatgen.electronic_structure.dos import CompleteDos + from pymatgen.electronic_structure.dos import Dos class ElectronicStructureRester(BaseRester): @@ -255,20 +262,47 @@ def search( **query_params, ) - def get_bandstructure_from_task_id(self, task_id: str): + def get_bandstructure_from_task_id( + self, + task_id: str, + run_type: str | RunType | None = None, + path_type: str | BSPathType | None = None, + ) -> BandStructure: """Get the band structure pymatgen object associated with a given task ID. Arguments: task_id (str): Task ID for the band structure calculation - + run_type (str, RunType, or None): Optional run type, + will speed up query due to delta table partitioning. + path_type (str, BSPathType, or None) : Optional path type to + speed up query Returns: bandstructure (BandStructure): BandStructure or BandStructureSymmLine object """ - return self._query_open_data( # type: ignore[call-overload] - bucket="materialsproject-parsed", - key=f"bandstructures/{validate_ids([task_id])[0]}.json.gz", - decoder=lambda x: load_json(x, deser=True), - )[0][0]["data"] + bs_lbl, bs_tbl = self._get_delta_table( + "materialsproject-parsed", + "core/electronic-structure/bandstructures/", + label="bandstructure", + ) + + selection_string = f"""SELECT * +FROM {bs_lbl} +WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}'""" + if run_type: + rt = RunType(run_type) if isinstance(run_type, str) else run_type + selection_string += f"\nAND run_type='{rt.value}'" + if path_type: + selection_string += f"\nAND path_convention='{path_type}'" + table = pa.table(self.query_builder.execute(selection_string)) + if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0: + emmet_bs = ElectronicBS(**deser[0]) + return emmet_bs.to_pmg( + pmg_cls=BandStructureSymmLine if emmet_bs.labels_dict else BandStructure + ) + raise MPRestError( + f"No bandstructure data found for {task_id=}" + + (f"run_type={rt}" if run_type else "") + ) def get_bandstructure_from_material_id( self, @@ -329,7 +363,10 @@ def get_bandstructure_from_material_id( ) bs_task_id = bs_data["total"]["1"]["task_id"] - bs_obj = self.get_bandstructure_from_task_id(bs_task_id) + bs_obj = self.get_bandstructure_from_task_id( + bs_task_id, + path_type=path_type if line_mode else BSPathType.unknown, + ) if bs_obj: return bs_obj @@ -451,29 +488,46 @@ def search( **query_params, ) - def get_dos_from_task_id(self, task_id: str) -> CompleteDos: + def get_dos_from_task_id( + self, task_id: str, run_type: str | RunType | None = None + ) -> Dos: """Get the density of states pymatgen object associated with a given calculation ID. Arguments: task_id (str): Task ID for the density of states calculation + run_type (str, RunType, or None): Optional run type to query by. + Will speed up query due to delta table partitioning. Returns: - bandstructure (CompleteDos): CompleteDos object + pymatgen Dos """ - return self._query_open_data( # type: ignore[call-overload] - bucket="materialsproject-parsed", - key=f"dos/{validate_ids([task_id])[0]}.json.gz", - decoder=lambda x: load_json(x, deser=True), - )[0][0]["data"] + dos_lbl, dos_tbl = self._get_delta_table( + "materialsproject-parsed", + "core/electronic-structure/total-dos/", + label="total_dos", + ) - def get_dos_from_material_id(self, material_id: str): + selection_string = f"""SELECT * +FROM {dos_lbl} +WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}'""" + if run_type: + rt = RunType(run_type) if isinstance(run_type, str) else run_type + selection_string += f"\nAND run_type='{rt.value}'" + table = pa.table(self.query_builder.execute(selection_string)) + if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0: + return ElectronicDos(**deser[0]).to_pmg() + raise MPRestError( + f"No DOS data found for {task_id=}" + (f"run_type={rt}" if run_type else "") + ) + + def get_dos_from_material_id(self, material_id: str) -> Dos: """Get the complete density of states pymatgen object associated with a Materials Project ID. Arguments: material_id (str): Materials Project ID for a material Returns: - dos (CompleteDos): CompleteDos object + pymatgen Dos """ if not ( dos_doc := self.es_rester.search(material_ids=material_id, fields=["dos"]) @@ -484,9 +538,6 @@ def get_dos_from_material_id(self, material_id: str): raise MPRestError(f"No density of states data found for {material_id}") dos_task_id = (dos_data.model_dump() if self.use_document_model else dos_data)[ - "total" - ]["1"]["task_id"] - if dos_obj := self.get_dos_from_task_id(dos_task_id): - return dos_obj - - raise MPRestError("No density of states object found.") + "task_id" + ] + return self.get_dos_from_task_id(dos_task_id) diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index b74742b7..d1ffd28a 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -166,16 +166,15 @@ def get_phase_diagram_from_chemsys( ) sorted_chemsys = "-".join(sorted(chemsys.split("-"))) - version = self.db_version.replace(".", "-") + version = "2026-04-13" # self.db_version.replace(".", "-") - pd_tbl = self._get_delta_table( - "materialsproject-build", "objects/phase-diagrams" + pd_lbl, pd_tbl = self._get_delta_table( + "materialsproject-build", "objects/phase-diagrams", label="phase_diagrams" ) - qb = self._query_builder.register("phase_diagrams", pd_tbl) table = pa.table( - qb.execute( + self.query_builder.execute( f"""SELECT phase_diagram - FROM phase_diagrams + FROM {pd_lbl} WHERE chemsys='{sorted_chemsys}' AND version='{version}' AND thermo_type='{thermo_type}' From 900e69225c7767a6a47d2f0562b438e4318c5c84 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 21 Apr 2026 14:05:03 -0700 Subject: [PATCH 07/15] update static collection-backed endpoints --- mp_api/client/core/client.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 9f24dfae..d2e92e88 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -69,6 +69,17 @@ except PackageNotFoundError: # pragma: no cover __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION", "") +STATIC_COLLECTIONS = [ + "eos", + "grain_boundaries", + "jcesr", + "molecules", + "phonon", + "snls", + "surface-properties", + "synth-descriptions", + "xas", +] hdlr = logging.StreamHandler() fmt = logging.Formatter("%(name)s - %(levelname)s - %(message)s") @@ -834,6 +845,9 @@ def _query_resource( if "tasks" in suffix: bucket_suffix, prefix = ("parsed", "core/tasks/") + elif suffix in STATIC_COLLECTIONS: + bucket_suffix = "build" + prefix = f"static-collections/{suffix}" else: bucket_suffix = "build" prefix = f"collections/{self.db_version.replace('.', '-')}/{suffix}" From d4344a18792cd00f39303aaa6f325da2072aaae5 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Wed, 22 Apr 2026 09:52:11 -0700 Subject: [PATCH 08/15] move task trajectory to cached deltatable --- mp_api/client/mprester.py | 2 +- mp_api/client/routes/materials/tasks.py | 17 +++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 7423b39e..27af9f0b 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -7,7 +7,6 @@ from functools import cache, lru_cache from typing import TYPE_CHECKING -from deltalake import QueryBuilder from emmet.core.band_theory import BSPathType from emmet.core.mpid import MPID, AlphaID from emmet.core.types.enums import ThermoType @@ -43,6 +42,7 @@ from collections.abc import Sequence from typing import Any, Literal + from deltalake import QueryBuilder import numpy as np from emmet.core.tasks import CoreTaskDoc from packaging.version import Version diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index cc707656..97aa9440 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING import pyarrow as pa -from deltalake import DeltaTable, QueryBuilder from emmet.core.mpid import MPID, AlphaID from emmet.core.tasks import CoreTaskDoc from emmet.core.trajectory import RelaxTrajectory @@ -36,22 +35,20 @@ def get_trajectory(self, task_id: MPID | AlphaID | str) -> dict[str, Any]: dict representing emmet.core.trajectory.RelaxTrajectory """ as_alpha = str(AlphaID(task_id, padlen=8)).split("-")[-1] - traj_tbl = DeltaTable( - "s3a://materialsproject-parsed/core/trajectories/", - storage_options={"AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1"}, + traj_lbl, traj_tbl = self._get_delta_table( + "materialsproject-parsed", + "core/trajectories/", + label = "traj", ) traj_data = pa.table( - QueryBuilder() - .register("traj", traj_tbl) - .execute( + self.query_builder.execute( f""" SELECT * - FROM traj + FROM {traj_lbl} WHERE identifier='{as_alpha}' """ - ) - .read_all() + ).read_all() ).to_pylist(maps_as_pydicts="strict") if not traj_data: From bca2ebcf603e602d814db41d3880bee40ced5ff4 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Wed, 22 Apr 2026 10:25:06 -0700 Subject: [PATCH 09/15] change default to delta backed, mark certain endpoints as not delta backed --- mp_api/client/core/client.py | 2 +- mp_api/client/mprester.py | 2 +- mp_api/client/routes/materials/doi.py | 1 + mp_api/client/routes/materials/electronic_structure.py | 2 ++ mp_api/client/routes/materials/grain_boundaries.py | 1 + mp_api/client/routes/materials/phonon.py | 1 + mp_api/client/routes/materials/similarity.py | 1 + mp_api/client/routes/materials/substrates.py | 1 + mp_api/client/routes/materials/surface_properties.py | 1 + mp_api/client/routes/materials/synthesis.py | 1 + mp_api/client/routes/materials/tasks.py | 2 +- mp_api/client/routes/materials/xas.py | 1 + mp_api/client/routes/molecules/jcesr.py | 1 + mp_api/client/routes/molecules/molecules.py | 1 + mp_api/client/routes/molecules/summary.py | 1 + 15 files changed, 16 insertions(+), 3 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index d2e92e88..0d5083fe 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -257,7 +257,7 @@ class BaseRester(_Rester): suffix: str = "" document_model: type[BaseModel] = _DictLikeAccess primary_key: str = "material_id" - delta_backed: bool = False + delta_backed: bool = True def __init__( self, diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 27af9f0b..02a7d675 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -42,8 +42,8 @@ from collections.abc import Sequence from typing import Any, Literal - from deltalake import QueryBuilder import numpy as np + from deltalake import QueryBuilder from emmet.core.tasks import CoreTaskDoc from packaging.version import Version from pymatgen.analysis.phase_diagram import PDEntry diff --git a/mp_api/client/routes/materials/doi.py b/mp_api/client/routes/materials/doi.py index c55e3758..26b268ca 100644 --- a/mp_api/client/routes/materials/doi.py +++ b/mp_api/client/routes/materials/doi.py @@ -12,6 +12,7 @@ class DOIRester(BaseRester): suffix = "doi" document_model = DOIDoc # type: ignore primary_key = "material_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index e94e35fe..01cb6130 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -174,6 +174,7 @@ def es_rester(self) -> ElectronicStructureRester: class BandStructureRester(BaseESPropertyRester): suffix = "materials/electronic_structure/bandstructure" + delta_backed = False def search_bandstructure_summary(self, *args, **kwargs): # pragma: no cover """Deprecated.""" @@ -375,6 +376,7 @@ def get_bandstructure_from_material_id( class DosRester(BaseESPropertyRester): suffix = "materials/electronic_structure/dos" + delta_backed = False def search_dos_summary(self, *args, **kwargs): # pragma: no cover """Deprecated.""" diff --git a/mp_api/client/routes/materials/grain_boundaries.py b/mp_api/client/routes/materials/grain_boundaries.py index 6949b9de..d9ac75c3 100644 --- a/mp_api/client/routes/materials/grain_boundaries.py +++ b/mp_api/client/routes/materials/grain_boundaries.py @@ -12,6 +12,7 @@ class GrainBoundaryRester(BaseRester): suffix = "materials/grain_boundaries" document_model = GrainBoundaryDoc # type: ignore primary_key = "material_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/phonon.py b/mp_api/client/routes/materials/phonon.py index 0373cd0d..c3d9db45 100644 --- a/mp_api/client/routes/materials/phonon.py +++ b/mp_api/client/routes/materials/phonon.py @@ -18,6 +18,7 @@ class PhononRester(BaseRester): suffix = "materials/phonon" document_model = PhononBSDOSDoc # type: ignore primary_key = "material_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/similarity.py b/mp_api/client/routes/materials/similarity.py index aa6cab71..0ba8c5b7 100644 --- a/mp_api/client/routes/materials/similarity.py +++ b/mp_api/client/routes/materials/similarity.py @@ -26,6 +26,7 @@ class SimilarityRester(BaseRester): suffix = "materials/similarity" document_model = SimilarityDoc # type: ignore primary_key = "material_id" + delta_backed = False _fingerprinter: SimilarityScorer | None = None diff --git a/mp_api/client/routes/materials/substrates.py b/mp_api/client/routes/materials/substrates.py index 62eaa676..6f1096b1 100644 --- a/mp_api/client/routes/materials/substrates.py +++ b/mp_api/client/routes/materials/substrates.py @@ -11,6 +11,7 @@ class SubstratesRester(BaseRester): suffix = "materials/substrates" document_model = SubstratesDoc # type: ignore primary_key = "film_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/surface_properties.py b/mp_api/client/routes/materials/surface_properties.py index 76d9e60c..3a36d5f9 100644 --- a/mp_api/client/routes/materials/surface_properties.py +++ b/mp_api/client/routes/materials/surface_properties.py @@ -12,6 +12,7 @@ class SurfacePropertiesRester(BaseRester): suffix = "materials/surface_properties" document_model = SurfacePropDoc # type: ignore primary_key = "material_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/synthesis.py b/mp_api/client/routes/materials/synthesis.py index 6788814c..4567c51f 100644 --- a/mp_api/client/routes/materials/synthesis.py +++ b/mp_api/client/routes/materials/synthesis.py @@ -12,6 +12,7 @@ class SynthesisRester(BaseRester): suffix = "materials/synthesis" document_model = SynthesisSearchResultModel # type: ignore + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 97aa9440..3f3a7e31 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -38,7 +38,7 @@ def get_trajectory(self, task_id: MPID | AlphaID | str) -> dict[str, Any]: traj_lbl, traj_tbl = self._get_delta_table( "materialsproject-parsed", "core/trajectories/", - label = "traj", + label="traj", ) traj_data = pa.table( diff --git a/mp_api/client/routes/materials/xas.py b/mp_api/client/routes/materials/xas.py index fc71ae7f..8fbbb3db 100644 --- a/mp_api/client/routes/materials/xas.py +++ b/mp_api/client/routes/materials/xas.py @@ -17,6 +17,7 @@ class XASRester(BaseRester): suffix = "materials/xas" document_model = XASDoc # type: ignore primary_key = "spectrum_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/molecules/jcesr.py b/mp_api/client/routes/molecules/jcesr.py index 2d462c19..24d3f5e6 100644 --- a/mp_api/client/routes/molecules/jcesr.py +++ b/mp_api/client/routes/molecules/jcesr.py @@ -15,6 +15,7 @@ class JcesrMoleculesRester(BaseRester): suffix = "molecules/jcesr" document_model = MoleculesDoc # type: ignore primary_key = "task_id" + delta_backed = False def __init__(self, **kwargs): """Throw deprecation warning when JCESR client is initialized.""" diff --git a/mp_api/client/routes/molecules/molecules.py b/mp_api/client/routes/molecules/molecules.py index b7600328..3171b55c 100644 --- a/mp_api/client/routes/molecules/molecules.py +++ b/mp_api/client/routes/molecules/molecules.py @@ -20,3 +20,4 @@ class MoleculeRester(CoreRester): primary_key = "molecule_id" suffix = "molecules/core" _sub_resters = MOLECULES_RESTERS + delta_backed = False diff --git a/mp_api/client/routes/molecules/summary.py b/mp_api/client/routes/molecules/summary.py index 4be3aab5..2f91677e 100644 --- a/mp_api/client/routes/molecules/summary.py +++ b/mp_api/client/routes/molecules/summary.py @@ -12,6 +12,7 @@ class MoleculesSummaryRester(BaseRester): suffix = "molecules/summary" document_model = MoleculeSummaryDoc # type: ignore primary_key = "molecule_id" + delta_backed = False def search( self, From f4d0348c8250d406bd5d46049d6841750fdda222 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Wed, 29 Apr 2026 17:25:18 -0700 Subject: [PATCH 10/15] add some default timeout params for deltatables + single entry point for querying delta objects that aren't full downloads + multline SQL style formatting --- mp_api/client/core/client.py | 35 +++++++++++++++++++ .../routes/materials/electronic_structure.py | 35 +++++++++++-------- mp_api/client/routes/materials/tasks.py | 16 ++++----- mp_api/client/routes/materials/thermo.py | 15 ++++---- 4 files changed, 69 insertions(+), 32 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 0d5083fe..96f04544 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -582,6 +582,7 @@ def _get_delta_table( DeltaTable : If one exists at the specified bucket / prefix, will retrieve the cached instance. """ + delta_timeout = f"{self.timeout}s" full_key = f"{bucket}/{prefix}" qb_label = label or full_key.replace("/", "_").replace("-", "_") if (uri := f"{connector}://{full_key}") not in self._delta_tables: @@ -590,12 +591,46 @@ def _get_delta_table( storage_options={ "AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1", + "timeout": delta_timeout, + "connect_timeout": delta_timeout, + "retry_delay": "3", + "max_retries": f"{MAPI_CLIENT_SETTINGS.MAX_RETRIES}", }, ) self.query_builder.register(qb_label, self._delta_tables[uri]) return qb_label, self._delta_tables[uri] + def _query_delta_single(self, query: str) -> pa.Table: + """Execute a SQL query against a registered Delta table. + + Wraps the query execution in a try/except to provide a more + actionable error message when the underlying Delta query engine + fails (e.g., due to network timeouts, missing tables, or + malformed queries). + + Args: + query (str): A SQL query string compatible with the + QueryBuilder engine. + + Returns: + pa.Table: The query result as a PyArrow Table. + + Raises: + MPRestError: If query execution fails for any reason, + including network timeouts, connectivity issues, or + invalid queries. Inspect the chained exception for + the underlying cause. + """ + try: + return pa.table(self.query_builder.execute(query).read_all()) + except Exception as e: + raise MPRestError( + f"Failed to retrieve object due to: {e}. " + f"If this is a timeout error, try increasing the 'timeout' " + f"parameter on MPRester (current value: {self.timeout}s)." + ) from e + def _query_delta_backed( self, bucket: str, diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index 01cb6130..c97bf389 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -6,10 +6,7 @@ import pyarrow as pa from emmet.core.band_theory import BSPathType, ElectronicBS, ElectronicDos -from emmet.core.electronic_structure import ( - DOSProjectionType, - ElectronicStructureDoc, -) +from emmet.core.electronic_structure import DOSProjectionType, ElectronicStructureDoc from emmet.core.mpid import AlphaID from emmet.core.vasp.calc_types.enums import RunType from pymatgen.analysis.magnetism.analyzer import Ordering @@ -286,15 +283,19 @@ def get_bandstructure_from_task_id( label="bandstructure", ) - selection_string = f"""SELECT * -FROM {bs_lbl} -WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}'""" + query = f""" + SELECT * + FROM {bs_lbl} + WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}' + """ + if run_type: rt = RunType(run_type) if isinstance(run_type, str) else run_type - selection_string += f"\nAND run_type='{rt.value}'" + query += f"\nAND run_type='{rt.value}'" if path_type: - selection_string += f"\nAND path_convention='{path_type}'" - table = pa.table(self.query_builder.execute(selection_string)) + query += f"\nAND path_convention='{path_type}'" + + table = self._query_delta_single(query) if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0: emmet_bs = ElectronicBS(**deser[0]) return emmet_bs.to_pmg( @@ -509,13 +510,17 @@ def get_dos_from_task_id( label="total_dos", ) - selection_string = f"""SELECT * -FROM {dos_lbl} -WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}'""" + query = f""" + SELECT * + FROM {dos_lbl} + WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}' + """ + if run_type: rt = RunType(run_type) if isinstance(run_type, str) else run_type - selection_string += f"\nAND run_type='{rt.value}'" - table = pa.table(self.query_builder.execute(selection_string)) + query += f"\nAND run_type='{rt.value}'" + + table = self._query_delta_single(query) if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0: return ElectronicDos(**deser[0]).to_pmg() raise MPRestError( diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 3f3a7e31..af5dae5e 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -41,15 +41,13 @@ def get_trajectory(self, task_id: MPID | AlphaID | str) -> dict[str, Any]: label="traj", ) - traj_data = pa.table( - self.query_builder.execute( - f""" - SELECT * - FROM {traj_lbl} - WHERE identifier='{as_alpha}' - """ - ).read_all() - ).to_pylist(maps_as_pydicts="strict") + query = f""" + SELECT * + FROM {traj_lbl} + WHERE identifier='{as_alpha}' + """ + + traj_data = self._query_delta_single(query).to_pylist(maps_as_pydicts="strict") if not traj_data: raise MPRestError(f"No trajectory data for {task_id} found") diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index d1ffd28a..47cdf159 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -171,16 +171,15 @@ def get_phase_diagram_from_chemsys( pd_lbl, pd_tbl = self._get_delta_table( "materialsproject-build", "objects/phase-diagrams", label="phase_diagrams" ) - table = pa.table( - self.query_builder.execute( - f"""SELECT phase_diagram + + query = f""" + SELECT phase_diagram FROM {pd_lbl} WHERE chemsys='{sorted_chemsys}' - AND version='{version}' - AND thermo_type='{thermo_type}' - """ - ) - ) + AND version='{version}' + AND thermo_type='{thermo_type}' + """ + table = self._query_delta_single(query) as_py = table["phase_diagram"].to_pylist(maps_as_pydicts="strict") pd: PhaseDiagram | None = None From 3f9bec4c44de85b10212341082bdebf607fdc74e Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Fri, 1 May 2026 14:13:35 -0700 Subject: [PATCH 11/15] pre-commit --- mp_api/client/routes/materials/thermo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index fa17f4c0..9f2f82b5 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -3,14 +3,12 @@ from collections import defaultdict import numpy as np -import pyarrow as pa from emmet.core.thermo import ThermoDoc from emmet.core.types.enums import ThermoType from emmet.core.types.pymatgen_types.phase_diagram_adapter import PhaseDiagramType from pydantic import TypeAdapter from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.core import Element -from pymatgen.core import __version__ as __pmg_version__ from mp_api.client.core import BaseRester from mp_api.client.core.utils import validate_ids From 2655407101348994ecd62ffa3ccf383f71b12895 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Mon, 4 May 2026 11:30:33 -0700 Subject: [PATCH 12/15] fix eos, task traj, add __dir__ to lazy import + fix attr access --- mp_api/_test_utils.py | 1 + mp_api/client/core/client.py | 19 ++++++----- mp_api/client/core/utils.py | 12 +++++-- .../routes/materials/electronic_structure.py | 12 +++++-- mp_api/client/routes/materials/eos.py | 32 ++++++++++++++----- mp_api/client/routes/materials/tasks.py | 7 ++-- .../materials/test_electronic_structure.py | 2 +- tests/client/materials/test_eos.py | 17 ++++++++-- 8 files changed, 73 insertions(+), 29 deletions(-) diff --git a/mp_api/_test_utils.py b/mp_api/_test_utils.py index 5d4044c9..b8065d70 100644 --- a/mp_api/_test_utils.py +++ b/mp_api/_test_utils.py @@ -78,6 +78,7 @@ def client_search_testing( doc = docs[0].model_dump() else: raise ValueError("No documents returned") + print(doc) for sub_field in sub_doc_fields: if sub_field in doc: diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 2eb055b3..5398fbdb 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -23,6 +23,7 @@ from itertools import chain, islice from json import JSONDecodeError from math import ceil +from pathlib import Path from typing import TYPE_CHECKING, ForwardRef, Optional, get_args from urllib.parse import urljoin @@ -184,7 +185,7 @@ def __init__( self.use_document_model = use_document_model self.mute_progress_bars = mute_progress_bars - self.local_dataset_cache = local_dataset_cache + self.local_dataset_cache = Path(local_dataset_cache) self.force_renew = force_renew self._query_builder = query_builder @@ -1436,12 +1437,7 @@ def _convert_to_model( ) return [ - data_model( - **{ - field: raw_doc[field] - for field in set_fields.intersection(raw_doc) - } - ) + data_model(**raw_doc) for raw_doc in (data if is_list else chain([first_doc], data)) ] @@ -1464,7 +1460,14 @@ def _generate_returned_model( set of str: set_fields, fields_not_requested) """ model_fields = self.document_model.model_fields - set_fields = set(doc).intersection(model_fields) + aliases = { + anno.alias: field for field, anno in model_fields.items() if anno.alias + } + set_fields = ( + set(doc) + .intersection(model_fields) + .union({aliases[k] for k in set(doc).intersection(aliases)}) + ) unset_fields = set(model_fields).difference(set_fields) user_requested_fields: list[str] = requested_fields or [] fields_not_requested = unset_fields.difference(user_requested_fields) diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 95793591..a6857ef1 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -116,8 +116,8 @@ def validate_ids(id_list: list[str]) -> list[str]: " data for all IDs and filter locally." ) - [validate_identifier(idx, serialize=False) for idx in id_list] - return [getattr(idx, "string", str(idx)) for idx in id_list] + validated = [validate_identifier(idx, serialize=False) for idx in id_list] + return [getattr(idx, "string", str(idx)) for idx in validated] def validate_endpoint(endpoint: str | None, suffix: str | None = None) -> str: @@ -241,6 +241,14 @@ def __getattr__(self, v: str) -> Any: if hasattr(self._imported, v): return getattr(self._imported, v) + raise AttributeError( + f"{self._module_name}{'.' + self._class_name if self._class_name else ''} " + f"has no attribute {v}" + ) + + def __dir__(self) -> list[str]: + return self._obj.__dir__() + class MPDataset: """Convenience wrapper for pyarrow datasets stored on disk.""" diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index eee9a8a5..87bb40b7 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -326,7 +326,9 @@ def get_bandstructure_from_material_id( material_ids=material_id, fields=["bandstructure"] ) if not bs_doc: - raise MPRestError("No electronic structure data found.") + raise MPRestError( + f"No electronic structure data found for material ID {material_id}." + ) if (_bs_data := bs_doc[0]["bandstructure"]) is None: raise MPRestError( @@ -349,7 +351,9 @@ def get_bandstructure_from_material_id( material_ids=material_id, fields=["dos"] ) ): - raise MPRestError("No electronic structure data found.") + raise MPRestError( + f"No electronic structure data found for material ID {material_id}." + ) if (_bs_data := bs_doc[0]["dos"]) is None: raise MPRestError( @@ -538,7 +542,9 @@ def get_dos_from_material_id(self, material_id: str) -> Dos: if not ( dos_doc := self.es_rester.search(material_ids=material_id, fields=["dos"]) ): - return None + raise MPRestError( + f"No electronic structure data found for material ID {material_id}." + ) if not (dos_data := dos_doc[0].get("dos")): raise MPRestError(f"No density of states data found for {material_id}") diff --git a/mp_api/client/routes/materials/eos.py b/mp_api/client/routes/materials/eos.py index 0182eb6f..2300db94 100644 --- a/mp_api/client/routes/materials/eos.py +++ b/mp_api/client/routes/materials/eos.py @@ -1,32 +1,34 @@ from __future__ import annotations +import warnings from collections import defaultdict from emmet.core.eos import EOSDoc -from mp_api.client.core import BaseRester +from mp_api.client.core import BaseRester, MPRestError, MPRestWarning from mp_api.client.core.utils import validate_ids class EOSRester(BaseRester): suffix = "materials/eos" document_model = EOSDoc # type: ignore - primary_key = "material_id" + primary_key = "task_id" def search( self, - material_ids: str | list[str] | None = None, + task_ids: str | list[str] | None = None, energies: tuple[float, float] | None = None, volumes: tuple[float, float] | None = None, num_chunks: int | None = None, chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, + **kwargs, ) -> list[EOSDoc] | list[dict]: """Query equations of state docs using a variety of search criteria. Arguments: - material_ids (str, List[str]): Search for equation of states associated with the specified Material IDs + task_ids (str, List[str]): Search for equation of states associated with the specified task IDs energies (Tuple[float,float]): Minimum and maximum energy in eV/atom to consider for EOS plot range. volumes (Tuple[float,float]): Minimum and maximum volume in A³/atom to consider for EOS plot range. num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. @@ -34,17 +36,31 @@ def search( all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in EOSDoc to return data for. Default is material_id only if all_fields is False. + **kwargs : used for handling deprecated kwargs Returns: ([EOSDoc], [dict]) List of equations of state docs or dictionaries. """ query_params: dict = defaultdict(dict) - if material_ids: - if isinstance(material_ids, str): - material_ids = [material_ids] + if "material_ids" in kwargs: + if task_ids: + raise MPRestError( + "You have specified both `task_ids` and the deprecated `material_ids` tag. " + "Please specify only `task_ids`." + ) + task_ids = kwargs.pop("material_ids") + warnings.warn( + "`material_id` has been replaced by `task_id` in the EOS endpoint. " + "Please migrate to using the newer field name.", + stacklevel=2, + category=MPRestWarning, + ) - query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + if task_ids: + query_params["material_ids"] = ",".join( + validate_ids([task_ids] if isinstance(task_ids, str) else task_ids) + ) if volumes: query_params.update({"volumes_min": volumes[0], "volumes_max": volumes[1]}) diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 1314254b..c55aeb85 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -39,10 +39,8 @@ def get_trajectory( """ as_alpha = str(AlphaID(task_id, padlen=8)).split("-")[-1] predicate = ( - f"WHERE run_type='{str(run_type)}' AND identifier='{as_alpha}'" - if run_type - else f"WHERE identifier='{as_alpha}'" - ) + f"WHERE run_type='{str(run_type)}' AND " if run_type else "" + ) + f"WHERE identifier='{as_alpha}'" traj_lbl, traj_tbl = self._get_delta_table( "materialsproject-parsed", @@ -53,7 +51,6 @@ def get_trajectory( query = f""" SELECT * FROM {traj_lbl} - WHERE identifier='{as_alpha}' {predicate}; """ diff --git a/tests/client/materials/test_electronic_structure.py b/tests/client/materials/test_electronic_structure.py index a89cc730..6b4b37c6 100644 --- a/tests/client/materials/test_electronic_structure.py +++ b/tests/client/materials/test_electronic_structure.py @@ -104,7 +104,7 @@ def test_bs_client(): with pytest.raises(MPRestError, match="No electronic structure data found."): _ = bs_rester.get_bandstructure_from_material_id("mp-0") - with pytest.raises(MPRestError, match="No object found"): + with pytest.raises(MPRestError, match="No bandstructure data found"): _ = bs_rester.get_bandstructure_from_task_id("mp-0") diff --git a/tests/client/materials/test_eos.py b/tests/client/materials/test_eos.py index 3e633e49..e71fc010 100644 --- a/tests/client/materials/test_eos.py +++ b/tests/client/materials/test_eos.py @@ -4,6 +4,7 @@ from mp_api._test_utils import client_search_testing, requires_api_key +from mp_api.client.core.exceptions import MPRestError, MPRestWarning from mp_api.client.routes.materials.eos import EOSRester @@ -26,9 +27,9 @@ def rester(): sub_doc_fields: list = [] -alt_name_dict: dict = {"material_ids": "material_id"} +alt_name_dict: dict = {"task_ids": "task_id"} -custom_field_tests: dict = {"material_ids": ["mp-149"]} +custom_field_tests: dict = {"task_ids": ["mp-149"]} @requires_api_key @@ -42,3 +43,15 @@ def test_client(rester): custom_field_tests=custom_field_tests, sub_doc_fields=sub_doc_fields, ) + + +@requires_api_key +def test_warnings_errors(rester): + + with pytest.warns( + MPRestWarning, match="`material_id` has been replaced by `task_id`" + ): + rester.search(material_ids=["mp-149"], num_chunks=1, chunk_size=1) + + with pytest.raises(MPRestError, match="You have specified both"): + rester.search(material_ids=["mp-149"], task_ids=["mp-1"]) From 00c2423f7600153bd487dfa20b1ab37c2b00807b Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Mon, 4 May 2026 14:49:45 -0700 Subject: [PATCH 13/15] remove print --- mp_api/_test_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mp_api/_test_utils.py b/mp_api/_test_utils.py index b8065d70..5d4044c9 100644 --- a/mp_api/_test_utils.py +++ b/mp_api/_test_utils.py @@ -78,7 +78,6 @@ def client_search_testing( doc = docs[0].model_dump() else: raise ValueError("No documents returned") - print(doc) for sub_field in sub_doc_fields: if sub_field in doc: From 45f8e949a37611dc7700a63a43d6c9dccf07d178 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Fri, 8 May 2026 11:00:01 -0700 Subject: [PATCH 14/15] mypy --- mp_api/client/core/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index a6857ef1..25c27645 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -27,6 +27,7 @@ from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS if TYPE_CHECKING: + from collections.abc import Iterable from typing import Any, Literal from pydantic._internal._model_construction import ModelMetaclass @@ -246,8 +247,8 @@ def __getattr__(self, v: str) -> Any: f"has no attribute {v}" ) - def __dir__(self) -> list[str]: - return self._obj.__dir__() + def __dir__(self) -> Iterable[str]: + return self._obj.__dir__() if hasattr(self._obj, "__dir__") else [] class MPDataset: From 81ee53b8f0a166e6b318bf7b0a6b945e9b1db17b Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 12 May 2026 15:57:10 -0700 Subject: [PATCH 15/15] subclass query builder to use saved delta tables --- mp_api/client/core/client.py | 77 +++++++++++++++---- mp_api/client/mprester.py | 7 +- .../routes/materials/electronic_structure.py | 4 +- mp_api/client/routes/materials/tasks.py | 2 +- mp_api/client/routes/materials/thermo.py | 4 +- 5 files changed, 73 insertions(+), 21 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 4e40b35e..58926f22 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -59,7 +59,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator - from typing import Any + from typing import Any, Self from mp_api.client.core.utils import LazyImport @@ -97,6 +97,31 @@ def _batched(iterable: Iterable, n: int) -> Iterator: yield batch +class QueryBuilderWithCache(QueryBuilder): + + def __init__(self) -> None: + """Extend deltalake.QueryBuilder with stored DeltaTables. + + The deltalake.QueryBuilder class does not permit introspection + of registered DeltaTables through the python API. + + Re-registering a DeltaTable + (1) wastes time by reading its metadata + (2) raises an exception because a table is already registered + + This class simply allows for caching the DeltaTable instances + and table names on the QueryBuilder class. + """ + # Dict of table names (labels) to DeltaTable instances + self._delta_tables: dict[str, DeltaTable] = {} + super().__init__() + + def register(self, table_name: str, delta_table: DeltaTable) -> Self: + """Register and cache a DeltaTable.""" + self._delta_tables[table_name] = delta_table + return super().register(table_name, delta_table) + + class _Rester: """Define base attributes of a REST client.""" @@ -113,7 +138,7 @@ def __init__( str | os.PathLike ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, force_renew: bool = False, - query_builder: QueryBuilder | None = None, + query_builder: QueryBuilderWithCache | None = None, **kwargs, ) -> None: """Initialize a RESTer. @@ -145,7 +170,8 @@ def __init__( local_dataset_cache: Target directory for downloading full datasets. Defaults to 'mp_datasets' in the user's home directory force_renew: Option to overwrite existing local dataset - query_builder : Instance of deltalake QueryBuilder to use in querying delta tables + query_builder : Instance of QueryBuilderWithCache to use in querying delta tables + NOTE: Must be a QueryBuilderWithCache, a deltalake.QueryBuilder will be ignored. **kwargs: access to legacy kwargs that may be in the process of being deprecated """ self.api_key = get_user_api_key(api_key=api_key) @@ -168,7 +194,9 @@ def __init__( self.mute_progress_bars = mute_progress_bars self.local_dataset_cache = Path(local_dataset_cache) self.force_renew = force_renew - self._query_builder = query_builder + self._query_builder = ( + query_builder if isinstance(query_builder, QueryBuilderWithCache) else None + ) if "monty_decode" in kwargs: # Pop to not repeatedly trigger warning to the user @@ -191,7 +219,7 @@ def session(self) -> requests.Session: @property def query_builder(self): if not self._query_builder: - self._query_builder = QueryBuilder() + self._query_builder = QueryBuilderWithCache() return self._query_builder @staticmethod @@ -320,8 +348,6 @@ def __init__( self.timeout = timeout self._s3_client = s3_client - self._delta_tables: dict[str, DeltaTable] = {} - @property def s3_client(self): if not self._s3_client: @@ -556,19 +582,34 @@ def _get_delta_table( prefix (str) : name of the prefix in S3 connector (str) : s3, s3n, s3a (default), or other valid Hadoop connector string. - label (str or None) : optional label for the table in QueryBuilder + label (str or None) : optional label for the table in the + cached query builder If `None`, will be gleaned from the URI Returns: - str : the table name in QueryBuilder + str : the table name in the stored query builder DeltaTable : If one exists at the specified bucket / prefix, will retrieve the cached instance. """ delta_timeout = f"{self.timeout}s" full_key = f"{bucket}/{prefix}" qb_label = label or full_key.replace("/", "_").replace("-", "_") - if (uri := f"{connector}://{full_key}") not in self._delta_tables: - self._delta_tables[uri] = DeltaTable( + + uri = f"{connector}://{full_key}" + if not uri.endswith("/"): + uri += "/" + + try: + stored_label, delta_table = next( + (_label, _table) + for _label, _table in self.query_builder._delta_tables.items() + if _table.table_uri == uri + ) + except StopIteration: + stored_label = None + + if stored_label is None: + delta_table = DeltaTable( uri, storage_options={ "AWS_SKIP_SIGNATURE": "true", @@ -579,9 +620,19 @@ def _get_delta_table( "max_retries": f"{MAPI_CLIENT_SETTINGS.MAX_RETRIES}", }, ) - self.query_builder.register(qb_label, self._delta_tables[uri]) + self.query_builder.register(qb_label, delta_table) + + elif stored_label != qb_label: + warnings.warn( + f"DeltaTable with URI {uri} already found with different label: " + f"Stored label = {stored_label}; submitted label {qb_label}. " + "Using stored DeltaTable.", + category=MPRestWarning, + stacklevel=2, + ) + return stored_label, delta_table - return qb_label, self._delta_tables[uri] + return qb_label, delta_table def _query_delta_single(self, query: str) -> pa.Table: """Execute a SQL query against a registered Delta table. diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 8df0cb00..50239e7b 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -43,7 +43,6 @@ from typing import Any, Literal import numpy as np - from deltalake import QueryBuilder from emmet.core.tasks import CoreTaskDoc from packaging.version import Version from pymatgen.analysis.phase_diagram import PDEntry @@ -56,6 +55,7 @@ ) from pymatgen.util.typing import SpeciesLike + from mp_api.client.core.client import QueryBuilderWithCache from mp_api.client.core.schemas import _DictLikeAccess DEFAULT_THERMOTYPE_CRITERIA = {"thermo_types": ["GGA_GGA+U_R2SCAN"]} @@ -101,7 +101,7 @@ def __init__( str | os.PathLike ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, force_renew: bool = False, - query_builder: QueryBuilder | None = None, + query_builder: QueryBuilderWithCache | None = None, notify_db_version: bool = False, **kwargs, ): @@ -133,7 +133,8 @@ def __init__( local_dataset_cache: Target directory for downloading full datasets. Defaults to "mp_datasets" in the user's home directory force_renew: Option to overwrite existing local dataset - query_builder : Instance of deltalake QueryBuilder to use in querying delta tables + query_builder : Instance of QueryBuilderWithCache to use in querying delta tables + NOTE: Must be a QueryBuilderWithCache, a deltalake.QueryBuilder will be ignored. notify_db_version (bool): If True, the current MP database version will be retrieved and logged locally in the ~/.mprester.log.yaml. If the database version changes, you will be notified. The current database version is diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index 87bb40b7..da300110 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -276,7 +276,7 @@ def get_bandstructure_from_task_id( Returns: bandstructure (BandStructure): BandStructure or BandStructureSymmLine object """ - bs_lbl, bs_tbl = self._get_delta_table( + bs_lbl, _ = self._get_delta_table( "materialsproject-parsed", "core/electronic-structure/bandstructures/", label="bandstructure", @@ -507,7 +507,7 @@ def get_dos_from_task_id( Returns: pymatgen Dos """ - dos_lbl, dos_tbl = self._get_delta_table( + dos_lbl, _ = self._get_delta_table( "materialsproject-parsed", "core/electronic-structure/total-dos/", label="total_dos", diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index c55aeb85..6feefbf9 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -42,7 +42,7 @@ def get_trajectory( f"WHERE run_type='{str(run_type)}' AND " if run_type else "" ) + f"WHERE identifier='{as_alpha}'" - traj_lbl, traj_tbl = self._get_delta_table( + traj_lbl, _ = self._get_delta_table( "materialsproject-parsed", "core/trajectories/", label="traj", diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index 9f2f82b5..b68123f2 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -165,9 +165,9 @@ def get_phase_diagram_from_chemsys( ) sorted_chemsys = "-".join(sorted(chemsys.split("-"))) - version = "2026-04-13" # self.db_version.replace(".", "-") + version = self.db_version.replace(".", "-") - pd_lbl, pd_tbl = self._get_delta_table( + pd_lbl, _ = self._get_delta_table( "materialsproject-build", "objects/phase-diagrams", label="phase_diagrams" )