We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.
This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.
This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
--
35fcf2e2fd5b4c56cbb591f4c8bf01222a23dfe5 by Matthew Johnson <mattjj@google.com>:
remove deprecated custom_transforms code
PiperOrigin-RevId: 366108489
Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.
This removes the need to manually set these env vars when running on a Cloud TPU pod slice.
PiperOrigin-RevId: 361681134
This can be useful when you need backend specific behaviour, e.g.:
if jax.default_backend() == 'gpu':
dataset = double_buffer(dataset)
Or if you want to assert a given backend is the default:
assert jax.default_backend() == 'tpu'
I am a bit conflicted by the naming, "backend" is consistent with other APIs in
JAX (e.g. jit, local_devices etc) which accept a "backend" string which is used
to lookup an XLA backend by platform name.
Also move tests for device_put_sharded into pmap_test.py, since that
file tests with multiple devices even in our OSS CI.
Add both device_put_replicated and device_put_sharded to
jax/__init__.py.
- Add float0 and set-up at_least_vspace to return float0
values for int/bool primals
- Use Zero to wrap float0 tangents so they're correctly ignored in jvp
rules
- Add float0 handlers to XLA to support jit
- Fix convert_element_type and tie_in jvp rules
* Add jax.linear_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
* add failing test for complex numbers
* Add picky dtype check for linear_transpose
* Lint fix
* Allow truncating dtypes to match inputs in linear_transpose
* Fix typo in shape check error
* improve docstring
* Don't support integer inputs; better docstring
* fixup
* Fix doctest
Co-authored-by: Matthew Johnson <mattjj@google.com>
Something must have started logging earlier than before, causing INFO-level logging to be initialized before we disabled it. This change disables INFO logging sooner.
* Add jax.image.resize.
This is a port of `tf.image.resize()` and the `ScaleAndTranslate` operator.
While I don't expect this implementation to be particularly fast, it is a useful generic implementation to which we can add optimized special cases as the need arises.
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
* Use a whitelist to restrict visibility in top-level jax namespace.
The goal of this change is to capture the way the world is (i.e., not break users), and separately we will work on fixing users to avoid accidentally-exported APIs.