325 Commits

Author SHA1 Message Date
Matthew Johnson
670a648b7b add experimental jax.no_tracing context manager 2024-08-23 21:21:55 +00:00
Yash Katariya
82c9da020a Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums.
This allows us to get more cache hits globally. For example:

Before:

```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
```

After:
```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache hit
```

Also, we can remove the hack (which I didn't like) in multihost_utils.py.

PiperOrigin-RevId: 665574475
2024-08-20 16:18:58 -07:00
Yash Katariya
1ab6279d4f Skip the global jit cpp cache if in/out_layouts are not None
PiperOrigin-RevId: 665085182
2024-08-19 18:43:23 -07:00
Yash Katariya
6e1c23610d If input layouts are specified via in_shardings to jit and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
2024-08-19 15:10:32 -07:00
Yash Katariya
daa69da321 Introduce jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...]) and allow with_sharding_constraint and shard_map to accept an abstract mesh as input (with_sharding_constraint is via NamedSharding(abstract_mesh, pspec)).
**Semantics**

Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).

Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.

**Why do this?**

There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.

So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

@jax.jit
def f(x):
  y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # DEVICE MISMATCH ERROR!
```

The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.

**Okay, so how do you fix this?**

As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)

The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.

**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)

@jax.jit
def f(x):
  y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # tracing and lowering cache hit
```

**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**

**What about `shard_map`?**

shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)

@jax.jit
def f(x):
  y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # tracing and lowering cache hit
```

This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!

PiperOrigin-RevId: 662670932
2024-08-13 15:18:08 -07:00
Yash Katariya
958234a9c1 Thread the mesh context manager to the place where we recover out_shardings back from GSPMDShardings. Before if you had a program like this:
```
with mesh:
  out = pjit(lambda: 1)()
