ENH: Allow lazy_xp_function to work properly for class methods that are inherited from a parent class and for classmethods and staticmethods#582
Conversation
|
a few CI failures otherwise looks pretty good |
|
cc @crusaderky |
|
Thanks @lucascolley. I've pushed a commit which fixes the failures locally. |
|
Coverage misses some lines in the diff because ufuncs allow setting attributes as of NumPy 2.2 (numpy/numpy#27735) so those lines only get hit for older NumPy versions.
|
|
Are they hit by |
Interesting, I tried it locally and it seems like they are still not hit. But they should be hit by this test if an older NumPy is installed:: try:
# Test an arbitrary Cython ufunc (@cython.vectorize).
# When SCIPY_ARRAY_API is not set, this is the same as
# scipy.special.erf.
from scipy.special._ufuncs import erf # type: ignore[import-untyped]
lazy_xp_function(erf)
except ImportError:
erf = None
@pytest.mark.skip_xp_backend(Backend.TORCH_GPU, reason="device->host copy")
@pytest.mark.filterwarnings("ignore:__array_wrap__:DeprecationWarning") # PyTorch
def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
pytest.importorskip("scipy")
assert erf is not None
x = xp.asarray([6.0, 7.0])
if library.like(Backend.ARRAY_API_STRICT, Backend.JAX):
# array-api-strict arrays are auto-converted to NumPy
# which results in an assertion error for mismatched namespaces
# eager JAX arrays are auto-converted to NumPy in eager JAX
# and fail in jax.jit (which lazy_xp_function tests here)
with pytest.raises((TypeError, AssertionError)):
xp_assert_equal(cast(Array, erf(x)), xp.asarray([1.0, 1.0]))
else:
# CuPy, Dask and sparse define __array_ufunc__ and dispatch accordingly
# note that when sparse reduces to scalar it returns a np.generic, which
# would make xp_assert_equal fail.
xp_assert_equal(cast(Array, erf(x)), xp.asarray([1.0, 1.0])) |
|
There's no mystery. The test that would hit those lines depends on SciPy and there's no SciPy in the |
|
Wanna try |
I can confirm that adding scipy causes those lines to get hit locally. |
crusaderky
left a comment
There was a problem hiding this comment.
IMHO the API would feel more user-friendly if one could just write
lazy_xp_function(B.g)is there a strong reason against it?
Also, is there a draft scipy PR that shows how this gets integrated in scipy's xp_capabilities?
The draft PR is here scipy/scipy#24267.
In situations where |
There was a problem hiding this comment.
This is now obsolete. Could you either
- clarify that you need to explicitly list classes as well as modules
- change
patch_lazy_xp_functionsto descend into classes from the modules (preferrable)
There was a problem hiding this comment.
Currently classes need to be explicitly listed as well as modules but yes, it would be nice if patch_lazy_xp_functions descended into modules so I'm +1 for that.
tests/test_testing.py
Outdated
| def test_lazy_xp_function_class_inheritance(): | ||
| assert hasattr(B.g, "_lazy_xp_function") | ||
| assert not hasattr(A.g, "_lazy_xp_function") |
There was a problem hiding this comment.
What this new test doesn't verify is that B.g actually runs with the JAX/Dask wrapper in a test with the xp fixture. In fact, it doesn't because patch_lazy_xp_functions doesn't descend into the class.
There was a problem hiding this comment.
Yeah, because the class needs to be added to lazy_xp_modules and the test needs the xp fixture. Things will work in that case, and there are tests in SciPy that work after doing this. Maybe I should just update patch_lazy_xp_functions here too to sidestep that though.
|
Thanks @crusaderky, |
tests/test_testing.py
Outdated
| foo = B(x) | ||
| observed = foo.g(y, z) | ||
| expected = xp.asarray(44.0)[()] | ||
| xp_assert_close(observed, expected) |
There was a problem hiding this comment.
This runs the function, but it doesn't test that it's been wrapped.
You need to write another function that will fail when wrapped, e.g.
def w(self):
return bool(self._xp.any(self.x))`See tests earlier in this same test module for examples.
|
Thanks @crusaderky; the test is actually doing what it's supposed to now. |
| foo = A(x) | ||
| bar = B(x) | ||
|
|
||
| if library.like(Backend.JAX): |
There was a problem hiding this comment.
not just JAX, sparse raises a RuntimeError, see CI
There was a problem hiding this comment.
Yeah, but for different reasons orthogonal from what's being tested so I'm just going to add a skip for it. By the way, what's the pixi command to run tests with all backends?
There was a problem hiding this comment.
array-api-extra on 🎋 lazy-class-methods is 📦 v0.10.0.dev0 via 🐍 took 4s
❯ pixi run tests
? The task 'tests' can be run in multiple environments.
Please select an environment to run the task in: ›
❯ tests
tests-py313
dev
tests-backends
tests-backends-py311
dev-cuda
tests-cuda
tests-cuda-py311
tests-numpy1
tests-py311
tests-nogiltests-backends for all (CPU) backends!
There was a problem hiding this comment.
if in doubt the CI log will always show the exact Pixi task you need to reproduce
There was a problem hiding this comment.
Thanks. Also, is there a convenient way to add a skip for all GPU backends? In SciPy skip_xp_backends has a cpu_only=True option, but here it looks like I'd need to add a skip for each GPU backend separately? I guess I could also just find a way to test the same behavior without needing the skips.
There was a problem hiding this comment.
I don't think that exists yet, however it should be somewhat easy to add a method similar to
array-api-extra/src/array_api_extra/_lib/_backends.py
Lines 44 to 46 in b6518f3
self.name ends with :gpu (or something more general).
There was a problem hiding this comment.
Thanks. That's probably out of scope for this PR so I just added separate skips for all of the GPU backends.
|
I've just found that the change to make it so classes didn't have to be added to |
|
OK, so it wasn't actually related to the |
|
Marking as ready to review because CI is now fully green for scipy/scipy#24267. |
Co-authored-by: Guido Imperiale <crusaderky@gmail.com>
Co-authored-by: Guido Imperiale <crusaderky@gmail.com>
5fec90b to
0c12809
Compare
|
the coverage for |
I added tests for it locally and found that classmethods currently don't get wrapped still. There's another layer of indirection for classmethods compared to staticmethods. Working on fixing that now. |
|
OK, that should do it. |
aha, wow I love codecov |
| hypothesis = ">=6.148.8" | ||
| array-api-strict = ">=2.4.1,<2.5" | ||
| numpy = ">=1.22.0" | ||
| scipy = ">=1.15.2,<2" |
There was a problem hiding this comment.
This looks unnecessary/unwanted. Dropping this can then also drop the pixi.lock changes which will resolve the merge conflict.
There was a problem hiding this comment.
It was deliberate for test coverage. I can handle conflicts

Closes #488
In #488 I asked about applying something like
lazy_xp_functionto classes. This was a half-baked idea motivated by wantingxp_capabilitiesin SciPy, which everyone here should be familiar with, to always be applied at the class level instead of the individual method level. In scipy/scipy#24267 I created a workflow that keepsxp_capabilitiesapplied to classes, but still allows capabilities to be defined separately for individual methods if needed, and allowslazy_xp_functionto be applied to individual methods separately. This means no nonsense about trying to automatically determine which methods of a class should havelazy_xp_functionapplied to them.While working on this, I settled on specifying methods with tuples of the form
Tuple[type, str]specifying an (uninstantiated) class and a method name. This is to allow distinguishing things likeA.ffromB.fwhenBis a subclass ofAthat inheritsffromA, since capabilities may differ at different levels of the inheritance hierarchy. Through changes in SciPy, I was able to allow precise declarations of capabilities for class methods in the presence of inheritance, but a separate change is needed here to allow things like applyinglazy_xp_functiontoB.fbut notA.fwhenfis inherited fromA.The change here modifies
lazy_xp_functionto also take tuples of the formTuple[type, str]. When this is done, for say(B, "f"), thenB.fis replaced with a shallow clone of itself before adding the tags. This allowsB.fgets the tags withoutA.fto get them. If replacing an inherited method with a shallow clone is too obtrusive, I have a workaround that keeps more obtrusive modifications only withinpatch_lazy_xp_functions, but it makes things considerably more complicated, so I hope that won't be necessary.https://github.com/scipy/scipy/blob/dfa1b87e4af7cf7ee5a8b8faf5c4360b63c86b36/scipy/_lib/tests/test_xp_capabilities.py has an (xfailed) test giving an example of what can go wrong with inheritance currently which can be made to pass with the change made in this PR.