We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.
This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.
This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
--
35fcf2e2fd5b4c56cbb591f4c8bf01222a23dfe5 by Matthew Johnson <mattjj@google.com>:
remove deprecated custom_transforms code
PiperOrigin-RevId: 366108489
There's a small overhead to introducing an additional Python call frame, which we can avoid if we annotate only the cache miss case. This change does not seem to affect the filtered stack traces; it appears we filter all JAX-internal frames whenever any api_boundary is present.
Avoid performing NaN/Inf checking in the common path for calling a jit-ted function. Instead, add a global/thread-local `posthook` function that, if, set, the C++ jit code calls with the inputs (function, args, kwargs, outputs). Use the posthook feature to implement NaN checking.
Add a `_cache_miss` attribute to the C++ JIT function objects to allow the NaN checking code to extract and call the cache miss function.
PiperOrigin-RevId: 365108787
[JAX] Add an opaque `extra_jit_context` field to the JAX C++ jit code.
This allows the JAX Python code to include extra context from, for example, the interpreter state as part of the C++ jit cache key.
PiperOrigin-RevId: 364611475
[JAX] Add an opaque `extra_jit_context` field to the JAX C++ jit code.
This allows the JAX Python code to include extra context from, for example, the interpreter state as part of the C++ jit cache key.
PiperOrigin-RevId: 364599983
This allows the JAX Python code to include extra context from, for example, the interpreter state as part of the C++ jit cache key.
PiperOrigin-RevId: 364563982
Create separate holder objects for global and thread-local state, and move enable_x64 and disable_jit context into the holder objects.
Expose the global and per-thread state objects to Python via pybind11.
Refactoring only; no functional changes intended.
PiperOrigin-RevId: 363510449
Back in the mists of time, before omnistaging landed in JAX, we used lazy
expressions to avoid materializing large constants inside `jit` computations.
Omnistaging, which means that computations that are in the dynamic scope of a
`jit` are staged into the `jit` computation, has subsumed most of the reasons
for laziness to exist, and this PR removes the laziness support for simplicity.
At the time of this PR, laziness is used only for broadcasts and transposes in
eager mode (i.e., outside a `jit`). This allows us to:
a) fuse together multiple broadcasts and transposes, and
b) if a lazy expression is lexically captured by a `jit` computation, we can
avoid materializing it in its expanded form.
It is not clear that laziness has sufficient power to weight ratio to continue
to exist, and it is making other work on improving JAX dispatch times more
difficult. As a result, this PR removes laziness to unblock that work; if we
want laziness again we would want to reimplement it in C++ anyway.
Updated version of #4536.
This is removing the device constant part of #1668. We can do this because after #3370 and #4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)
After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).
This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
This is done to simplify the code, and not at all for performance, because it's only executed during the compilation phase.
One possible design question: should we let the user access the value of the flag if it has not been set? Right now, the Python code allows it I think (meaning the behavior may not match the flag value, which has not been parsed yet).
We could raise an error, by setting the flag value to absl::nullopt, and check it's not null. But it would be a breaking change, so I am a little reluctant doing so.
PiperOrigin-RevId: 360407549
This adds a primitive with a corresponding traceable function in
`custom_derivatives` that takes a callee and its transpose, both
functions. When the primitive is encountered during transposition, the
given transpose function is invoked instead of transpose-transforming
the callee. The invocation of the custom transposition is itself done
via a `linear_call`, with the original callee set as the transpose.
This maintains, in particular, that transposing twice is an identity.
The second change in the avals-with-names stack:
- https://github.com/google/jax/pull/5524 Revise aval constructor call sites to use a new `aval.update` method
- **Add `named_shape` to `ShapedArray` and update typecompat**
- Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules
- Make `mapped_aval`, `unmapped_aval`, and their xmap equivalents swap positional and named axes (rather than just creating and deleting positional ones)
- Enable `lax.full` to create values with named axes
- Ensure `grad` and `jacfwd`/`jacrev` consistently act elementwise over named axes (by e.g. using a seed with named axes in `grad`, and prohibiting collectives if TAP isn't too unhappy) and align `vmap(transpose)` with `transpose(vmap)` by moving the `psum` in `transpose(psum)` into `backward_pass`
- Add `axis_name` kwarg to grad to indicate operating collectively over one or more named axes
PiperOrigin-RevId: 355880632
This can be useful when you need backend specific behaviour, e.g.:
if jax.default_backend() == 'gpu':
dataset = double_buffer(dataset)
Or if you want to assert a given backend is the default:
assert jax.default_backend() == 'tpu'
I am a bit conflicted by the naming, "backend" is consistent with other APIs in
JAX (e.g. jit, local_devices etc) which accept a "backend" string which is used
to lookup an XLA backend by platform name.