```

The sharding of `out` was a `GSPMDSharding` which is not ideal. This change fixes that and returns a `NamedSharding` instead.

This is also required for `Shardy` integration.

PiperOrigin-RevId: 658842350
2024-08-02 11:04:48 -07:00
Yash Katariya
30037547d7 Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
PiperOrigin-RevId: 657400413
2024-07-29 18:44:31 -07:00
Yash Katariya
2eb1888c98 Make the vmap(jit) or vmap(wsc) with a concrete layout error more informative
PiperOrigin-RevId: 656176702
2024-07-25 18:32:37 -07:00
Ram Rachum
0d92d31063 Show elapsed time in nanoseconds 2024-07-25 22:20:25 +03:00
Bart Chrzaszcz
b00f978f70 #sdy Support with_sharding_constraint lowering through Shardy.
PiperOrigin-RevId: 655905063
2024-07-25 04:20:52 -07:00
Yash Katariya
51e27923e8 Simplify pjit's batching rule now that xmap is deleted. Also do cleanup around adding manual axes under shard_map
PiperOrigin-RevId: 655776234
2024-07-24 19:02:13 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Yash Katariya
b6e86c413a Remove dead code now that xmap is deleted
PiperOrigin-RevId: 655664512
2024-07-24 12:40:20 -07:00
jax authors
ac4ca35221 Merge pull request #22263 from hawkinsp:tuples
PiperOrigin-RevId: 653267867
2024-07-17 09:56:18 -07:00
Yash Katariya
bb7a6995f9 Remove the spmd_mode check. It's disabled in OSS since a long time.
PiperOrigin-RevId: 652591122
2024-07-15 13:58:23 -07:00
Yash Katariya
1e1bca0706 Check for layout mismatch between array's layout and layout specified via in_shardings to jit by only checking major_to_minor if _tiling is None. Otherwise, check the entire layout.
PiperOrigin-RevId: 651796471
2024-07-12 09:23:37 -07:00
Yash Katariya
ff3dc0f5fb Add check_compatible_aval checks to Layout. It checks if len(major_to_minor) == len(aval.shape).
PiperOrigin-RevId: 651777179
2024-07-12 08:10:43 -07:00
Yash Katariya
0426388d31 Add sharding to convert_element_type_p primitive.
There are 2 reasons for doing this:

* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.

* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.

Also fixes: https://github.com/google/jax/issues/17422

PiperOrigin-RevId: 650621659
2024-07-09 07:33:29 -07:00
Peter Hawkins
3d5784a343 Don't wrap singleton ir.Types during HLO lowering.
This is similar to https://github.com/google/jax/pull/22211, but for MLIR types instead of MLIR values.
2024-07-08 12:24:45 -04:00
Sergei Lebedev
ec7dd0fac1 `debug_info no longer requires non-None func_src_info`
I suspect in the past lack of source info meant that the function also has
no signature, but this is no longer the case.

I also removed an unused parameter from ``explain_tracing_cache_miss`` as
a drive by change.

This is a follow up to #22269.
2024-07-05 20:08:53 +01:00
jax authors
061ccd4e73 Merge pull request #22269 from superbobry:main
PiperOrigin-RevId: 649395181
2024-07-04 06:31:08 -07:00
Sergei Lebedev
ffa39c0858 Handle missing `debug_info in explain_tracing_cache_miss` 2024-07-04 14:07:10 +01:00
George Necula
a4a9499a40 [pallas] Improve some error messages and add API tests.
We make the following improvements:

  * pytree structural disequality messages now attempt to localize the
    mismatch using tree_util.KeyPath.
  * we generate a simpler error message for when `in_specs` is not
    a sequence, instead of the current PyTreeDef mismatch error.
  * we generate an error message for when the index map function
    in a BlockSpec returns an unexpected number of results.
  * added error localization to the existing shape polymorphism
    check that the block shapes are static.
  * We check that the kernel function returns None. Without this
    we used to get `body_fun output and input must have same type structure`
    in the interpreter, `assert len(jaxpr.outvars) == 0` on GPU,
    and `INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0`
    on TPU.
  * we check that the rank of the block_shape matches the rank of
    the overall array. Without this we used to get a `safe_zip`
    error. We also carry the pytree paths to localize the error.

To simplify the generation of the error messages we added a helper
function `tree_util.equality_errors_pytreedef`, which is just like
`tree_util.equality_errors` but takes `PyTreeDef` inputs rather than
PyTrees. We then used this new helper function in `pjit.py` and `stages.py`.
2024-07-04 09:02:16 +02:00
jax authors
dffd72e290 Merge pull request #22211 from hawkinsp:singletons
PiperOrigin-RevId: 649135349
2024-07-03 11:07:00 -07:00
jax authors
e6ebd55532 Merge pull request #22237 from hawkinsp:qualname
PiperOrigin-RevId: 649090742
2024-07-03 08:55:36 -07:00
Yash Katariya
884487773e Read the layout set by with_sharding_constraint and set the top module level out_layout to AUTO if wsc layout is not None.
This will allow XLA to override the entry_computation_layout with the layout set via custom call (i.e. via wsc).

PiperOrigin-RevId: 648911765
2024-07-02 19:13:27 -07:00
Peter Hawkins
f5290ddff7 Prefer __qualname__ as a pjit_p name.
If applying `jit` to a class method, it is often important to know the class name in the jaxpr.
2024-07-02 14:53:50 -04:00
Peter Hawkins
8ab0c07edc Don't wrap singleton ir.Values with tuples during HLO lowering.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.

To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
2024-07-01 16:11:00 -04:00
Yash Katariya
e1a496d3b6 Add concrete layout API to JAX. The API takes major_to_minor: tuple[int, ...] and tiling: tuple[tuple[int, ...], ...] as the arguments. Allows users to pass layouts to with_sharding_constraint to constrain the layout + sharding.
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.

memory space is exposed via JAX memories API so it doesn't have to be in the layout API.

Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.

Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.

PiperOrigin-RevId: 647487510
2024-06-27 16:47:31 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Peter Hawkins
9e30079dba [JAX] Add caching to pjit._infer_params.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.

In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.

```
name           old cpu/op   new cpu/op   delta
jit_add_chain  59.1ms ±14%  49.4ms ±10%  -16.32%  (p=0.008 n=5+5)

name           old time/op          new time/op          delta
jit_add_chain  60.3ms ±14%          50.7ms ±11%  -15.99%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 645491650
2024-06-21 13:53:04 -07:00
Peter Hawkins
637a7cbcc1 pjit.py cleanups.
Refactoring only, NFC intended.

* add types to more places.
* don't unpack PjitInfo positionally, since it's a 23-tuple and that seems rather error prone.
* change _infer_params to produce a new PjitParams NamedTuple, rather than having callers unpack a 9-tuple positionally.
* inline _pjit_jaxpr into its caller, since it only has one caller and the wrapper doesn't really clarify anything.
* note the return type of transformation_with_aux is a Callable.

PiperOrigin-RevId: 645068326
2024-06-20 09:58:22 -07:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
Yash Katariya
6ba16e0348 Add lowering_platforms to traced.lower() to allow lowering to different backends and multi-backend lowering too. In other words, enable cross-lowering!
The motivation for doing this is 2-fold:

1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.

2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.

Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.

Designed with @froystig!

