This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.
The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.
PiperOrigin-RevId: 414704700
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).
The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)
PiperOrigin-RevId: 411565432
The current error message for `jax.vmap(lambda x: 1)({})` is:
`ValueError: vmap must have at least one non-None value in in_axes`
With this PR, it becomes:
`ValueError: vmap wrapped function must be passed at least one argument
containing an array, got empty *args=({},) and **kwargs={}`
This is a more descriptive name and a better location (next to other facilities for building XLA IR).
Quite a few users of the former xla_bridge.constant() didn't need anything other than uncanonicalized array constants. Change these users to use xla_client.ops.Constant instead; no need for the fancy utility in these cases.
PiperOrigin-RevId: 404270649
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.
Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.
PiperOrigin-RevId: 404071001
This is a vestigal wrapper around xla_client.XlaBuilder whose purpose is long gone.
Also rename uses of XlaComputationBuilder to XlaBuilder. XlaComputationBuilder was an older name that is gone in most places.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
There's a bug we're struggling to repro.
To use the new checkpoint, just use
```python
from jax.ad_checkpoint import checkpoint
```
rather than `from jax import checkpoint.
The number of buffers used to represent an abstract value is a property specific to a particular representation of that abstract value. Currently the only representation is an XLA representation, but that may change in the future. Instead, callers who want to know how XLA would represent an aval should ask the XLA module instead. In this case, we call len(xla.aval_to_xla_shapes(...)) instead.
Before this change, primitives have a special case dispatch path that attempts
to avoid building a jaxpr in the cache miss case. However, there's no good
reason for this: it makes the code more complicated, and we're not particularly
optimizing for fast cache misses anyway (we care mostly about cache hits).
Make the primitive lowering path trace a small function using the xla_callable
lowering path instead.
Changed the behavior of `jacfwd`, `jacrev`, and `grad` when the input
pytree elements have heterogeneous dtypes, e.g., real and complex
elements:
* Changed the dtypes of the pytree elements of the Jacobian produced by
jacfwd to be those of the input tangent basis.
* Changed the dtypes of the pytree elements of the Jacobian produced by
jacrev to be those of the output tangent basis.
* Changed the dtypes of the pytree elements of the primals and tangents
produced by jacfwd and jacrev to be the same as the corresponding
elements in the input.
Changed the behavior of the flags to `jacfwd` and `jacrev`:
* Changed the allow_int flag to only allows integer and Boolean dtypes.
Previously, this flag allowed all other types.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...
PiperOrigin-RevId: 400858477
``jaxpr_subcomp`` likes to lower control-flow primitives by tracing them
again as JAX callables, but they're all axis primitives now and so they
do require a properly initialized axis env.
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.