An upcoming pytype release complains about unpacking a non-deterministic order iterable for this line of code. Work around pytype.
PiperOrigin-RevId: 551627521
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.
For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.
Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.
This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.
PiperOrigin-RevId: 551604974
_tpu_ext.so dynamically links in libjaxlib_mlir_capi.so (in
jaxlib/mlir/_mlir_libs), so needs to include jaxlib/mlir/_mlir_libs in
its RPATH or similar on other platforms.
We achieve this by moving _tpu_ext.cc to jaxlib/mlir/_mlir_libs so it
can use the same linkopts as other mlir targets that depend on
libjaxlib_mlir_capi.so. In particular, we want this to work correctly
across platforms, and it's not clear if Windows supports RPATH-like
functionality beyond the current directory.
PiperOrigin-RevId: 551372130
In this version the serialized module contain a StableHLO module
boolean attribute `jax.uses_shape_polymorphism` that specifies
whether the module uses shape polymorphism. If it doesn't then
we do not need to do shape refinement.
Note that we are still keeping the default serialization version to
6, for forward compatibility. However, the serialization unit tests
now run at version 8.
Made Exported.mlir_module a method instead of a propery, to make it
more obvious that it is a derived artifact.
Previously, the serialization would use the specified serialization version
without checking if it supported by the serialzier.
This could result in invalid serializations
Also add some compatibility tests for all supported versions.
Using bool attrs aligns better with StableHLO. Since [VHLO does not define unit attrs](https://github.com/openxla/stablehlo/blob/main/stablehlo/dialect/VhloAttrs.td), serializing StableHLO modules containing unit attrs fails. This becomes a problem when we want to serialize MHLO modules containing `mhlo.is_same_data_across_replicas` by converting them into StableHLO then VHLO.
JAX emits `mhlo.is_same_data_across_replicas` as a bool attr only after a new jaxlib version since this requires the jaxlib to understand the new attr type.
PiperOrigin-RevId: 550745955
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:
Rename opaque dtype to extended dtype.
This includes three deprecations:
- jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
- jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
- the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
Starting with serialization version 7 we introduce shape
assertions that are checked at runtime. In the process of
rolling out version 7 we encoutered projects with failed
shape assertions and it became clear that we need better
error messages.
See the changes here in tests and README.md for example of
the updated assertions.
To produce these assertions we now pass multiple operands to
the shape assertion, and we introduce a CachedShapeEvaluator
to reduce the amount of duplicate code generated.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
Previously, XlaCallModule was running the shape refinement pass for all
compilations, even if the module did not use shape polymorphism.
Currently shape refinement changes the structure of the module,
through inlining and constant folding all integer operations.
This complicates debugging because the HLO dump is very different
than the one from JAX native executions.
Starting with version 8, we run shape refinement only
if the module contains a boolean module attribute
jax.uses_shape_polymorphism=true. I think it makes sense
to put this flag as a module attribute, rather than
as a TF op attribute, because the same processing will
be needed when the module is executed from JAX.
This attribute is not yet populated by the JAX exporter.
As part of this change we moved the error check for the
number of invocation arguments from RefineDynamicShapes
to LoadAndPreprocessModule. This required adding a couple
more arguments to the loader constructor.
PiperOrigin-RevId: 549973693
The trick is to save the traceback as an XLA traceback, then turn it into a
python traceback only when throwing the error. No locals are leaked in the
process.
PiperOrigin-RevId: 549957746