425 Commits

Author SHA1 Message Date
Peter Hawkins
b5e4ba4900 Don't call inspect.signature() each time we trace a jit().
We can just call it once when jit itself is called.

While we're here, also don't recompute api_util.fun_sourceinfo.

PiperOrigin-RevId: 607443283
2024-02-15 13:49:27 -08:00
Jake VanderPlas
82611eb8ae document that under disable_jit, individual primitives are still compiled 2024-02-05 12:01:33 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Peter Hawkins
fc6df3218c Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.

i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.

Why do this?

The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.

The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.

This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.

Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.

The change is disabled by default, so we do not expect any user visible impacts from this change.

PiperOrigin-RevId: 599787818
2024-01-19 03:53:37 -08:00
Yash Katariya
f04f305489 Make eval_shape a wrapper around jax.jit(f).eval_shape(*args, **kwargs)
PiperOrigin-RevId: 599724490
2024-01-18 22:10:57 -08:00
Matthew Johnson
0d4f200e08 Allow unhashable callables in jax.eval_shape.
PiperOrigin-RevId: 599691923
2024-01-18 19:16:48 -08:00
Yash Katariya
51ef738c86 Use jit's jaxpr creation function for eval_shape to maximize tracing cache hits.
This comes up in LLM models, where we trace twice (one for eval_shape (usually the init function) and another during jit) when the output jaxpr is the same. This shouldn't happen and we should cache as much as possible.

The only caveat here is that in eval_shape the `traced_for` on `DebugInfo` is set to `jit`. But maybe it's ok to do that if we want to deprecate eval_shape for a AOT style method on `jax.jit` or have it be a thin wrapper around something like `jax.jit(f).eval_shape`

PiperOrigin-RevId: 599602407
2024-01-18 13:11:44 -08:00
Yash Katariya
b8098b1782 Remove indices and devices from shard_arg_handlers and shard_args.
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
2024-01-05 14:17:14 -08:00
Tom Cobley
ebc7af95df Fix typo in pmap docstring
Docstring states:
>  If the pmapped function is called with fewer positional arguments than indicated by **`static_argnums`** then an error is raised.

However `static_argnums` is not an argument that exists - I believe this should be corrected to `static_broadcasted_argnums`.

PiperOrigin-RevId: 595731210
2024-01-04 09:50:00 -08:00
Axel Donath
e8330b5fc5 Add eval_shape example for function with static arguments
Improve wording and formating of dynamic eval_shape example
2023-12-18 19:19:48 -05:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
Jake VanderPlas
35b84402c0 Deprecate arr.device_buffer and arr.device_buffers 2023-12-06 10:20:29 -08:00
Yash Katariya
f0bc7e0fc6 Reverts f0382a5838f4526d21631e804f6fe576bfc3f97e
PiperOrigin-RevId: 587231484
2023-12-01 22:06:33 -08:00
jax authors
57e19db104 Merge pull request #18736 from mattjj:device-put-fixes
PiperOrigin-RevId: 586490689
2023-11-29 16:51:15 -08:00
Matthew Johnson
c9ab0bfd3c fix grad device_put src inference, and a small device_put bug
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-11-29 16:24:24 -08:00
Jake VanderPlas
0aec40a16f Deprecate arr.device_buffer and arr.device_buffers 2023-11-29 15:31:01 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Yash Katariya
439b89e47f Remove DefaultLayout and make None same as DefaultLayout
PiperOrigin-RevId: 583221970
2023-11-16 18:01:27 -08:00
jax authors
7657a0fb15 Merge pull request #18539 from NeilGirdhar:ruff
PiperOrigin-RevId: 583105786
2023-11-16 11:15:19 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
Jake VanderPlas
0bcd64ade3 jax.vmap: improve docs & error for structured in_axes 2023-11-15 11:56:53 -08:00
Yash Katariya
5c3da219c0 Add a private API to allow setting layouts on jitted computations.
We expose 3 modes:

* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.

* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.

* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.

Public API coming soon.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036
2023-11-15 08:48:53 -08:00
Etienne Pot
7cf66dfe4b Fix typing annotations for jax.named_call
PiperOrigin-RevId: 582235119
2023-11-14 01:47:42 -08:00
Junwhan Ahn
55394a0914 Roll back the optimized version of jax.block_until_ready due to test breakage
Reverts 6cc6d093643c0265c7de4027f79879f6945e0342

PiperOrigin-RevId: 581577789
2023-11-11 12:15:45 -08:00
Junwhan Ahn
6cc6d09364 Implement more efficient jax.block_until_ready(x) in C++
The current implementation synchronously calls `ArrayImpl.block_until_ready()` one by one. This is suboptimal when it's not cheap to query the readiness of an array. Also, calling `x.block_until_ready()` causes GIL to be acquired/released repeatedly.

To address this issue, this CL introduces a C++ implementation of `jax.block_until_ready(x)` that uses IFRT's `Array::GetReadyFuture()` to asynchronously query the readiness of all arrays and wait for them once. To preserve the previous behavior, the C++ implementation also has a slow path for any non-PyArray objects that implement `block_until_ready`.

PiperOrigin-RevId: 581302290
2023-11-10 10:34:34 -08:00
jax authors
62741d9744 Reverts 81ac67f38164b7626d733d081a87ff49b235b9d0
PiperOrigin-RevId: 579010408
2023-11-02 16:17:29 -07:00
Etienne Pot
81ac67f381 Fix typing annotations for @jax.named_call
PiperOrigin-RevId: 578852649
2023-11-02 07:55:04 -07:00
George Necula
edbe49fb2a Cleanup the handling of single- and multi-platform lowering in ModuleContext
Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.

