464 Commits

Author SHA1 Message Date
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Yash Katariya
b6e86c413a Remove dead code now that xmap is deleted
PiperOrigin-RevId: 655664512
2024-07-24 12:40:20 -07:00
jax authors
ac4ca35221 Merge pull request #22263 from hawkinsp:tuples
PiperOrigin-RevId: 653267867
2024-07-17 09:56:18 -07:00
Yash Katariya
bb7a6995f9 Remove the spmd_mode check. It's disabled in OSS since a long time.
PiperOrigin-RevId: 652591122
2024-07-15 13:58:23 -07:00
Yash Katariya
1e1bca0706 Check for layout mismatch between array's layout and layout specified via in_shardings to jit by only checking major_to_minor if _tiling is None. Otherwise, check the entire layout.
PiperOrigin-RevId: 651796471
2024-07-12 09:23:37 -07:00
Yash Katariya
ff3dc0f5fb Add check_compatible_aval checks to Layout. It checks if len(major_to_minor) == len(aval.shape).
PiperOrigin-RevId: 651777179
2024-07-12 08:10:43 -07:00
Yash Katariya
0426388d31 Add sharding to convert_element_type_p primitive.
There are 2 reasons for doing this:

* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.

* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.

Also fixes: https://github.com/google/jax/issues/17422

PiperOrigin-RevId: 650621659
2024-07-09 07:33:29 -07:00
Peter Hawkins
3d5784a343 Don't wrap singleton ir.Types during HLO lowering.
This is similar to https://github.com/google/jax/pull/22211, but for MLIR types instead of MLIR values.
2024-07-08 12:24:45 -04:00
Sergei Lebedev
ec7dd0fac1 `debug_info no longer requires non-None func_src_info`
I suspect in the past lack of source info meant that the function also has
no signature, but this is no longer the case.

I also removed an unused parameter from ``explain_tracing_cache_miss`` as
a drive by change.

This is a follow up to #22269.
2024-07-05 20:08:53 +01:00
jax authors
061ccd4e73 Merge pull request #22269 from superbobry:main
PiperOrigin-RevId: 649395181
2024-07-04 06:31:08 -07:00
Sergei Lebedev
ffa39c0858 Handle missing `debug_info in explain_tracing_cache_miss` 2024-07-04 14:07:10 +01:00
George Necula
a4a9499a40 [pallas] Improve some error messages and add API tests.
We make the following improvements:

  * pytree structural disequality messages now attempt to localize the
    mismatch using tree_util.KeyPath.
  * we generate a simpler error message for when `in_specs` is not
    a sequence, instead of the current PyTreeDef mismatch error.
  * we generate an error message for when the index map function
    in a BlockSpec returns an unexpected number of results.
  * added error localization to the existing shape polymorphism
    check that the block shapes are static.
  * We check that the kernel function returns None. Without this
    we used to get `body_fun output and input must have same type structure`
    in the interpreter, `assert len(jaxpr.outvars) == 0` on GPU,
    and `INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0`
    on TPU.
  * we check that the rank of the block_shape matches the rank of
    the overall array. Without this we used to get a `safe_zip`
    error. We also carry the pytree paths to localize the error.

To simplify the generation of the error messages we added a helper
function `tree_util.equality_errors_pytreedef`, which is just like
`tree_util.equality_errors` but takes `PyTreeDef` inputs rather than
PyTrees. We then used this new helper function in `pjit.py` and `stages.py`.
2024-07-04 09:02:16 +02:00
jax authors
dffd72e290 Merge pull request #22211 from hawkinsp:singletons
PiperOrigin-RevId: 649135349
2024-07-03 11:07:00 -07:00
jax authors
e6ebd55532 Merge pull request #22237 from hawkinsp:qualname
PiperOrigin-RevId: 649090742
2024-07-03 08:55:36 -07:00
Yash Katariya
884487773e Read the layout set by with_sharding_constraint and set the top module level out_layout to AUTO if wsc layout is not None.
This will allow XLA to override the entry_computation_layout with the layout set via custom call (i.e. via wsc).

PiperOrigin-RevId: 648911765
2024-07-02 19:13:27 -07:00
Peter Hawkins
f5290ddff7 Prefer __qualname__ as a pjit_p name.
If applying `jit` to a class method, it is often important to know the class name in the jaxpr.
2024-07-02 14:53:50 -04:00
Peter Hawkins
8ab0c07edc Don't wrap singleton ir.Values with tuples during HLO lowering.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.

