In presence of shape polymorphism and multi-platorm lowering
we pass the global values for the dimension variables and the
platform index to all inner functions. At the moment, prior to
compilation we run a shape refinement pass to infer which of
the arguments of a function carry such global values.
This inference can yield false positives, e.g., when a
user-defined function is called with a constant int32 as the first argument.
With this change we do not need to infer anymore the arguments
that carry global constants. This is in preparation for a more
reliable implementation of shape refinement.
Previously we declared the lowering rule for call_exported to be
platform specific. This was correct, but in the case when the
caller function is lowered itself for multiple platforms this results
in multiple copies of the inner called Exported. Now instead we
make the call_exported rule be platform independent and make it
compute the platform index for the called module based on the
platform index in the caller module. This results in a single
copy of the HLO for the called module in the output.
The longer term goal here is to move away from having the config object as
part of the public API and migrate towards module-level functions instead.
Note that we can preserve the dynamic attribute lookup behavior of the
config object via a module-level `__getattr__`
The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier.
Fixes https://github.com/google/jax/issues/18122
PiperOrigin-RevId: 573955671
Enable experiments with jax2tf native serialization for
multiple platforms. This feature is not yet fully functional
but we need this change to enable further testing.
Cleanup some of the places that are specific to single-platform
serialization, e.g., `lowering_platform`, and generalize
them to multiple platforms (`lowering_platforms`).
Metrics:
1) '/jax/compilation_cache/cache_hits' to track the number of times the cached executable is successfully returned from a cache read using the new implementation.
2) '/jax/compilation_cache/compile_time_saved_sec' to record the time saved on cache hits using the new implementation.
PiperOrigin-RevId: 573019115
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
Previously, when we call_exported of an Exported module
with shardings, we invoke the right HLO but the enclosing
JAX computation does not know about the shardings of the
called module. This results in errors when invoking the
calling module.
We change call_exported lowering rules to add sharding
constraints for the inputs and the outputs and we add
a check that we call the exported module on the same
number of devices as at export time.
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
PiperOrigin-RevId: 572587137
When take_ownership is true, the original buffer is marked as deleted and enforced that JAX won't attempt to read or write the buffer. This provides better error checking but at the cost of one more C++ API and two more C APIs. The same semantic can be achieved by not using take_ownership and being careful. Therefore we decided to remove take_ownership support in DLPack.
PiperOrigin-RevId: 572278488