Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.
For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.
Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.
This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.
PiperOrigin-RevId: 551604974
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
In order to measure cache savings, we add compilation time to the cache entry along with the serialized executable. The compile time can then be retrieved on a cache hit.
Testing: updated tests.
PiperOrigin-RevId: 549439628
We can make this general enough in JAX slowly and carefully and would likely require a refactor of how device_assignment is chosen.
Fixes: https://github.com/google/jax/issues/15903
PiperOrigin-RevId: 530638856
To handle Tracers, ShapedArray, concrete Arrays, etc `global_array_to_host_local_array` and `host_local_array_to_global_array` are now primitives.
PiperOrigin-RevId: 528925663
Implicit jit and apply_primitive will still raise an error though (which is recognized via inline parameter). Majority of jnp operations in JAX should be inlined.
PiperOrigin-RevId: 527398394
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.
There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.
This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.
PiperOrigin-RevId: 525534901
Use a Protocol instead of an abstract base class for the CacheInterface since it allows us to use one fewer file.
No functional change intended.
PiperOrigin-RevId: 524855263
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.
PiperOrigin-RevId: 523146076
Following are the changes:
* Make _pjit_lower_cached depend on exact sharding equality if `_original_sharding` exists. This top level cache should fill up eventually if users are passing different shardings into the pjit function.
* Split lower_sharding_computation into 3 caches:
* _trace_to_jaxpr_and_dce cache -- This will return a closed jaxpr which is DCE'd
* _cached_lowering_to_hlo cache -- This will cache the generation of MHLO. This cache is dependent on the semantic equality of shardings i.e. if 2 shardings lower to the same OpSharding, then there will be a cache hit
* _cached_compilation cache -- This caches the compilation so that we don't recompile if the shardings are semantically equal.
The way this works is the out_handlers are created again if we pass in different shardings to pjit (but there is no recompilation). This allows us to maintain the shardings passed by the user.
For ops like `jnp.squeeze` where we infer the sharding from the executable, we try to recreate a NamedSharding (right now, more support will be added in following CLs) from the GSPMDSharding since it will be available on the input.
PiperOrigin-RevId: 522991145
* Remove {in|out}_positional_semantics from pjit_p.bind
* Remove `in_is_global` from lower_sharding_computation
* Remove local_to_global and global_to_local
* Clean up some arguments of sharded_lowering since they are not needed
PiperOrigin-RevId: 517469390
Fixesiree-org/iree-jax#57
An alternative fix would've been just to add the dtype attribute to IreeBuffer.
But it seems better not to make demands on the underlying runtime objects when
we don't need to.
I had to run the test with:
`JAX_PLATFORM_NAME=iree JAX_ARRAY=0 JAX_JIT_PJIT_API_MERGE=0 python tests/dynamic_api_test.py DynamicShapeTest.test_iree_buffer_doesnt_need_dtype_attribute`
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009
Include _src/distributed.py and _src/clusters/*.py in the same target because they are in a strongly-connected component.
[XLA:Python] Set type of ArrayImpl to Any, since the JAX change now allows pytype to see that some values are ArrayImpls but ArrayImpls are not instances of jax.Array to Pytype.
Fix type of buffer_from_pyval.
PiperOrigin-RevId: 515687258