Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spec/draft/API_specification/set_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Objects in API
:toctree: generated
:template: method.rst

isin
unique_all
unique_counts
unique_inverse
Expand Down
38 changes: 36 additions & 2 deletions src/array_api_stubs/_draft/set_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,41 @@
__all__ = ["unique_all", "unique_counts", "unique_inverse", "unique_values"]
__all__ = ["isin", "unique_all", "unique_counts", "unique_inverse", "unique_values"]


from ._types import Tuple, array
from ._types import Tuple, Union, array


def isin(
x1: Union[array, int],
x2: Union[array, int],
/,
*,
invert: bool = False,
) -> array:
"""
Tests for each element in ``x1`` whether the element is in ``x2``.

Parameters
----------
x1: Union[array, int]
first input array. **Should** have an integer data type.
x2: Union[array, int]
second input array. **Should** have an integer data type.
invert: bool
boolean indicating whether to invert the test criterion. If ``True``, the function **must** test whether each element in ``x1`` is *not* in ``x2``. If ``False``, the function **must** test whether each element in ``x1`` is in ``x2``. Default: ``False``.

Returns
-------
out: array
an array containing element-wise test results. The returned array **must** have a boolean data type. If ``x1`` is an array, the returned array **must** have the same shape as ``x1``; otherwise, the returned array **must** be a zero-dimensional array containing the result.

Notes
-----

- At least one of ``x1`` or ``x2`` **must** be an array.
- If an element in ``x1`` is in ``x2``, the corresponding element in the output array **must** be ``True``; otherwise, the corresponding element in the output array **must** be ``False``.
- Testing whether an element in ``x1`` corresponds to an element in ``x2`` **must** be determined based on value equality (see :func:`~array_api.equal`).
- Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is unspecified and thus implementation-defined.
"""


def unique_all(x: array, /) -> Tuple[array, array, array, array]:
Expand Down