200 Commits

Author SHA1 Message Date
Yash Katariya
90892f533a Check for jax.Sharding's number of devices instead of py_array.num_shards which looks at IFRT sharding's num_devices to check against global_devices and deciding whether to fall back to python shard_arg.
This is because IFRT sharding's `num_shards` method is busted. It doesn't return the global shards (in some cases) which leads to JAX program unnecessarily falling back to python.

PiperOrigin-RevId: 673067095
2024-09-10 12:43:52 -07:00
Peter Hawkins
1b2ba9d1c2 Disable two lax_scipy_test testcases that fail on TPU v6e.
PiperOrigin-RevId: 672973757
2024-09-10 08:26:27 -07:00
Yash Katariya
e1b497078e Rename jtu.create_global_mesh to jtu.create_mesh and use jax.make_mesh inside jtu.create_mesh to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
Jake VanderPlas
68be5b5085 CI: update ruff to v0.6.1 2024-08-27 14:54:11 -07:00
Yash Katariya
daa69da321 Introduce jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...]) and allow with_sharding_constraint and shard_map to accept an abstract mesh as input (with_sharding_constraint is via NamedSharding(abstract_mesh, pspec)).
**Semantics**

Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).

Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.

**Why do this?**

There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.

So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

@jax.jit
def f(x):
  y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # DEVICE MISMATCH ERROR!
```

The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.

**Okay, so how do you fix this?**

As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)

The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.

**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)

@jax.jit
def f(x):
  y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # tracing and lowering cache hit
```

**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**

**What about `shard_map`?**

shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)

@jax.jit
def f(x):
  y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # tracing and lowering cache hit
```

This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!

PiperOrigin-RevId: 662670932
2024-08-13 15:18:08 -07:00
Yash Katariya
abc9ba00e9 Rename count_jit_and_pmap_compiles to count_jit_and_pmap_lowerings
PiperOrigin-RevId: 661496993
2024-08-09 20:03:43 -07:00
jax authors
c2c04e054e Merge pull request #22608 from kaixih:fix_cuda_version_check
PiperOrigin-RevId: 659664879
2024-08-05 13:34:43 -07:00
kaixih
09b88430e9 Fix CUDA version checks 2024-08-05 20:09:17 +00:00
Pearu Peterson
780b10b4c4 Update complex functions accuracy tests 2024-08-02 23:31:51 +03:00
Sergei Lebedev
fb1dbf15df Bumped mypy to 1.11.0 and jaxlib to 0.4.31 on the CI 2024-08-01 22:30:24 +01:00
Jake VanderPlas
f887b66d5d Remove the unaccelerate_deprecation utility 2024-07-23 05:07:49 -07:00
George Necula
d3454f374e Add some hypothesis testing utilities and developer documentation.
Add a helper function for setting up hypothesis testing,
with support for selecting an interactive hypothesis profile
that speeds up interactive development.
2024-07-15 17:05:32 +02:00
Sergei Lebedev
0dff794f68 Added test assertions for `pl.debug_print` on TPU
PiperOrigin-RevId: 650651472
2024-07-09 09:15:16 -07:00
George Necula
f02d32c680 [pallas] Fix the interpreter for block_shape not dividing the overall shape
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.

This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
2024-07-09 16:10:22 +03:00
Sergei Lebedev
1ff07c2fbe Added some documentation to `jtu.capture_stdout`
PiperOrigin-RevId: 648830525
2024-07-02 13:46:19 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
jax authors
96cf5d53c8 Merge pull request #21916 from ROCm:ci_pjrt
PiperOrigin-RevId: 646793145
2024-06-26 02:43:21 -07:00
Peter Hawkins
9e30079dba [JAX] Add caching to pjit._infer_params.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.

In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.

```
name           old cpu/op   new cpu/op   delta
jit_add_chain  59.1ms ±14%  49.4ms ±10%  -16.32%  (p=0.008 n=5+5)

name           old time/op          new time/op          delta
jit_add_chain  60.3ms ±14%          50.7ms ±11%  -15.99%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 645491650
2024-06-21 13:53:04 -07:00
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
Yash Katariya
175183775b Replace jax.xla_computation with the AOT API and add a way to unaccelerate the deprecation in jax tests.
PiperOrigin-RevId: 644535402
2024-06-18 15:47:24 -07:00
jax authors
c8cdf303fb Merge pull request #20761 from superbobry:config
PiperOrigin-RevId: 644435642
2024-06-18 10:33:44 -07:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
Sergei Lebedev
ce0d9e9b9f Changed the naming of internal config APIs
The new naming highlights that we have two kinds of configuration options:
flags, set at most once, and states, which can be changed locally per thread
via a context manager.

The renames are