To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
2024-07-01 16:11:00 -04:00
Yash Katariya
e1a496d3b6 Add concrete layout API to JAX. The API takes major_to_minor: tuple[int, ...] and tiling: tuple[tuple[int, ...], ...] as the arguments. Allows users to pass layouts to with_sharding_constraint to constrain the layout + sharding.
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.

memory space is exposed via JAX memories API so it doesn't have to be in the layout API.

Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.

Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.

PiperOrigin-RevId: 647487510
2024-06-27 16:47:31 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Peter Hawkins
9e30079dba [JAX] Add caching to pjit._infer_params.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.

In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.

```
name           old cpu/op   new cpu/op   delta
jit_add_chain  59.1ms ±14%  49.4ms ±10%  -16.32%  (p=0.008 n=5+5)

name           old time/op          new time/op          delta
jit_add_chain  60.3ms ±14%          50.7ms ±11%  -15.99%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 645491650
2024-06-21 13:53:04 -07:00
Peter Hawkins
637a7cbcc1 pjit.py cleanups.
Refactoring only, NFC intended.

* add types to more places.
* don't unpack PjitInfo positionally, since it's a 23-tuple and that seems rather error prone.
* change _infer_params to produce a new PjitParams NamedTuple, rather than having callers unpack a 9-tuple positionally.
* inline _pjit_jaxpr into its caller, since it only has one caller and the wrapper doesn't really clarify anything.
* note the return type of transformation_with_aux is a Callable.

PiperOrigin-RevId: 645068326
2024-06-20 09:58:22 -07:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
Yash Katariya
6ba16e0348 Add lowering_platforms to traced.lower() to allow lowering to different backends and multi-backend lowering too. In other words, enable cross-lowering!
The motivation for doing this is 2-fold:

1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.

2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.

Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.

Designed with @froystig!

PiperOrigin-RevId: 644087787
2024-06-17 11:59:10 -07:00
Junwhan Ahn
5046cedbfc Make pxla.shard_arg batch calls to xc.copy_array_to_devices_with_sharding
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.

Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.

PiperOrigin-RevId: 643097852
2024-06-13 13:10:10 -07:00
Yash Katariya
6c34a56b87 Add util.cache to jax.clear_caches and move pjit, sharding, array, etc uses of functools.lru_cache to util.cache so that those caches will be cleared if jax.clear_caches is called.
PiperOrigin-RevId: 642359226
2024-06-11 12:46:47 -07:00
jax authors
d32404020b Avoid "min() arg is an empty sequence" error after enabling "jax_explain_cache_misses".
PiperOrigin-RevId: 641381432
2024-06-07 15:52:35 -07:00
Yash Katariya
44a13c9d4b Merge code between make_jaxpr and jit(f).trace.
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.

Since we can keep the existing behavior and still merge the implementation is a good cleanup!

Fixes https://github.com/google/jax/issues/21116

PiperOrigin-RevId: 641347140
2024-06-07 13:48:31 -07:00
Roy Frostig
ea6dfd1947 rename Specialized to Traced (and specialize to trace)
PiperOrigin-RevId: 641076488
2024-06-06 17:43:08 -07:00
Yash Katariya
aee62e4874 Implement lower in terms of specialize
PiperOrigin-RevId: 641005643
2024-06-06 13:39:07 -07:00
Yash Katariya
fbf2a62aa1 Remove jaxpr and name from Lowered because specialize already has those. This keeps the abstraction boundary clear. Adapt export to use specialize.
PiperOrigin-RevId: 640968129
2024-06-06 11:38:56 -07:00
Yash Katariya
55d0f5ef8f Add lower to specialize making it a true Stage.
So now users can do:

```
specialized = jax.jit(f).specialize(*args)
print(specialized.jaxpr, specialized.out_info)

lowered = specialized.lower()

