--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:
Rename opaque dtype to extended dtype.
This includes three deprecations:
- jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
- jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
- the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.
One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.
PiperOrigin-RevId: 549301796
Note that if donate_argnames is not None and donate_argnums is None, then JAX will infer donate_argnums from the names which will then we used to find the donation_vector. This is fine because currently, the same thing happens from static_argnums and static_argnames.
I'll fix the TODOs, etc in follow up CLs.
Fixes https://github.com/google/jax/issues/10539
PiperOrigin-RevId: 547612861
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:
* remove symbolic_equal_dim and symbolic_equal_shape in favor of the
newer definitely_equal and definitely_equal_shape
* remove is_special_dim_size, which checks that a value is a
dimension expression (not a constant). Some uses are replaced
with `not is_constant_dim` and others with `is_dim`.
* introduce concrete_dim_or_error to check that a value is
a dimension
sharding=None means that JAX is free to choose whatever sharding it wants. As it stands, jax will choose to mark the input as replicated but JAX reserves the right to change that as it sees fit.
PiperOrigin-RevId: 543630595
The semantics are as follow:
* if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings
* If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics.
This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit.
PiperOrigin-RevId: 540705660
After the changes in shard_map, there are 75 failures left to be resolved (not counting the EagerPmap tests).
TODO:
* Move shard_map to _src so that the circular import can be removed from api.py
PiperOrigin-RevId: 525930416
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.
PiperOrigin-RevId: 523146076