This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).
If your code ends up taking the python dispatch, then something is going wrong anyways.
PiperOrigin-RevId: 596081987
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.
TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).
The net effect of this change is mostly to tighten up many test tolerances on TPU.
PiperOrigin-RevId: 562953488
Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist.
In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about.
```
jit_trivial_dispatch 246µs ± 3% 4µs ± 1% -98.52% (p=0.008 n=5+5)
jit_trivial 250µs ± 3% 5µs ± 1% -98.19% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 560141018
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:
Rename opaque dtype to extended dtype.
This includes three deprecations:
- jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
- jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
- the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
This makes it closer to numpy, with dtypes.OpaqueDtype analogous to np.dtype,
and dtypes.opaque analogous to np.numeric. This will let us replace the
dtypes.is_opaque_dtype function with jnp.issubdtype(dtype, dtypes.opaque).
The issue is that the batching rule assumes that each scatter variant
always has the same update_jaxpr. This is not true of scatter_apply, which
lowers to scatter with a custom update_jaxpr. To address this, we change
the batching rule such that it re-uses the input jaxpr rather than always
re-generating it.
Make sure lower_jaxpr_to_fun always sees HloSharding in arg_shardings and results_shardings.
Also make sure physical_hlo_sharding only accepts HloSharding as the input.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 538342152
This change brings the dot_general primitive more in line with the HLO
primitive, as it is described in XLA's shape_inference.cc (but not in the
StableHLO spec). In particular we allow different input dtypes.
The main motivation is to support transposition in the presence of
preferred_element_type (which can set the output dtype to be different from the
inputs), e.g. to fix#10818.
However, because XLA platforms/backends can't seem to codegen all the cases
that are accepted by shape_inference.cc, in our lowering rules we generate
ConvertElementTypes on the inputs in a platform-dependent way.
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.