compiled = lowered.compile()
```
PiperOrigin-RevId: 640737396
2024-06-05 19:54:41 -07:00
Yash Katariya
228adb4a4a Add specialize on jax.jit and make it a Stage.
Eventually, we should use this in jax.make_jaxpr and delete all the duplicated code.

PiperOrigin-RevId: 640707223
2024-06-05 17:46:14 -07:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Yash Katariya
9e3f290de3 Delete XLACompatibleSharding and replace with jax.sharding.Sharding.
As of this change, `XLACompatibleSharding` is an alias of `jax.sharding.Sharding` but it will be deprecated in a follow up change.

Why do this?

* All shardings JAX has are XLA Compatible. The reason why `Sharding` was created was to allow non-xla shardings but that's not happened in the past 2 years. So let's simplify!

* Having these 2 types makes things very confusing. One example is:
  * `jax.jit` only accepts XLACompatibleShardings.
  * `jax.device_put` accepts `jax.sharding.Sharding` but if you use `device_put` inside `jax.jit` with a memory_kind then you can only pass `XLACompatibleSharding`. This is contradicting and confusing and we can simplify.

PiperOrigin-RevId: 640527070
2024-06-05 08:03:23 -07:00
Yash Katariya
1273028018 Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes.
PiperOrigin-RevId: 639920049
2024-06-03 14:52:50 -07:00
George Necula
be1e40dc2e Copybara import of the project:
--
f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8 by George Necula <gcnecula@gmail.com>:

[export] Fix

A user reported an error when trying to export a function
that has a "lower" attribute (to impersonate a jitted function)
but does not have a "__name__" attribute.
The solution is to use the default name "<unnamed function>".

While I was at it I have added a `util.fun_name` to get
the name of a Callable, and I use it in several places.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21572 from gnecula:exp_fix_name f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8
PiperOrigin-RevId: 639236990
2024-05-31 20:40:42 -07:00
Yash Katariya
54eaef9f62 Make sure that the sharding and unconstrained_dims in with_sharding_constraint are correct when wsc is vmapped.
In other words, if unconstrained_dims is specified, then the sharding should also contain P.UNCONSTRAINED under vmap.

PiperOrigin-RevId: 638843222
2024-05-30 17:44:51 -07:00
Yash Katariya
bfaf0b74e8 Improve the error message when users pass DeviceLocalLayout.AUTO to jax.jit and a jax.Array as an argument.
PiperOrigin-RevId: 638797194
2024-05-30 15:07:01 -07:00
Matthew Johnson
3984d822ba add error checks for vmap spmd_axis_name 2024-05-30 20:48:11 +00:00
Yash Katariya
73a8e7f0a8 Rename is_compatible_aval to check_compatible_aval since it returns None and not a bool.
PiperOrigin-RevId: 638431968
2024-05-29 15:29:32 -07:00
jax authors
fafe740400 Merge pull request #21478 from google:pytree-attrs
PiperOrigin-RevId: 638359249
2024-05-29 11:41:42 -07:00
Matthew Johnson
a4622b6a29 fix weak key cache stuff
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-05-29 17:53:56 +00:00
jax authors
26f9820417 [JAX] Automatically share PGO data for GPU latency-hiding scheduler.
Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data.

1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results.

2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner.

3. The profile session runner should be passed to pxla.py and then called.

4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h

5. Once FDO is collected we need to share it between hosts to keep deterministic compilation.

PiperOrigin-RevId: 638197166
2024-05-29 01:50:03 -07:00
Dougal
122924fdf3 Make attrs work with pytrees
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-28 23:23:51 -04:00
George Necula
bb4c073574 Add the function name, the Jaxpr, and lowering platforms to Lowered.
These changes are necessary to ensure that `Lowered` carries all the
information that is needed for export and serialization.
These are in preparation of a cleanup of the exporting and serialization APIs
to integrate them with the AOT APIs. In particular, exporting will start
with a `Lowered` object and will not include anymore its own lowering code.

We add the lowered function name and the Jaxpr (as the attributes `_fun_name` and `_jaxpr`)
to `Lowered`,
and we add the tuple of lowering platforms (as `Lowered._lowering._platforms`).

The function name is useful for better error messages when exporting and
serializating. The Jaxpr is useful for exporting also the VJP of the function
and obtaining an `Exported` that can be differentiated.
2024-05-29 05:04:17 +03:00
Yash Katariya
4fae9aa160 Support eager unified memory computations
PiperOrigin-RevId: 638073121
2024-05-28 16:59:26 -07:00
Yash Katariya
ff3db9b3a1 Update the deprecation message of backend and device argument of jit to be more actionable.
PiperOrigin-RevId: 637899890
2024-05-28 08:00:23 -07:00
Matthew Johnson
0a693faf48 add pjit forwarding rule
Co-authored-by: Roy Frostig <frostig@google.com>
2024-05-25 17:46:01 +00:00
jax authors
42fc69b26e Internal cleanup
PiperOrigin-RevId: 636518124
2024-05-23 05:35:53 -07:00
Yash Katariya
711190155d Initialize JaxprEqnContext only in new_jaxpr_eqn and new_eqn_recipe with the current active compute type if no ctx is specified.
PiperOrigin-RevId: 636309959
2024-05-22 15:16:58 -07:00