diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 80d205dc..3d26381d 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -466,6 +466,183 @@ def _query_open_data( return decoded_data, len(decoded_data) # type: ignore + def _query_delta_backed( + self, + bucket: str, + prefix: str, + timeout: int | None = None, + ) -> dict[str, Any]: + """Retrieve data from S3 backed by a DeltaTable. + + Args: + bucket (str) : S3 OpenData bucket + prefix (str) : S3 object prefix + timeout (int or None) : timeout on getting access-controlled groups + + Returns: + dict of str to Any + """ + # Check if user has access to GNoMe + # temp suppress tqdm + re_enable = not self.mute_progress_bars + self.mute_progress_bars = True + has_gnome_access = bool( + self._submit_requests( + url=urljoin(self.base_endpoint, "materials/summary/"), + criteria={ + "batch_id": "gnome_r2scan_statics", + "_fields": "material_id", + }, + use_document_model=False, + num_chunks=1, + chunk_size=1, + timeout=timeout if timeout is not None else self.timeout, + ) + .get("meta", {}) + .get("total_doc", 0) + ) + self.mute_progress_bars = not re_enable + + suffix = prefix.rsplit("/")[1] + + target_path = str( + self.local_dataset_cache.joinpath( + f"{bucket.split('materialsproject-')[1]}/{prefix}" + ) + ) + os.makedirs(target_path, exist_ok=True) + + if DeltaTable.is_deltatable(target_path): + if self.force_renew: + shutil.rmtree(target_path) + logger.warning(f"Regenerating {suffix} dataset at {target_path}...") + os.makedirs(target_path, exist_ok=True) + else: + logger.warning( + f"Dataset for {suffix} already exists at {target_path}, returning existing dataset." + ) + logger.info( + "Delete or move existing dataset or re-run search query with MPRester(force_renew=True) " + "to refresh local dataset.", + ) + + return { + "data": MPDataset( + path=target_path, + document_model=self.document_model, + use_document_model=self.use_document_model, + ) + } + + tbl = DeltaTable( + f"s3a://{bucket}/{prefix}", + storage_options={ + "AWS_SKIP_SIGNATURE": "true", + "AWS_REGION": "us-east-1", + }, + ) + + controlled_batch_str = ",".join( + [f"'{tag}'" for tag in self.access_controlled_batch_ids] + ) + + predicate = ( + f"WHERE batch_id NOT IN ({controlled_batch_str})" + if not has_gnome_access + else "" + ) + + builder = QueryBuilder().register("tbl", tbl) + + # Setup progress bar + num_docs_needed: int = tbl.count() + + if not has_gnome_access: + num_docs_needed = self.count( + {"batch_id_neq_any": self.access_controlled_batch_ids} + ) + + pbar = ( + tqdm( + desc=( + f"Retrieving DeltaTable-backed {self.document_model.__name__} documents" + if self.document_model is not None + else "Retrieving documents" + ), + total=num_docs_needed, + ) + if not self.mute_progress_bars + else None + ) + + iterator = builder.execute(f"SELECT * FROM tbl {predicate}") + + file_options = ds.ParquetFileFormat().make_write_options(compression="zstd") + + def _flush(accumulator: list[pa.RecordBatch], group: int, schema: pa.Schema): + # somewhere post datafusion 51.0.0 and arrow-rs 57.0.0 + # casts to *View types began, need to cast back to base schema + # -> pyarrow is behind on implementation support for *View types + tbl = ( + pa.Table.from_batches(accumulator) + .select(schema.names) + .cast(target_schema=schema) + ) + + ds.write_dataset( + tbl, + base_dir=target_path, + format="parquet", + basename_template=f"group-{group}-" + "part-{i}.zstd.parquet", + existing_data_behavior="overwrite_or_ignore", + max_rows_per_group=1024, + file_options=file_options, + ) + + group = 1 + size = 0 + accumulator = [] + schema = pa.schema(arrowize(self.document_model)) + for page in iterator: + # arro3 rb to pyarrow rb for compat w/ pyarrow ds writer + rg = pa.record_batch(page) + accumulator.append(rg) + page_size = page.num_rows + size += rg.get_total_buffer_size() + + if pbar is not None: + pbar.update(page_size) + + if size >= MAPI_CLIENT_SETTINGS.DATASET_FLUSH_THRESHOLD: + _flush(accumulator, group, schema) + group += 1 + size = 0 + accumulator.clear() + + if accumulator: + _flush(accumulator, group + 1, schema) + + if pbar is not None: + pbar.close() + + logger.info(f"Dataset for {suffix} written to {target_path}") + logger.info("Converting to DeltaTable...") + + convert_to_deltalake(target_path) + + logger.info( + "Consult the delta-rs and pyarrow documentation for advanced usage: " + "delta-io.github.io/delta-rs, arrow.apache.org/docs/python" + ) + + return { + "data": MPDataset( + path=target_path, + document_model=self.document_model, + use_document_model=self.use_document_model, + ) + } + def _query_resource( self, criteria: dict | None = None, @@ -542,27 +719,6 @@ def _query_resource( suffix = infix if suffix == "core" else suffix suffix = suffix.replace("_", "-") - # Check if user has access to GNoMe - # temp suppress tqdm - re_enable = not self.mute_progress_bars - self.mute_progress_bars = True - has_gnome_access = bool( - self._submit_requests( - url=urljoin(self.base_endpoint, "materials/summary/"), - criteria={ - "batch_id": "gnome_r2scan_statics", - "_fields": "material_id", - }, - use_document_model=False, - num_chunks=1, - chunk_size=1, - timeout=timeout, - ) - .get("meta", {}) - .get("total_doc", 0) - ) - self.mute_progress_bars = not re_enable - if "tasks" in suffix: bucket_suffix, prefix = ("parsed", "core/tasks/") else: @@ -572,156 +728,19 @@ def _query_resource( bucket = f"materialsproject-{bucket_suffix}" if self.delta_backed: - target_path = str( - self.local_dataset_cache.joinpath(f"{bucket_suffix}/{prefix}") - ) - os.makedirs(target_path, exist_ok=True) - - if DeltaTable.is_deltatable(target_path): - if self.force_renew: - shutil.rmtree(target_path) - logger.warning( - f"Regenerating {suffix} dataset at {target_path}..." - ) - os.makedirs(target_path, exist_ok=True) - else: - logger.warning( - f"Dataset for {suffix} already exists at {target_path}, returning existing dataset." - ) - logger.info( - "Delete or move existing dataset or re-run search query with MPRester(force_renew=True) " - "to refresh local dataset.", - ) - - return { - "data": MPDataset( - path=target_path, - document_model=self.document_model, - use_document_model=self.use_document_model, - ) - } - - tbl = DeltaTable( - f"s3a://{bucket}/{prefix}", - storage_options={ - "AWS_SKIP_SIGNATURE": "true", - "AWS_REGION": "us-east-1", - }, - ) - - controlled_batch_str = ",".join( - [f"'{tag}'" for tag in self.access_controlled_batch_ids] - ) - - predicate = ( - f"WHERE batch_id NOT IN ({controlled_batch_str})" - if not has_gnome_access - else "" - ) - - builder = QueryBuilder().register("tbl", tbl) - - # Setup progress bar - num_docs_needed: int = tbl.count() - - if not has_gnome_access: - num_docs_needed = self.count( - {"batch_id_neq_any": self.access_controlled_batch_ids} - ) - - pbar = ( - tqdm( - desc=pbar_message, - total=num_docs_needed, - ) - if not self.mute_progress_bars - else None - ) - - iterator = builder.execute(f"SELECT * FROM tbl {predicate}") - - file_options = ds.ParquetFileFormat().make_write_options( - compression="zstd" - ) - - def _flush( - accumulator: list[pa.RecordBatch], group: int, schema: pa.Schema - ): - # somewhere post datafusion 51.0.0 and arrow-rs 57.0.0 - # casts to *View types began, need to cast back to base schema - # -> pyarrow is behind on implementation support for *View types - tbl = ( - pa.Table.from_batches(accumulator) - .select(schema.names) - .cast(target_schema=schema) - ) - - ds.write_dataset( - tbl, - base_dir=target_path, - format="parquet", - basename_template=f"group-{group}-" - + "part-{i}.zstd.parquet", - existing_data_behavior="overwrite_or_ignore", - max_rows_per_group=1024, - file_options=file_options, - ) - - group = 1 - size = 0 - accumulator = [] - schema = pa.schema(arrowize(self.document_model)) - for page in iterator: - # arro3 rb to pyarrow rb for compat w/ pyarrow ds writer - rg = pa.record_batch(page) - accumulator.append(rg) - page_size = page.num_rows - size += rg.get_total_buffer_size() - - if pbar is not None: - pbar.update(page_size) - - if size >= MAPI_CLIENT_SETTINGS.DATASET_FLUSH_THRESHOLD: - _flush(accumulator, group, schema) - group += 1 - size = 0 - accumulator.clear() - - if accumulator: - _flush(accumulator, group + 1, schema) - - if pbar is not None: - pbar.close() - - logger.info(f"Dataset for {suffix} written to {target_path}") - logger.info("Converting to DeltaTable...") - - convert_to_deltalake(target_path) - - logger.info( - "Consult the delta-rs and pyarrow documentation for advanced usage: " - "delta-io.github.io/delta-rs, arrow.apache.org/docs/python" - ) - - return { - "data": MPDataset( - path=target_path, - document_model=self.document_model, - use_document_model=self.use_document_model, - ) - } + return self._query_delta_backed(bucket, prefix, timeout=timeout) # Paginate over all entries in the bucket. # TODO: change when a subset of entries needed from DB paginator = self.s3_client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket, Prefix=prefix) - keys = [] - for page in pages: - for obj in page.get("Contents", []): - key = obj.get("Key") - if key and "manifest" not in key: - keys.append(key) + keys = [ + obj["Key"] + for page in pages + for obj in page.get("Contents", []) + if obj.get("Key") and "manifest" not in obj["Key"] + ] if len(keys) < 1: return self._submit_requests( @@ -769,8 +788,13 @@ def _flush( ) ] + _chunks = chain.from_iterable(unzipped_chunks) data: dict[str, Any] = { - "data": list(chain.from_iterable(unzipped_chunks)), + "data": ( + self._convert_to_model(_chunks) + if self.document_model and use_document_model + else list(_chunks) + ), "meta": {}, } @@ -1208,19 +1232,27 @@ def _submit_request_and_process( ) def _convert_to_model( - self, data: list[dict[str, Any]] + self, + data: list[dict[str, Any]] | Iterator, ) -> list[BaseModel] | list[dict[str, Any]]: """Converts dictionary documents to instantiated MPDataDoc objects. Args: - data (list[dict]): Raw dictionary data objects + data (list[dict] or Iterator): Raw dictionary data objects Returns: (list[MPDataDoc]): List of MPDataDoc objects """ - if len(data) > 0: - data_model, set_fields, _ = self._generate_returned_model(data[0]) + if (hasattr(data, "__len__") and len(data) > 0) or (hasattr(data, "__next__")): # type: ignore[arg-type] + is_list = hasattr(data, "__len__") + try: + # Handle both list-like and iterator input + first_doc = data[0] if is_list else next(data) # type: ignore[index,arg-type] + except StopIteration: + # Return empty list if no data in iterator + return [] + data_model, set_fields, _ = self._generate_returned_model(first_doc) return [ data_model( @@ -1230,7 +1262,7 @@ def _convert_to_model( if field in set_fields } ) - for raw_doc in data + for raw_doc in (data if is_list else chain([first_doc], data)) ] return data diff --git a/tests/client/materials/test_materials.py b/tests/client/materials/test_materials.py index 4133b0d5..67f41020 100644 --- a/tests/client/materials/test_materials.py +++ b/tests/client/materials/test_materials.py @@ -69,7 +69,6 @@ def test_client(rester): ) -@pytest.mark.xfail(condition=True, reason="Needs new deployment.", strict=False) @pytest.mark.parametrize( "run_type, uncorrected_energy, use_document_model", [("PBE", None, True), ("r2SCAN", 1.0, False), ("GGA_U", (-50e4, 0.0), True)],