* FlagHolder -> Flag
* DEFINE_<type> -> <type>_flag
* _StateContextManager -> State
* define_<type>_state -> <type>_state
2024-06-18 11:48:57 +01:00
jax authors
f819c344ee Merge pull request #21645 from andportnoy:aportnoy/cuda-custom-call
PiperOrigin-RevId: 644124155
2024-06-17 13:56:52 -07:00
Andrey Portnoy
ec5c4f5a10 Add CUDA custom call example as a JAX test 2024-06-17 15:21:49 -04:00
Ruturaj4
99c2b7b4e9 [ROCm] Bring-up pjrt support 2024-06-17 16:49:22 +00:00
Junwhan Ahn
5046cedbfc Make pxla.shard_arg batch calls to xc.copy_array_to_devices_with_sharding
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.

Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.

PiperOrigin-RevId: 643097852
2024-06-13 13:10:10 -07:00
Jake VanderPlas
a861c55a28 test cleanup: use ExitStack to reduce test boilerplate 2024-06-06 14:18:27 -07:00
Mark Sandler
2c246df439 Reverts dfe61285093ff826e1ad23bb36b77a42c01040b4
PiperOrigin-RevId: 640987745
2024-06-06 12:41:17 -07:00
Christos Perivolaropoulos
18e55d567f [test_utils] Fix the encoding of capture_stdout so it works on windows.
PiperOrigin-RevId: 640910749
2024-06-06 08:43:25 -07:00
Peter Hawkins
dfe6128509 Reverts da816d34eaad6a1c6536959ccb4bfee4466c037d
PiperOrigin-RevId: 640886105
2024-06-06 07:10:09 -07:00
Mark Sandler
da816d34ea Makes global_shape optional for jax.make_array_from_process_local_data.
PiperOrigin-RevId: 640695090
2024-06-05 16:58:08 -07:00
Jake VanderPlas
f04a2279a5 shape_poly_test: adjust configs via jtu.global_config_context 2024-06-05 10:45:56 -07:00
Jake VanderPlas
2333d5c7c3 Add validation that tests do not change global configs 2024-06-04 09:58:50 -07:00
Christos Perivolaropoulos
9939cc9974 test_util.capture_stdout redirects using file descriptors rather than mocking the python interface.
PiperOrigin-RevId: 640183718
2024-06-04 09:41:47 -07:00
Jake VanderPlas
570325d6f5 JaxTestCase: set default configs during class setup 2024-05-29 10:24:55 -07:00
jax authors
26f9820417 [JAX] Automatically share PGO data for GPU latency-hiding scheduler.
Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data.

1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results.

2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner.

3. The profile session runner should be passed to pxla.py and then called.

4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h

5. Once FDO is collected we need to share it between hosts to keep deterministic compilation.

PiperOrigin-RevId: 638197166
2024-05-29 01:50:03 -07:00
Sergei Lebedev
2473ebf508 Removed mentions of iree from the test suite 2024-05-24 10:31:57 +01:00
Mark Sandler
8f045cafd2 Add jax.make_array_from_process_local_data to create a distributed tensor from host data and supporting scaffolding in sharding to be able to figure out dimensions of host data required.
PiperOrigin-RevId: 634205261
2024-05-15 22:06:45 -07:00
jax authors
962f084543 Merge pull request #21137 from superbobry:pallas
PiperOrigin-RevId: 631923082
2024-05-08 14:20:10 -07:00
jax authors
65d4c688e0 Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.

However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).

For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.

In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition

PiperOrigin-RevId: 631916764
2024-05-08 14:00:39 -07:00
Sergei Lebedev
4b62425b42 Renamed is_device_gpu_at_least to is_cuda_compute_capability_at_least
This makes it clear that the predicate is only supposed to be used for NVidia
GPUs at the moment.
2024-05-08 21:41:50 +01:00
Sergei Lebedev
575ba942e0 Removed get_compute_capability from jax.experimental.pallas.gpu
Compute capability is available as a `str` attribute on a GPU device since
jaxlib 0.4.26.
2024-05-08 21:10:43 +01:00
Pearu Peterson
ee5c134e66 Workaround mpmath 1.3 issues in asinh evaluation at infinities 2024-04-24 23:52:01 +03:00
Pearu Peterson
e8ff7028f4 Workaround mpmath 1.3 issues in asin and asinh evaluation at infinities and on branch cuts. 2024-04-23 21:01:43 +03:00
Sergei Lebedev
a13efc2815 Added int4 and uint4 to dtype-specific tests
I probably missed some cases, so this PR is really just the first step in
making sure we have good *int4 coverage.
2024-04-18 15:20:20 +01:00
Pearu Peterson
fc04ba983c Workaround mpmath 1.3 bugs in tan and tanh evaluation at infinities 2024-04-10 18:26:07 +03:00
Pearu Peterson
2ef5bc6075 Workaround numpy 1.x assert_allclose false-positive result in comparing complex infinities. 2024-04-04 11:19:57 +03:00
Pearu Peterson
9a7fb898d4 Workaround mpmath bug (mpmath/mpmath#774) in log1p at complex infinities
Temporarily disable arctanh success tests that depend on log1p fixes
2024-04-03 18:48:26 +03:00
Pearu Peterson
fdb5015909 Evaluate the correctness of JAX complex functions using mpmath as a reference 2024-03-21 23:35:29 +02:00