ENH: cupy: add a workaround for cp.searchorted 2nd argument#374
ENH: cupy: add a workaround for cp.searchorted 2nd argument#374ev-br wants to merge 1 commit intodata-apis:mainfrom
Conversation
Array API 2025.12 allows python scalars for the x2 argument of `searchsorted`. CuPy only supports python scalars for x2 from CuPy 14.0. Until this is the minimum supported version, array-api-compat needs a workaround.
|
Marking as a draft until 2025.12 is out. |
There was a problem hiding this comment.
Pull request overview
This PR adds a workaround for CuPy's searchsorted function to support Python scalars for the x2 argument, which is now allowed by Array API 2025.12 but only supported in CuPy 14.0+. Until CuPy 14.0 becomes the minimum supported version, this compatibility wrapper converts scalar inputs to arrays before passing them to the underlying CuPy function.
Key Changes
- Added new
searchsortedfunction inarray_api_compat/cupy/_aliases.pythat wrapscp.searchsortedwith scalar-to-array conversion - Added
Literalimport fromtypingfor type annotations - Updated
__all__exports to include the newsearchsortedfunction
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| def searchsorted( | ||
| x1: Array, | ||
| x2: Array | int | float, |
There was a problem hiding this comment.
The type annotation for x2 is missing complex which is checked in the isinstance check on line 154. The type annotation should be Array | int | float | complex to match the runtime validation logic.
| x2: Array | int | float, | |
| x2: Array | int | float | complex, |
| raise NotImplementedError( | ||
| 'Only python scalars or ndarrays are supported for x2') | ||
| x2 = cp.asarray(x2) | ||
| return cp.searchsorted(x1, x2, side, sorter) |
There was a problem hiding this comment.
The arguments side and sorter should be passed as keyword arguments to cp.searchsorted, not positional arguments. The call should be cp.searchsorted(x1, x2, side=side, sorter=sorter) to match the API signature.
| return cp.searchsorted(x1, x2, side, sorter) | |
| return cp.searchsorted(x1, x2, side=side, sorter=sorter) |
Array API 2025.12 allows python scalars for the x2 argument of
searchsorted. CuPy only supports python scalars for x2 from CuPy 14.0. Until this is the minimum supported version, array-api-compat needs a workaround.Array API spec PR: data-apis/array-api#982
A matching test: data-apis/array-api-tests#394
The matching CuPy enhancement: cupy/cupy#9512