PiperOrigin-RevId: 644087787
2024-06-17 11:59:10 -07:00
Junwhan Ahn
5046cedbfc Make pxla.shard_arg batch calls to xc.copy_array_to_devices_with_sharding
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.

Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.

PiperOrigin-RevId: 643097852
2024-06-13 13:10:10 -07:00
Yash Katariya
6c34a56b87 Add util.cache to jax.clear_caches and move pjit, sharding, array, etc uses of functools.lru_cache to util.cache so that those caches will be cleared if jax.clear_caches is called.
PiperOrigin-RevId: 642359226
2024-06-11 12:46:47 -07:00
jax authors
d32404020b Avoid "min() arg is an empty sequence" error after enabling "jax_explain_cache_misses".
PiperOrigin-RevId: 641381432
2024-06-07 15:52:35 -07:00
Yash Katariya
44a13c9d4b Merge code between make_jaxpr and jit(f).trace.
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.

Since we can keep the existing behavior and still merge the implementation is a good cleanup!

Fixes https://github.com/google/jax/issues/21116

PiperOrigin-RevId: 641347140
2024-06-07 13:48:31 -07:00
Roy Frostig
ea6dfd1947 rename Specialized to Traced (and specialize to trace)
PiperOrigin-RevId: 641076488
2024-06-06 17:43:08 -07:00
Yash Katariya
aee62e4874 Implement lower in terms of specialize
PiperOrigin-RevId: 641005643
2024-06-06 13:39:07 -07:00
Yash Katariya
fbf2a62aa1 Remove jaxpr and name from Lowered because specialize already has those. This keeps the abstraction boundary clear. Adapt export to use specialize.
PiperOrigin-RevId: 640968129
2024-06-06 11:38:56 -07:00
Yash Katariya
55d0f5ef8f Add lower to specialize making it a true Stage.
So now users can do:

```
specialized = jax.jit(f).specialize(*args)
print(specialized.jaxpr, specialized.out_info)

lowered = specialized.lower()

compiled = lowered.compile()
```
PiperOrigin-RevId: 640737396
2024-06-05 19:54:41 -07:00
Yash Katariya
228adb4a4a Add specialize on jax.jit and make it a Stage.
Eventually, we should use this in jax.make_jaxpr and delete all the duplicated code.

PiperOrigin-RevId: 640707223
2024-06-05 17:46:14 -07:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Yash Katariya
9e3f290de3 Delete XLACompatibleSharding and replace with jax.sharding.Sharding.
As of this change, `XLACompatibleSharding` is an alias of `jax.sharding.Sharding` but it will be deprecated in a follow up change.

Why do this?

* All shardings JAX has are XLA Compatible. The reason why `Sharding` was created was to allow non-xla shardings but that's not happened in the past 2 years. So let's simplify!

* Having these 2 types makes things very confusing. One example is:
  * `jax.jit` only accepts XLACompatibleShardings.
  * `jax.device_put` accepts `jax.sharding.Sharding` but if you use `device_put` inside `jax.jit` with a memory_kind then you can only pass `XLACompatibleSharding`. This is contradicting and confusing and we can simplify.

PiperOrigin-RevId: 640527070
2024-06-05 08:03:23 -07:00
Yash Katariya
1273028018 Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes.
PiperOrigin-RevId: 639920049
2024-06-03 14:52:50 -07:00
George Necula
be1e40dc2e Copybara import of the project:
--
f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8 by George Necula <gcnecula@gmail.com>:

[export] Fix

A user reported an error when trying to export a function
that has a "lower" attribute (to impersonate a jitted function)
but does not have a "__name__" attribute.
The solution is to use the default name "<unnamed function>".

While I was at it I have added a `util.fun_name` to get
the name of a Callable, and I use it in several places.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21572 from gnecula:exp_fix_name f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8
PiperOrigin-RevId: 639236990
2024-05-31 20:40:42 -07:00
Yash Katariya
54eaef9f62 Make sure that the sharding and unconstrained_dims in with_sharding_constraint are correct when wsc is vmapped.
In other words, if unconstrained_dims is specified, then the sharding should also contain P.UNCONSTRAINED under vmap.

PiperOrigin-RevId: 638843222
2024-05-30 17:44:51 -07:00
Yash Katariya
bfaf0b74e8 Improve the error message when users pass DeviceLocalLayout.AUTO to jax.jit and a jax.Array as an argument.
PiperOrigin-RevId: 638797194
2024-05-30 15:07:01 -07:00
Matthew Johnson
3984d822ba add error checks for vmap spmd_axis_name 2024-05-30 20:48:11 +00:00