PiperOrigin-RevId: 576575376
2023-10-25 10:40:41 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Jake VanderPlas
709b05f12f jax.make_jaxpr: fix __name__ & related attributes 2023-10-09 15:12:28 -07:00
Matthew Willson
17d89ad166 Fix jax.device_put so it doesn't use tree_map for _check_sharding.
This causes it to unnecessarily attempt to unflatten the None return values from _check_sharding into the original tree structure, which is a problem for custom datatypes registered with jax.tree_util that don't accept None values in place of jax arrays.

PiperOrigin-RevId: 570189648
2023-10-02 15:01:03 -07:00
George Necula
552fef6fcd Introduce a LoweringParameters dataclass for easier plumbing
There are currently two parameters that are used to configure
lowering: lowering_platform (for cross-platform lowering), and
override_lowering_rules. Each of them are passed as separate arguments
through several layers of lowering internal functions. This is tedious,
and error prone. In fact, override_lowering_rules was not plumbed
in all places, and due to using default arguments in all places,
this leads to silent errors.

We foresee introducing other parameters for lowering: for multi-platform
lowering, for controlling the lowering of effects.

Here is pack all such parameters into a `mlir.LoweringParameters`
dataclass and we plumb that through.
2023-09-29 08:23:05 +03:00
Yash Katariya
603c879fa0 Run _check_sharding checks during api.device_put instead of in the impl rule so that we don't have to repeat these checks in each rule of device_put.
The same is done for jit and with_sharding_constraint.

PiperOrigin-RevId: 561380348
2023-08-30 10:27:37 -07:00
Peter Hawkins
92128d4083 Remove backward compatibility code related to pytree registries.
We always have a default_registry now, so we don't need to protect code that uses it with conditionals. A number of type suppressions are also stale.

PiperOrigin-RevId: 560849610
2023-08-28 16:26:25 -07:00
Yash Katariya
aeb62cc006 Add TransferToMemoryKind as a private API to allow device_put to transfer to different memories without specifying the sharding and allowing the SPMD partitioner to choose the sharding for the intermediate.
Exposing it as a public API can be done later.

PiperOrigin-RevId: 559314369
2023-08-22 22:11:38 -07:00
Jake VanderPlas
9aca944891 Fix type annotation for tree_util.default_registry 2023-08-16 15:07:48 -07:00
Jake VanderPlas
6097d8fe23 CI: bump mypy to version 1.5.0 2023-08-16 14:26:04 -07:00
Peter Hawkins
e58f1ba86e Move some utilities out of dispatch.py next to their users, add more types.
Internal cleanups only, no user-visible changes intended.

PiperOrigin-RevId: 554876522
2023-08-08 10:52:11 -07:00
Peter Hawkins
a6a8f4850c [JAX] Don't include ShardingSpecs or out_indices in the data passed to the C++ pmap() fast path.
The pmap() fast path doesn't even look the ShardingSpec or the out_indices since the jax.Sharding rework.

PiperOrigin-RevId: 553206145
2023-08-02 11:29:05 -07:00
Yash Katariya
4ddf6a9a54 Bump minimum_jaxlib_version to 0.4.14. xla_extension_version is 174 and mlir_api_version is 54
PiperOrigin-RevId: 552816893
2023-08-01 08:53:28 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
cdb48134e5 [JAX] Add support for multiple pytree registries.
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.

One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.

PiperOrigin-RevId: 549301796
2023-07-19 06:48:21 -07:00
Yash Katariya
89c78bf53f jax.jit now works correctly if both donate_argnums and donate_argnames are specified.
Update the docstring and changelog too to mention `donate_argnames`.

PiperOrigin-RevId: 548223395
2023-07-14 14:28:16 -07:00
Yash Katariya
b337c26c72 Add donate_argnames to jax.jit. This works similarly to static_argnames.
Note that if donate_argnames is not None and donate_argnums is None, then JAX will infer donate_argnums from the names which will then we used to find the donation_vector. This is fine because currently, the same thing happens from static_argnums and static_argnames.

I'll fix the TODOs, etc in follow up CLs.

Fixes https://github.com/google/jax/issues/10539

PiperOrigin-RevId: 547612861
2023-07-12 15:09:57 -07:00
treyra
9ec6ebb7e0 Fixed _mapped_axis_size raising an uncaught TypeError 2023-07-06 18:51:44 -07:00
George Necula
9261edaf94 [shape_poly] Cleanups for the shape polymorphism APIs.
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:

  * remove symbolic_equal_dim and symbolic_equal_shape in favor of the
    newer definitely_equal and definitely_equal_shape
  * remove is_special_dim_size, which checks that a value is a
    dimension expression (not a constant). Some uses are replaced
    with `not is_constant_dim` and others with `is_dim`.
  * introduce concrete_dim_or_error to check that a value is
    a dimension
2023-06-30 15:56:57 +03:00
Yash Katariya
744a64fce6 Make sharding on ShapeDtypeStruct a property that always exists. The previous behavior was it only existed if sharding was not None.
sharding=None means that JAX is free to choose whatever sharding it wants. As it stands, jax will choose to mark the input as replicated but JAX reserves the right to change that as it sees fit.
PiperOrigin-RevId: 543630595
2023-06-26 21:46:50 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00