439 Commits

Author SHA1 Message Date
Yash Katariya
837f0bbf6f Cache the _check_sharding check in device_put. If aval and sharding are the same, no need to check multiple times
PiperOrigin-RevId: 626244240
2024-04-18 21:26:35 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
jax authors
bb8cf34a31 Document the fact that jax.clear_caches() doesn't affect the persistent cache.
PiperOrigin-RevId: 626019057
2024-04-18 06:52:40 -07:00
Yash Katariya
90401d51e9 Accept layout on ShapeDtypeStruct on the sharding argument. DeviceLocalLayout.AUTO is not allowed on SDS.
PiperOrigin-RevId: 624982814
2024-04-15 09:19:40 -07:00
Matthew Johnson
83a200a42f simple fix to make_jaxpr docstring
maybe it was accidentally copied from xla_computation before?
2024-04-11 21:50:51 -07:00
Sai-Suraj-27
5564521308 Prefer raising of TypeError for invalid types instead of ValueError. 2024-04-08 13:08:24 +05:30
Yash Katariya
92326dbc71 Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
2024-04-03 16:13:31 -07:00
Peter Hawkins
d3e03fff5d Refactorings to the jit implementation.
Notably:
* We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code.
* Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback.
* If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me.

No functional changes intended, this is just to improve readability.

PiperOrigin-RevId: 617812557
2024-03-21 05:37:32 -07:00
Junwhan Ahn
f569031456 Reverts 55394a0914dc0583427a4ceb73dac56348911d15
PiperOrigin-RevId: 616201321
2024-03-15 11:56:45 -07:00
Peter Hawkins
642f20de1c [JAX] Convert stablehlo to MLIR bytecode, not an MLIR string.
Bytecode is considerably more compact.

PiperOrigin-RevId: 615386276
2024-03-13 06:02:18 -07:00
Roy Frostig
98f790f5d5 update package/API reference docs to new-style typed PRNG keys 2024-03-07 12:40:09 -08:00
Sergei Lebedev
5283d4b4a5 Axis names are now tracked via an effect
This allows propagating the names bottom up -- from equations to the jaxpr,
instead of "discovering" them top-down by traversing (and rebuilding) the
jaxpr via core.subst_axis_names.

PiperOrigin-RevId: 612416803
2024-03-04 05:42:03 -08:00
Matthew Johnson
3736b322b7 [xmap-removal] remove reduce_axes from grad / vjp / backward_pass
The reduce_axes machinery was planned to be used for xmap. It's not needed for
e.g. shard_map, see https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html.
2024-02-25 15:50:54 -08:00
Matthew Johnson
b0b88d87d3 [attrs] add linearize and vjp support 2024-02-23 16:43:49 -08:00
Peter Hawkins
b5e4ba4900 Don't call inspect.signature() each time we trace a jit().
We can just call it once when jit itself is called.

While we're here, also don't recompute api_util.fun_sourceinfo.

PiperOrigin-RevId: 607443283
2024-02-15 13:49:27 -08:00
Jake VanderPlas
82611eb8ae document that under disable_jit, individual primitives are still compiled 2024-02-05 12:01:33 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Peter Hawkins
fc6df3218c Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.

i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.

Why do this?

The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.

The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.

This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.

Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.

The change is disabled by default, so we do not expect any user visible impacts from this change.

PiperOrigin-RevId: 599787818
2024-01-19 03:53:37 -08:00
Yash Katariya
f04f305489 Make eval_shape a wrapper around jax.jit(f).eval_shape(*args, **kwargs)
PiperOrigin-RevId: 599724490
2024-01-18 22:10:57 -08:00
Matthew Johnson
0d4f200e08 Allow unhashable callables in jax.eval_shape.
PiperOrigin-RevId: 599691923
2024-01-18 19:16:48 -08:00
Yash Katariya
51ef738c86 Use jit's jaxpr creation function for eval_shape to maximize tracing cache hits.
This comes up in LLM models, where we trace twice (one for eval_shape (usually the init function) and another during jit) when the output jaxpr is the same. This shouldn't happen and we should cache as much as possible.

The only caveat here is that in eval_shape the `traced_for` on `DebugInfo` is set to `jit`. But maybe it's ok to do that if we want to deprecate eval_shape for a AOT style method on `jax.jit` or have it be a thin wrapper around something like `jax.jit(f).eval_shape`

PiperOrigin-RevId: 599602407
2024-01-18 13:11:44 -08:00
Yash Katariya
b8098b1782 Remove indices and devices from shard_arg_handlers and shard_args.
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
2024-01-05 14:17:14 -08:00
Tom Cobley
ebc7af95df Fix typo in pmap docstring
Docstring states:
>  If the pmapped function is called with fewer positional arguments than indicated by **`static_argnums`** then an error is raised.

However `static_argnums` is not an argument that exists - I believe this should be corrected to `static_broadcasted_argnums`.

