diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index c8001a84..18c285c1 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -1034,25 +1034,30 @@ def _validate_table_annotation_metadata(cls, data: AnnData) -> None: raise ValueError(f"`{attr[cls.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.") if attr[cls.INSTANCE_KEY] not in data.obs: raise ValueError(f"`{attr[cls.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.") - if ( - (dtype := data.obs[attr[cls.INSTANCE_KEY]].dtype) - not in [ - int, - np.int16, - np.uint16, - np.int32, - np.uint32, - np.int64, - np.uint64, - "O", - ] - and not pd.api.types.is_string_dtype(data.obs[attr[cls.INSTANCE_KEY]]) - or (dtype == "O" and (val_dtype := type(data.obs[attr[cls.INSTANCE_KEY]].iloc[0])) is not str) - ): - dtype = dtype if dtype != "O" else val_dtype + instance_col = data.obs[attr[cls.INSTANCE_KEY]] + dtype = instance_col.dtype + + _INT_TYPES = [int, np.int16, np.uint16, np.int32, np.uint32, np.int64, np.uint64] + + def _is_int_or_str_dtype(d: np.dtype) -> bool: + return d in _INT_TYPES or isinstance(d, pd.StringDtype) + + # First, check the top-level dtype (covers plain int and StringDtype cases) + is_valid = _is_int_or_str_dtype(dtype) + # Explicitly handle categorical dtypes by inspecting the categories' dtype, including + # object-backed string categories via is_string_dtype on the categories' dtype. + if isinstance(dtype, pd.CategoricalDtype): + cat_dtype = dtype.categories.dtype + is_valid = is_valid or _is_int_or_str_dtype(cat_dtype) or pd.api.types.is_string_dtype(cat_dtype) + # the string case is already covered above, the check below covers the case of dtype("O") with string dtype + is_valid = is_valid or pd.api.types.is_string_dtype(instance_col) + + if not is_valid: raise TypeError( - f"Only int, np.int16, np.int32, np.int64, uint equivalents or string allowed as dtype for " - f"instance_key column in obs. Dtype found to be {dtype}" + f"Only integer (int, np.int16, np.int32, np.int64, and uint equivalents), string " + f"(including pandas StringDtype and object dtype with string values), or categorical " + f"with integer/string categories allowed as dtype for instance_key column in obs. " + f"Dtype found to be {dtype}" ) expected_regions = attr[cls.REGION_KEY] if isinstance(attr[cls.REGION_KEY], list) else [attr[cls.REGION_KEY]] found_regions = data.obs[attr[cls.REGION_KEY_KEY]].unique().tolist() diff --git a/tests/models/test_models.py b/tests/models/test_models.py index e2087ace..c4ac3347 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -466,6 +466,63 @@ def test_table_model( del table.uns[TableModel.ATTRS_KEY] _ = TableModel.parse(table) + @pytest.mark.parametrize( + "instance_key_values,instance_key_dtype,should_pass", + [ + # pd.StringDtype: accepted (issue #1062) + (["id_0", "id_1", "id_2", "id_3", "id_4"], pd.StringDtype(), True), + # object dtype with string values: accepted + (["id_0", "id_1", "id_2", "id_3", "id_4"], object, True), + # CategoricalDtype with object (string) categories: accepted (issue #1062) + ( + pd.Categorical(["id_0", "id_1", "id_2", "id_3", "id_4"]), + None, + True, + ), + # CategoricalDtype with StringDtype categories: accepted (issue #1062) + ( + pd.Categorical(pd.array(["id_0", "id_1", "id_2", "id_3", "id_4"], dtype="string")), + None, + True, + ), + # CategoricalDtype with integer categories: accepted + ( + pd.Categorical([0, 1, 2, 3, 4]), + None, + True, + ), + # CategoricalDtype with float categories: rejected + ( + pd.Categorical([0.0, 1.0, 2.0, 3.0, 4.0]), + None, + False, + ), + # integer dtype: accepted + ([0, 1, 2, 3, 4], np.int64, True), + # float dtype: rejected + ([0.0, 1.0, 2.0, 3.0, 4.0], np.float64, False), + # object dtype with non-string values: rejected + ([0, 1, 2, 3, 4], object, False), + ], + ) + def test_table_instance_key_dtype_validation(self, instance_key_values, instance_key_dtype, should_pass): + """Test that _validate_table_annotation_metadata accepts/rejects the correct dtypes for instance_key.""" + n = 5 + region = "sample" + region_key = "region" + obs = pd.DataFrame(index=list(map(str, range(n)))) + obs[region_key] = pd.Categorical([region] * n) + if instance_key_dtype is not None: + obs["instance_id"] = pd.array(instance_key_values, dtype=instance_key_dtype) + else: + obs["instance_id"] = instance_key_values + adata = AnnData(RNG.normal(size=(n, 2)), obs=obs) + if should_pass: + _ = TableModel.parse(adata, region=region, region_key=region_key, instance_key="instance_id") + else: + with pytest.raises(TypeError, match="allowed as dtype for instance_key column"): + TableModel.parse(adata, region=region, region_key=region_key, instance_key="instance_id") + @pytest.mark.parametrize( "name", [