Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,25 +1034,30 @@ def _validate_table_annotation_metadata(cls, data: AnnData) -> None:
raise ValueError(f"`{attr[cls.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.")
if attr[cls.INSTANCE_KEY] not in data.obs:
raise ValueError(f"`{attr[cls.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.")
if (
(dtype := data.obs[attr[cls.INSTANCE_KEY]].dtype)
not in [
int,
np.int16,
np.uint16,
np.int32,
np.uint32,
np.int64,
np.uint64,
"O",
]
and not pd.api.types.is_string_dtype(data.obs[attr[cls.INSTANCE_KEY]])
or (dtype == "O" and (val_dtype := type(data.obs[attr[cls.INSTANCE_KEY]].iloc[0])) is not str)
):
dtype = dtype if dtype != "O" else val_dtype
instance_col = data.obs[attr[cls.INSTANCE_KEY]]
dtype = instance_col.dtype

_INT_TYPES = [int, np.int16, np.uint16, np.int32, np.uint32, np.int64, np.uint64]

def _is_int_or_str_dtype(d: np.dtype) -> bool:
return d in _INT_TYPES or isinstance(d, pd.StringDtype)

# First, check the top-level dtype (covers plain int and StringDtype cases)
is_valid = _is_int_or_str_dtype(dtype)
# Explicitly handle categorical dtypes by inspecting the categories' dtype, including
# object-backed string categories via is_string_dtype on the categories' dtype.
if isinstance(dtype, pd.CategoricalDtype):
cat_dtype = dtype.categories.dtype
is_valid = is_valid or _is_int_or_str_dtype(cat_dtype) or pd.api.types.is_string_dtype(cat_dtype)
# the string case is already covered above, the check below covers the case of dtype("O") with string dtype
is_valid = is_valid or pd.api.types.is_string_dtype(instance_col)

if not is_valid:
raise TypeError(
f"Only int, np.int16, np.int32, np.int64, uint equivalents or string allowed as dtype for "
f"instance_key column in obs. Dtype found to be {dtype}"
f"Only integer (int, np.int16, np.int32, np.int64, and uint equivalents), string "
f"(including pandas StringDtype and object dtype with string values), or categorical "
f"with integer/string categories allowed as dtype for instance_key column in obs. "
f"Dtype found to be {dtype}"
)
expected_regions = attr[cls.REGION_KEY] if isinstance(attr[cls.REGION_KEY], list) else [attr[cls.REGION_KEY]]
found_regions = data.obs[attr[cls.REGION_KEY_KEY]].unique().tolist()
Expand Down
57 changes: 57 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,63 @@ def test_table_model(
del table.uns[TableModel.ATTRS_KEY]
_ = TableModel.parse(table)

@pytest.mark.parametrize(
"instance_key_values,instance_key_dtype,should_pass",
[
# pd.StringDtype: accepted (issue #1062)
(["id_0", "id_1", "id_2", "id_3", "id_4"], pd.StringDtype(), True),
# object dtype with string values: accepted
(["id_0", "id_1", "id_2", "id_3", "id_4"], object, True),
# CategoricalDtype with object (string) categories: accepted (issue #1062)
(
pd.Categorical(["id_0", "id_1", "id_2", "id_3", "id_4"]),
None,
True,
),
# CategoricalDtype with StringDtype categories: accepted (issue #1062)
(
pd.Categorical(pd.array(["id_0", "id_1", "id_2", "id_3", "id_4"], dtype="string")),
None,
True,
),
# CategoricalDtype with integer categories: accepted
(
pd.Categorical([0, 1, 2, 3, 4]),
None,
True,
),
# CategoricalDtype with float categories: rejected
(
pd.Categorical([0.0, 1.0, 2.0, 3.0, 4.0]),
None,
False,
),
# integer dtype: accepted
([0, 1, 2, 3, 4], np.int64, True),
# float dtype: rejected
([0.0, 1.0, 2.0, 3.0, 4.0], np.float64, False),
# object dtype with non-string values: rejected
([0, 1, 2, 3, 4], object, False),
],
)
def test_table_instance_key_dtype_validation(self, instance_key_values, instance_key_dtype, should_pass):
"""Test that _validate_table_annotation_metadata accepts/rejects the correct dtypes for instance_key."""
n = 5
region = "sample"
region_key = "region"
obs = pd.DataFrame(index=list(map(str, range(n))))
obs[region_key] = pd.Categorical([region] * n)
if instance_key_dtype is not None:
obs["instance_id"] = pd.array(instance_key_values, dtype=instance_key_dtype)
else:
obs["instance_id"] = instance_key_values
adata = AnnData(RNG.normal(size=(n, 2)), obs=obs)
if should_pass:
_ = TableModel.parse(adata, region=region, region_key=region_key, instance_key="instance_id")
else:
with pytest.raises(TypeError, match="allowed as dtype for instance_key column"):
TableModel.parse(adata, region=region, region_key=region_key, instance_key="instance_id")

@pytest.mark.parametrize(
"name",
[
Expand Down
Loading