1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 08:16:07 +00:00

109 Commits

Author SHA1 Message Date
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Adam Paszke
ffd34d5ad7 Allow collectives in manually sharded computations
... at least when the manual sharding applies to the whole mesh, because
that's all that XLA can support right now. This is especially important
when computing gradients of xmapped functions (when manual lowering is
enabled), since AD often introduces many `psum`s.

PiperOrigin-RevId: 467895089
2022-08-16 04:54:14 -07:00
George Necula
3d9c8fbe6f [dynamic-shapes] Ensure that the axis_size_env is passed to sub lowering contexts 2022-07-12 12:44:23 +03:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Peter Hawkins
cc0f51603d [JAX] Don't expand complex all-reduce ops into real/complex pairs.
[XLA:CPU] Implement complex all-reductions for sum and product.

Fixes https://github.com/google/jax/issues/11133 by making XLA implement the all-reduction whenever we build one, not just the one path on which we happened to have a workaround.

PiperOrigin-RevId: 455687275
2022-06-17 13:45:24 -07:00
Xin Zhou
c017d09767 [mhlo] Add result type inference for mhlo.alltoall.
PiperOrigin-RevId: 449591261
2022-05-18 15:24:22 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Eugene Burmako
90f926ac6b [MHLO] Migrate mhlo.all_reduce to HLO_CompatibleOperandsAndResultType
This runs into the currently unsupported feature in Python bindings which prevents it from taking advantage of the type inference functionality provided by HLO_CompatibleOperandsAndResultType.

PiperOrigin-RevId: 447844880
2022-05-10 15:39:18 -07:00
Peter Hawkins
931bf3674b [JAX] Split the "gpu" platform in internal JAX usage into separate "cuda" and "rocm" platforms.
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.

