Draft
Conversation
This references the PR data-apis/array-api-tests#274.
top_k compatibilitytop_k compatibility [DO NOT MERGE]
top_k compatibility [DO NOT MERGE]top_k compatibility
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This references the PR data-apis/array-api-tests#274 and implements the compatibility layer for
top_k.Summary of Compatibility
jax:top_kdoes not implementaxisorlargestarguments. Whileaxisis easily implemented withjax.numpy.swapaxes,largestis not. Implementing the spec in JAX can be done similar to the pure python implementation in WIP: top_k draft implementation numpy/numpy#26666.jax.numpy.partitionandjax.numpy.argpartitionjax-ml/jax#22137.numpy:dask:top_kis currently about 2x longer than it has to be since computing the indices and values has to be done separately. This can be rectified whentake_along_axisis implemented in dask: Add NumPy's new take_along_axis dask/dask#3663.torch:Process
As mentioned in the referenced PR, since the process I went through is likely going to be repeated again, here are the steps I took:
array-apithat adds the corresponding specification..draft.array-api-testswhich implements the new tests and has itsarray-apisubmodule pointing to the newly createdarray-apibranch.array-api-compat(This PR) that implements the compatibility and points the CI to the newly createdarray-api-testsbranch.ARRAY_API_TESTS_VERSION=draftin the CI.Since I was implementing tests and compatibility on a non-existent spec, developing all 3 concurrently was incredibly messy. As of now I don't have much opinions on how to improve this process, but a documentation page of the necessary steps will be really helpful for future contributors.