355 Commits

Author SHA1 Message Date
Yash Katariya
970f4c9d4d Remove trivial execution from jax since it leads to 100x slower dispatch time.
Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist.

In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about.

```
jit_trivial_dispatch                                   246µs ± 3%                4µs ± 1%  -98.52%          (p=0.008 n=5+5)
jit_trivial                                            250µs ± 3%                5µs ± 1%  -98.19%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 560141018
2023-08-25 10:59:48 -07:00
Yash Katariya
0501a15fd5 Print str_short of the arg and remove printing the value of the arg.
PiperOrigin-RevId: 559524941
2023-08-23 13:31:35 -07:00
jax authors
6d686135de Force inline all the calls in the module created by a custom call.
PiperOrigin-RevId: 558196979
2023-08-18 11:13:45 -07:00
Parker Schuh
03575c4b33 Pad generated sharding specs with None up to ndims to simplify comparing dims
across different partitioned arguments.

PiperOrigin-RevId: 555712119
2023-08-10 17:02:31 -07:00
Parker Schuh
74bcd65bbd Make mesh available to custom_partitioning lowering rules.
PiperOrigin-RevId: 555319896
2023-08-09 17:08:57 -07:00
Ce Zheng
b80498874a [XLA:Client] Make HloSharding::iota_tile actually produce V2 shardings.
PiperOrigin-RevId: 554631780
2023-08-07 16:46:53 -07:00
Yash Katariya
853c470292 Improve the repr of NamedSharding and error message of device_put
PiperOrigin-RevId: 552841710
2023-08-01 10:17:20 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
7df3477926 [JAX] Use MLIR argument locations instead of a bespoke jax.arg_info attribute.
514dddbeba allowed for specifying argument Locations in the MLIR Python bindings. We should use them, in the form of a Name location, rather than making up our own attribute.

Example of new output:

```
In [1]: import jax
In [2]: ir = jax.jit(lambda x, y: x + y).lower(7, 3).compiler_ir()
In [3]: ir.operation.print(enable_debug_info=True)
#loc1 = loc("x")
#loc2 = loc("y")
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.sharding = "{replicated}"} loc("x"), %arg1: tensor<i32> {mhlo.sharding = "{replicated}"} loc("y")) -> (tensor<i32> {jax.result_info = ""}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<i32> loc(#loc4)
    return %0 : tensor<i32> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("<ipython-input-2-ef5a568a0c1c>":1:0)
#loc4 = loc("jit(<lambda>)/jit(main)/add"(#loc3))
```

Note debug information must be enabled.

PiperOrigin-RevId: 549325621
2023-07-19 08:39:16 -07:00
Yash Katariya
f0ce0d8c6a Delete in_axis_resources and out_axis_resources from pjit since it's been more than 3 months since their deprecation. The replace is to use in_shardings and out_shardings. You can still pass PartitionSpecs to {in|out}_shardings to pjit.
PiperOrigin-RevId: 548673905
2023-07-17 06:35:49 -07:00
Yash Katariya
89c78bf53f jax.jit now works correctly if both donate_argnums and donate_argnames are specified.
Update the docstring and changelog too to mention `donate_argnames`.

PiperOrigin-RevId: 548223395
2023-07-14 14:28:16 -07:00
Yash Katariya
b337c26c72 Add donate_argnames to jax.jit. This works similarly to static_argnames.
Note that if donate_argnames is not None and donate_argnums is None, then JAX will infer donate_argnums from the names which will then we used to find the donation_vector. This is fine because currently, the same thing happens from static_argnums and static_argnames.

I'll fix the TODOs, etc in follow up CLs.

Fixes https://github.com/google/jax/issues/10539

PiperOrigin-RevId: 547612861
2023-07-12 15:09:57 -07:00
jax authors
e894e4817a Remove deprecated compiler_ir from Compiled
PiperOrigin-RevId: 547211085
2023-07-11 09:24:48 -07:00
Roy Frostig
1ad0a11897 AOT: better error messages on call signature mismatch
Also update error example in AOT docs.
2023-07-10 22:10:50 -07:00
Yash Katariya
744a64fce6 Make sharding on ShapeDtypeStruct a property that always exists. The previous behavior was it only existed if sharding was not None.
sharding=None means that JAX is free to choose whatever sharding it wants. As it stands, jax will choose to mark the input as replicated but JAX reserves the right to change that as it sees fit.
PiperOrigin-RevId: 543630595
2023-06-26 21:46:50 -07:00
Yash Katariya
6007698f4e Allow None to be passed to in_shardings and out_shardings. The default is still UNSPECIFIED to handle edge cases around the old semantics where None is treated as fully replicated.
The semantics are as follow:

* if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings

* If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics.

This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit.

PiperOrigin-RevId: 540705660
2023-06-15 15:22:22 -07:00
Yash Katariya
38b9bf8cac Raise a good error message when a ShapeDtypeStruct is closed over as a const which is not a valid arg during execution.
PiperOrigin-RevId: 540296131
2023-06-14 09:40:37 -07:00
Yash Katariya
c7b372a8df Convert OpSharding to HloSharding in the constructor of GSPMDSharding. Also make op_sharding_to_indices to work with only HloSharding.
PiperOrigin-RevId: 538165282
2023-06-06 06:35:29 -07:00
Yash Katariya
01fdd91a5f Use _to_xla_hlo_sharding everywhere in JAX. Remove _to_xla_op_sharding in favor of _to_xla_hlo_sharding since constructing a C++ class is faster than protos and will help with further changes coming to HloSharding.
PiperOrigin-RevId: 537969500
2023-06-05 13:41:31 -07:00
Yash Katariya
5cf0e042eb Change the _most_recent_executable logic to store a weakref dict of jaxpr -> executable so that with the inner cpp cache and outer cpp cache, we extract the correct executable.
PiperOrigin-RevId: 537908874
2023-06-05 10:07:05 -07:00
André Susano Pinto
cfabad5886 Avoid IndexError when constructing a ValueError for a DeviceAssignmentMismatchError.
_get_arg_names was throwing IndexError when handling functions with variadic args.

PiperOrigin-RevId: 537308439
2023-06-02 07:43:59 -07:00
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Yash Katariya
f884b4d13f Fix the test_sharding_on_output_with_vmap failure in Pathways which was getting a cache miss in pjit_call_impl.
There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.

In the test, check for re-compilation which is what that test was doing before cl/535630905

PiperOrigin-RevId: 536575987
2023-05-30 19:51:48 -07:00
Yash Katariya
fe3fed3627 Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
PiperOrigin-RevId: 535687618
2023-05-26 12:35:16 -07:00
Yash Katariya
4f074718d4 Make pjit_call_impl go via C++ dispatch.
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.

PiperOrigin-RevId: 535630905
2023-05-26 08:57:30 -07:00
Peter Hawkins
e464dc8700 Reland: [XLA:Python] Add buffer protocol support to jax.Array
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 535248553
2023-05-25 07:20:42 -07:00
Ce Zheng
8e397f7f08 [XLA:Client] Change replicate_last_dim to subgroup_types in HloSharding.iota_tile to cover arbitrary subgroups, adding necessary accessors.
PiperOrigin-RevId: 535079635
2023-05-24 20:26:28 -07:00
Ce Zheng
4f1f5e4516 [XLA:Client] Expose HloSharding pybind factories for iota tile/partial tile, replicated and manual sharding,
PiperOrigin-RevId: 534600886
2023-05-23 16:37:42 -07:00
Yash Katariya
b71829f882 Allow pjit.AUTO to be used with jax.jit. This introduces an API change which requires a mesh to be provided to pjit.AUTO(mesh).
`with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD.

PiperOrigin-RevId: 533801596
2023-05-20 23:00:35 -07:00
Sholto Douglas
e0b5003880 Allow unconstrained dimensions when using NamedShardings.
PiperOrigin-RevId: 533752415
2023-05-20 16:28:12 -07:00
Parker Schuh
56ca8af9bb Make custom_partitioning support multiple return values.
PiperOrigin-RevId: 533584581
2023-05-19 16:58:54 -07:00
Yash Katariya
4a5c6f8200 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time.
PiperOrigin-RevId: 533263584
2023-05-18 15:09:54 -07:00
Parker Schuh
08169291a4 Simplify custom_partitioning to use jax.ShapeDtypeStruct instead of passing separate
arguments for shape and sharding.

PiperOrigin-RevId: 533257532
2023-05-18 14:48:07 -07:00
Yash Katariya
f1c2711292 Add impl rule for with_sharding_constraint so that users can use their functions with and without a jit.
The semantics of eager wsc is the same as within a jit i.e. it will reshard to the given sharding only if the devices are the same and in the same order.

eager wsc won't work as expected with AD transpose because there is no `src` argument to reverse the shardings when transposing and was decided that it is fine for now. jax.device_put should be the API to use for that.

PiperOrigin-RevId: 532858670
2023-05-17 11:50:12 -07:00
Yash Katariya
8b9e6bcbd4 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time.
PiperOrigin-RevId: 532155068
2023-05-15 10:32:24 -07:00
Yash Katariya
befa29b566 Fix the cache on to_gspmd_sharding to depend on if device/backend is set on pjit/jit.
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.

This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.

PiperOrigin-RevId: 530712918
2023-05-09 14:24:21 -07:00
Yash Katariya
1629c6c76b Make jax.jit work with vmap(..., spmd_axis_name) when there is no mesh context manager.
This will only work if the input Array's sharding is a NamedSharding

Fixes https://github.com/google/jax/issues/15886

PiperOrigin-RevId: 529758233
2023-05-05 10:48:33 -07:00
Yash Katariya
a6254c75e0 Improve the shape incompatible error message by adding the argument/result name path to it.
PiperOrigin-RevId: 529605855
2023-05-04 21:50:04 -07:00
Yash Katariya
bffddf76cb Improve the error raised when wsc is passed a PartitionSpec without a mesh context manager
PiperOrigin-RevId: 529260748
2023-05-03 19:35:51 -07:00
Yash Katariya
9515ccf376 Fix pjit + vmap when device is passed as an argument to pjit/jit
PiperOrigin-RevId: 529155035
2023-05-03 11:55:23 -07:00
Rahul Joshi
9d750ae97d Fix pjit outfeed test avoid potential deadlocks.
PiperOrigin-RevId: 529076350
2023-05-03 06:51:26 -07:00
Yash Katariya
40349a8612 Normalize 1 length tuples to a string while getting PartitionSpec from array mapping.
Fixes https://github.com/google/jax/issues/15782

PiperOrigin-RevId: 528796985
2023-05-02 08:55:40 -07:00
Yash Katariya
c52e48b6c0 Only return the same input Sharding object is the original aval's ndim and out_aval's ndim are the same.
This is because if both the OpShardings are replicated then the ndim is not encoded in the OpSharding and it will return True even if the Sharding is incompatible with the output's ndim. Concretely `NamedSharding({'x': 1, y: '2'}, P('x'))` is not compatible with a input with `ndim == 0`.

PiperOrigin-RevId: 528621971
2023-05-01 17:39:51 -07:00
Yash Katariya
4a3fb238f6 Return the same sharding object if the output OpSharding matches the input OpSharding.
Fixes https://github.com/google/jax/issues/15782

PiperOrigin-RevId: 528531594
2023-05-01 11:46:57 -07:00
Yash Katariya
86c1f5bcee Preserve the sharding type of physical sharding on logical sharding when .sharding is accessed on a PRNGKeyArray
PiperOrigin-RevId: 527639257
2023-04-27 11:41:00 -07:00
Jake VanderPlas
3108f05eee fix pjit_test:testWithCustomPRNGKey 2023-04-25 10:52:15 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Parker Schuh
87c328864b Improve testing for custom_partitioning.
Add a test to demonstrate how to force XLA to choose
a different sharding.

Also it is possible to return the wrong
shape from a partition function. We should error in this case.

PiperOrigin-RevId: 525606690
2023-04-19 18:26:51 -07:00
Yash Katariya
53e6382f4a Add arg_names to aval mismatch error raised during AOT compilation to raise better error messages
PiperOrigin-RevId: 525561905
2023-04-19 15:08:53 -07:00
Yash Katariya
0a19638490 Plumb debug_info to meshExecutable as a optional arg to raise better error messages.
PiperOrigin-RevId: 525521694
2023-04-19 12:35:49 -07:00