[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.

PiperOrigin-RevId: 446737518
2022-05-05 09:33:06 -07:00
Anudhyan Boral
a147046d18 Add unary xeinsum and allow named axis reductions for unary and binary xeinsums 2022-04-26 09:55:42 +00:00
Sharad Vikram
f17c09eb8d add in mlir lowering for tokens 2022-04-21 11:28:58 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
cb4abe754a [MHLO] Separate registrations for collective and initial_style primitives from the XLA translation rule registration.
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.

PiperOrigin-RevId: 441474701
2022-04-13 07:26:26 -07:00
Sharad Vikram
0fa1eddd25 Adds simple effect types to jaxprs 2022-04-11 11:50:41 -07:00
Peter Hawkins
cbdcdf7401 [MHLO] Add MHLO lowerings for parallel collectives.
PiperOrigin-RevId: 440106753
2022-04-07 07:59:26 -07:00
Lukas Geiger
50e8bc4514 Replace reshape with expand_dims if possible 2022-03-31 01:34:26 +01:00
Peter Hawkins
f7ba328e7a Ensure that pdot parameters are hashable.
PiperOrigin-RevId: 416804949
2021-12-16 07:22:59 -08:00
Peter Hawkins
b0646557ee Change primitive arguments to parallel and sparse primitives to make parameters hashable.
An upcoming change adds a cache keyed on (primitive, params), but to do that, we need the params to be hashable.

PiperOrigin-RevId: 416793521
2021-12-16 06:13:32 -08:00
Matthew Johnson
c555f5f0e4 handle trivial case for ppermute batching rule
fixes 
2021-12-14 10:42:05 -08:00
Peter Hawkins
1f2d8c0c07 In CPU all_gather lowering, make sure the outputs are bools if the inputs are bools.
PiperOrigin-RevId: 414045093
2021-12-03 16:12:03 -08:00
Peter Hawkins
42647e013f [MLIR] Make jit(pmap(...)) work in the MLIR lowering.
This is redundant with the XLA lowering, but it's probably not the end of the world as a temporary state. An alternative would have been to port the _xla_shard/_xla_unshard primitives to the LAX level and to use xla.lower_fun, but it's not immediately obvious to me how to access ReplicaId() without defining a new primitive. lax.axis_index() is similar but not identical.

Add an axis_env argument to xla.primitive_subcomputation for use by the MLIR fallback path.

PiperOrigin-RevId: 413124116
2021-11-30 05:34:34 -08:00
Peter Hawkins
db0e3fbea9 Reenable pytype checking for jax._src.lax.lax.
pytype checking for this module is no longer excessively slow after the module was split.

PiperOrigin-RevId: 412098920
2021-11-24 11:15:40 -08:00
Peter Hawkins
83d8c6c238 Split slice/update_slice/gather/scatter out of jax._src.lax.lax into jax._src.lax.slicing.
To solve a circular dependency problem where some functions in jax._src.lax.lax depend on slicing, I moved a number of utility functions, e.g., standard_primitive, into a new module `jax._src.lax.utils`. Only utilities that need to be present at module import time were moved.

PiperOrigin-RevId: 411921794
2021-11-23 16:35:18 -08:00
Matthew Johnson
2cb235809a make vmap ppermute consistent with pmap/docstring
This was a bad bug! Unfortunately our tests didn't catch it, in part
because permutations on size-two axes are either trivial or not. The
simplest test might have a size-three axis.
2021-11-18 14:02:49 -08:00
Peter Hawkins
606ca26e05 Simplify implementation of lower_fun in XLA and MLIR lowering paths.
* Always propagate the axis environment, and remove the parallel argument to lower_fun() because it is no longer needed.
* Don't update the name_stack in the MLIR version. The XLA version no longer does this.
* Simplify the call signature of the MLIR version by always accepting avals_out but noting that it is ignored so it is legal to pass, say, None.

PiperOrigin-RevId: 410875100
2021-11-18 12:45:21 -08:00
Peter Hawkins
3f34ee4250 Make a non-None platform mandatory during XLA translation.
The main change is plumbing a platform into calls to xla.primitive_subcomputation.

PiperOrigin-RevId: 410130715
2021-11-15 18:26:38 -08:00
jax authors
f3b1a3010a Merge pull request from sharadmv:all-gather-grad
PiperOrigin-RevId: 405924058
2021-10-27 10:40:07 -07:00
Sharad Vikram
ae9e69814a Broadcast unmapped values in all_to_all batching rule
Fixes .

Co-authored-by: Sharad Vikram<sharad.vikram@gmail.com>
Co-authored-by: Adam Paszke <apaszke@google.com>
2021-10-27 10:10:41 -07:00
Peter Hawkins
1a73743610 Move xla_bridge.constant to jax.interpreter.xla.pyval_to_ir_constant.
This is a more descriptive name and a better location (next to other facilities for building XLA IR).

Quite a few users of the former xla_bridge.constant() didn't need anything other than uncanonicalized array constants. Change these users to use xla_client.ops.Constant instead; no need for the fancy utility in these cases.

PiperOrigin-RevId: 404270649
2021-10-19 08:40:51 -07:00
Peter Hawkins
185d7a9fd9 Delete xla_bridge.xla.dtype_to_etype, replace it with jax.interpreters.xla.dtype_to_primitive_type.
The new version does *not* canonicalize dtypes. We should be canonicalizing dtypes as part of tracing to a jaxpr, not in any way as part of XLA lowering. In all cases as best I can tell the dtypes from the callers are already canonical anyway.

jax.interpreters.xla is also a better location: I'm not even sure why we have a bunch of random things in xla_bridge any more, so it makes sense to consolidate them in xla.py along with the other registrations for things like avals.

Also delete the unused function xla_bridge.supported_numpy_dtypes.

PiperOrigin-RevId: 404246574
2021-10-19 06:49:02 -07:00
Peter Hawkins
fa89f34be6 [JAX] Port lax translation rules to updated XLA translation rule API.
PiperOrigin-RevId: 404114709
2021-10-18 18:07:29 -07:00
Peter Hawkins
2bd010ae88 Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.

Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.

In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.

PiperOrigin-RevId: 403607667
2021-10-16 07:53:24 -07:00
Adam Paszke
49d9affce0 Enable batcher and batched collective rules for tiled all gathers
Fixes .
2021-10-15 14:37:38 +00:00
James Bradbury
86022adf2f
Use all_gather+reduce_scatter HLOs on TPU
The all-gather and reduce-scatter HLOs were wired through for GPU but not TPU, but they should also work there (and be more performant than the all-reduce based fallback).
2021-10-05 17:17:48 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Sharad Vikram
8ae58be90a Fix for singleton axis name in axis_index_translation_rule 2021-09-16 12:31:13 -07:00
jax authors
c365d7f91c Merge pull request from hawkinsp:api3
PiperOrigin-RevId: 396578038
2021-09-14 06:04:28 -07:00
jax authors
f172f9337e Merge pull request from jakevdp:parallel-take
PiperOrigin-RevId: 396493605
2021-09-13 19:04:37 -07:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Jake VanderPlas
04afc0cc79 ppermute: avoid passing lists of indices to jnp.take 2021-09-13 17:13:10 -07:00
Sharad Vikram
ebd8d95847 Add precision param for pdot 2021-09-13 16:28:31 -07:00
Sharad Vikram
cc3e197991 Combine initial_style_batchers with collective_rules 2021-09-09 11:23:51 -07:00
jax authors
7500c7e969 Merge pull request from google:rejames5
PiperOrigin-RevId: 394333280
2021-09-01 15:56:42 -07:00
Ningning Xie
f38d3e8735 Allow axis index groups to have different sizes for AllReduce.
PiperOrigin-RevId: 394297426
2021-09-01 13:10:17 -07:00
Matthew Johnson
542641ca87 rejames/reblake implementation 2021-08-25 20:46:32 -07:00
jax authors
606cbe036a Merge pull request from slowy07:fixing
PiperOrigin-RevId: 388774232
2021-08-04 13:43:58 -07:00
Ningning Xie
e7f03073dd ReduceScatter translation and abstract eval.
PiperOrigin-RevId: 387152857
2021-07-27 11:13:18 -07:00
slowy07
9eadb07bdc fix: miss typo codespell and documentation 2021-07-24 15:25:13 +07:00
Adam Paszke
64510bd5b6 Add axis and tiled options to lax.all_gather.
This is especially convenient when using JAX as an HLO generator, because the
HLO AllGather defaults to the tiling behavior.

PiperOrigin-RevId: 384897270
2021-07-15 04:22:36 -07:00
Adam Paszke
490f9778c8 Raise a friendlier error message when using loop axes in collectives 2021-06-08 11:55:03 +00:00