Without this, the compiled graph will still contain a node multipying a complex number with a constant 1+0j (1 is cast to complex because the other term is complex as well). This is problematic when converting to TFLite using jax2tf, because multiplying complex numbers is not supported in TFLite. With this change, the multiplication is removed from the graph all together.
PiperOrigin-RevId: 459566727
* All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs.
* `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances.
* This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled.
* cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used.
* Checking of sharding with `aval` has a handler system to deal with sharding instances.
* The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding.
* `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us.
* _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL.
* Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too.
* Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface.
* `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`.
* `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998
* `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach.
* MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done.
PiperOrigin-RevId: 459548974
CPU / VMVX runtime is now called local-task. Updated to
separate compiler, runtime, and backend naming for single
specified configuration.
PiperOrigin-RevId: 459298179
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
Introduce ShapePolyVmapPrimitivesTest to contain all the tests
that vmap results in batch polymprphic code.
Also fix some warnings about eig, eigh, and qr taking only kwarg
arguments.
Otherwise it can lead to tracer leak errors. I'm not a 100% sure how
this works out, because the sublevel counting has changed since I read
it previously. This replicates the changes applied to
DynamicJaxprTrace.process_map since I last looked at it.
Re-organizing things this way in order to:
* Clarify internally what a lowering and executable should do, rather than what current XLA-backed versions happen to provide.
* Document that some features (e.g. cost analysis) are best-effort and intended mainly for debugging purposes. They may be unimplemented on some backends and what they return is intentionally undefined.
For an example of the latter item, this change adds a `cost_analysis()` method on `jax.stages.Compiled`. However, the expression `jit(f).lower(*args).compile().cost_analysis()` may return `None` depending on backend. Otherwise, guarantees about its output and return type are very limited -- these can differ across invocations and across JAX/jaxlib versions.
Some specifics:
* Introduce `cost_analysis` and `memory_analysis` methods on `Compiled` that do as their name suggests.
* Introduce `as_text` methods on `Lowered` and `Compiled` that do as the name suggests.
* Rename `_src.stages.Computation` protocol to `_src.stages.Lowering`.
* Fix a handful of type annotations, add various docstrings and comments explaining the above.
PiperOrigin-RevId: 458574166
This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.
To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```
Issue #7323
PiperOrigin-RevId: 458551208