This can be useful when you need backend specific behaviour, e.g.:
if jax.default_backend() == 'gpu':
dataset = double_buffer(dataset)
Or if you want to assert a given backend is the default:
assert jax.default_backend() == 'tpu'
I am a bit conflicted by the naming, "backend" is consistent with other APIs in
JAX (e.g. jit, local_devices etc) which accept a "backend" string which is used
to lookup an XLA backend by platform name.
* Add jax.linear_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
* add failing test for complex numbers
* Add picky dtype check for linear_transpose
* Lint fix
* Allow truncating dtypes to match inputs in linear_transpose
* Fix typo in shape check error
* improve docstring
* Don't support integer inputs; better docstring
* fixup
* Fix doctest
Co-authored-by: Matthew Johnson <mattjj@google.com>
* Add jax.image.resize.
This is a port of `tf.image.resize()` and the `ScaleAndTranslate` operator.
While I don't expect this implementation to be particularly fast, it is a useful generic implementation to which we can add optimized special cases as the need arises.
- Show `numpy.jax.vectorize` explicitly in the JAX docs, rather than the
original `numpy.vectorize.
- Updated regex for identifying function signatures in NumPy. This now correctly
parses `np.vectorize` and `np.einsum`.
- Removed docs for `jax.experimental.vectorize`. There's still some good
narrative content in the docstring but it should go somewhere else.
* Add jax.numpy.vectorize
This is basically a non-experimental version of the machinery in
`jax.experimental.vectorize`, except:
- It adds the `excluded` argument from NumPy, which works just like
`static_argnums` in `jax.jit`.
- It doesn't include the `axis` argument yet (which NumPy doesn't have).
Eventually we might want want to consolidate the specification of signatures
with signatures used by shape-checking machinery, but it's nice to emulate
NumPy's existing interface, and this is already useful (e.g., for writing
vectorized linear algebra routines).
* Add deprecation warning to jax.experimental.vectorize
* improve implementation
The current index of the API docs page seems to have broken links: when I
click on "Automatic differentiation" for example, I get sent to the "JIT"
section.
This change fixes the links.
Create a new library `jax.ops` for user-facing ops that don't exist in NumPy or SciPy.
Progress on issue #101. Fixes#122.
Reenable some disabled TPU indexing tests that now pass.