The original cache key generation hashes devices and backend. This
is not future proof: it does not work for accelerators other than
TPUs. Change this to use the serialized version of
PjRtTopologyDescription which is supported for all accelerators.
Note:
. CPU and PjRt C API not supported as yet.
. Stream Executor will not be supported.
Testing: revised unit test.
PiperOrigin-RevId: 564461564
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.
If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.
PiperOrigin-RevId: 564457393
Those modules have been developed initially for jax2tf
but they do not depend on TF anymore. They are used for JAX
native serialization. We move them under
jax.experimental.export (also renaming jax_export.py to export.py) so that we can use them without depending on TF.
We are leaving behind stub modules jax2tf.jax_export and jax2tf.shape_poly that just redirect some of the public APIs. To be cleaned later.
PiperOrigin-RevId: 562988740
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.
TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).
The net effect of this change is mostly to tighten up many test tolerances on TPU.
PiperOrigin-RevId: 562953488
Removes callback testing function and uses io_callback
and pure_callback instead. This allows us to remove
some tests from the PureCallbackTest class.
Renames IoPythonCallbackTest -> IoCallbackTest and PurePythonCallbackTest -> PureCallbackTest.
PiperOrigin-RevId: 562285255
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!
The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!
The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.
Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa
PiperOrigin-RevId: 561805840
The original cache key generation hashes individual fields of
CompileOptions, ExecutableBuildOptions, and DebugOptions. This
is not future proof: when a field is added to any of these
structures, the corresponding hash needs to be added to the
cache key generation. The new cache key generation algorithm
hashes the serialized representation of CompileOptions.
Some DebugOptions do not affect the compilation result;
exclude them from the computation. If additional fields are
identified, they can be added; such additions will reduce
unnecessary cache misses.
Testing: revised unit test.
PiperOrigin-RevId: 561803875
Metrics:
1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true.
2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation.
3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository.
Created a context manager to register/unregister event listeners.
PiperOrigin-RevId: 561771262
This is true for mock devices or user specific devices and jax.devices() too.
Fix the tests so that the mock devices are hashable.
PiperOrigin-RevId: 561103167
This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities.
Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules.
This function has two benefits over just building the stablehlo directly:
a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes
b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults).
Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper.
PiperOrigin-RevId: 561042402
This is done by returning the same object when constructing mesh if devices.shape, axis_names and flat device list matches.
PiperOrigin-RevId: 560828993