PiperOrigin-RevId: 595731210
2024-01-04 09:50:00 -08:00
Axel Donath
e8330b5fc5 Add eval_shape example for function with static arguments
Improve wording and formating of dynamic eval_shape example
2023-12-18 19:19:48 -05:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
Jake VanderPlas
35b84402c0 Deprecate arr.device_buffer and arr.device_buffers 2023-12-06 10:20:29 -08:00
Yash Katariya
f0bc7e0fc6 Reverts f0382a5838f4526d21631e804f6fe576bfc3f97e
PiperOrigin-RevId: 587231484
2023-12-01 22:06:33 -08:00
jax authors
57e19db104 Merge pull request #18736 from mattjj:device-put-fixes
PiperOrigin-RevId: 586490689
2023-11-29 16:51:15 -08:00
Matthew Johnson
c9ab0bfd3c fix grad device_put src inference, and a small device_put bug
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-11-29 16:24:24 -08:00
Jake VanderPlas
0aec40a16f Deprecate arr.device_buffer and arr.device_buffers 2023-11-29 15:31:01 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Yash Katariya
439b89e47f Remove DefaultLayout and make None same as DefaultLayout
PiperOrigin-RevId: 583221970
2023-11-16 18:01:27 -08:00
jax authors
7657a0fb15 Merge pull request #18539 from NeilGirdhar:ruff
PiperOrigin-RevId: 583105786
2023-11-16 11:15:19 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
Jake VanderPlas
0bcd64ade3 jax.vmap: improve docs & error for structured in_axes 2023-11-15 11:56:53 -08:00
Yash Katariya
5c3da219c0 Add a private API to allow setting layouts on jitted computations.
We expose 3 modes:

* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.

* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.

* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.

Public API coming soon.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036
2023-11-15 08:48:53 -08:00
Etienne Pot
7cf66dfe4b Fix typing annotations for jax.named_call
PiperOrigin-RevId: 582235119
2023-11-14 01:47:42 -08:00
Junwhan Ahn
55394a0914 Roll back the optimized version of jax.block_until_ready due to test breakage
Reverts 6cc6d093643c0265c7de4027f79879f6945e0342

PiperOrigin-RevId: 581577789
2023-11-11 12:15:45 -08:00
Junwhan Ahn
6cc6d09364 Implement more efficient jax.block_until_ready(x) in C++
The current implementation synchronously calls `ArrayImpl.block_until_ready()` one by one. This is suboptimal when it's not cheap to query the readiness of an array. Also, calling `x.block_until_ready()` causes GIL to be acquired/released repeatedly.

To address this issue, this CL introduces a C++ implementation of `jax.block_until_ready(x)` that uses IFRT's `Array::GetReadyFuture()` to asynchronously query the readiness of all arrays and wait for them once. To preserve the previous behavior, the C++ implementation also has a slow path for any non-PyArray objects that implement `block_until_ready`.

PiperOrigin-RevId: 581302290
2023-11-10 10:34:34 -08:00
jax authors
62741d9744 Reverts 81ac67f38164b7626d733d081a87ff49b235b9d0
PiperOrigin-RevId: 579010408
2023-11-02 16:17:29 -07:00
Etienne Pot
81ac67f381 Fix typing annotations for @jax.named_call
PiperOrigin-RevId: 578852649
2023-11-02 07:55:04 -07:00
George Necula
edbe49fb2a Cleanup the handling of single- and multi-platform lowering in ModuleContext
Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.

PiperOrigin-RevId: 576575376
2023-10-25 10:40:41 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Jake VanderPlas
709b05f12f jax.make_jaxpr: fix __name__ & related attributes 2023-10-09 15:12:28 -07:00
Matthew Willson
17d89ad166 Fix jax.device_put so it doesn't use tree_map for _check_sharding.
This causes it to unnecessarily attempt to unflatten the None return values from _check_sharding into the original tree structure, which is a problem for custom datatypes registered with jax.tree_util that don't accept None values in place of jax arrays.

PiperOrigin-RevId: 570189648
2023-10-02 15:01:03 -07:00
George Necula
552fef6fcd Introduce a LoweringParameters dataclass for easier plumbing
There are currently two parameters that are used to configure
lowering: lowering_platform (for cross-platform lowering), and
override_lowering_rules. Each of them are passed as separate arguments
through several layers of lowering internal functions. This is tedious,
and error prone. In fact, override_lowering_rules was not plumbed
in all places, and due to using default arguments in all places,
this leads to silent errors.

We foresee introducing other parameters for lowering: for multi-platform
lowering, for controlling the lowering of effects.

Here is pack all such parameters into a `mlir.LoweringParameters`
dataclass and we plumb that through.
2023-09-29 08:23:05 +03:00
Yash Katariya
603c879fa0 Run _check_sharding checks during api.device_put instead of in the impl rule so that we don't have to repeat these checks in each rule of device_put.
The same is done for jit and with_sharding_constraint.

PiperOrigin-RevId: 561380348
2023-08-30 10:27:37 -07:00
Peter Hawkins
92128d4083 Remove backward compatibility code related to pytree registries.
We always have a default_registry now, so we don't need to protect code that uses it with conditionals. A number of type suppressions are also stale.

PiperOrigin-RevId: 560849610
2023-08-28 16:26:25 -07:00
Yash Katariya
aeb62cc006 Add TransferToMemoryKind as a private API to allow device_put to transfer to different memories without specifying the sharding and allowing the SPMD partitioner to choose the sharding for the intermediate.
Exposing it as a public API can be done later.

PiperOrigin-RevId: 559314369
2023-08-22 22:11:38 -07:00