184 Commits

Author SHA1 Message Date
Matthew Johnson
ec7d28c0b2 revise logic for tangent types of extended dtypes
* remove the dead code KeyTangentTy
* replace TyRules.make_tangent with TyRules.zero
* removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it
* fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type
* fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see #19009 for a check which catches this and hence includes the same test change

We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
2023-12-20 14:24:52 -08:00
Jan Hrček
4da56dcdd7 Fix duplicate word occurrences 2023-12-19 06:15:30 +01:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
George Necula
edbe49fb2a Cleanup the handling of single- and multi-platform lowering in ModuleContext
Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.

PiperOrigin-RevId: 576575376
2023-10-25 10:40:41 -07:00
jax authors
17e259b1fe fix typo: device_fun(c) -> device_fun(x)
PiperOrigin-RevId: 570289287
2023-10-02 23:38:48 -07:00
jax authors
9f91df725d Merge pull request #17733 from superbobry:dict-literals
PiperOrigin-RevId: 568560963
2023-09-26 09:22:37 -07:00
Sergei Lebedev
eca10f5a3d ENH Use {} and () instead of dict() and tuple() 2023-09-25 11:53:33 +01:00
Jake VanderPlas
4a5bd9e046 Fix typos across the package 2023-09-22 14:54:31 -07:00
George Necula
efaea8ed32 [callback] Enable device_index support in terms of callback sharding support.
This is part of deprecating host_callback and moving to io_callback.

PiperOrigin-RevId: 561856023
2023-08-31 22:31:35 -07:00
George Necula
e0a6230214 [host_callback] Delete unused code paths.
This is part of deprecating host_callback and moving to io_callback.

PiperOrigin-RevId: 561851494
2023-08-31 22:08:23 -07:00
jax authors
f19e748303 Merge pull request #17016 from mattjj:royroyroy
PiperOrigin-RevId: 559524338
2023-08-23 13:22:37 -07:00
George Necula
8891503f87 [callback] Add workaround for TPU host_callback not supporting empty arrays.
Currently JAX callbacks on TPU raise errors when the called function takes empty arguments or returns empty results. It seems that the send_to_host function works
even with empty arrays, but recv_from_host crashes (crash log below).

Here we work around this issue, by ensuring that only the non-empty results of the Python callback are sent to the device computation and the empty results are replaced with empty constants in the device computation.

This is part of the work to replace uses of host_callback with io_callback.

PiperOrigin-RevId: 559061336
2023-08-22 03:47:18 -07:00
Matthew Johnson
78199bbc32 [custom-jvp/vjp] for symbolic zeros, ensure rules can be run more than once
Co-authored-by: Roy Frostig <frostig@google.com>
2023-08-21 15:28:43 -07:00
George Necula
ad15a38ec1 [host_callback] Remove old backwards compatibility flag jax_host_callback_ad_transforms.
This flag was added in https://github.com/google/jax/pull/8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.

This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.

PiperOrigin-RevId: 557520668
2023-08-16 10:01:49 -07:00
Peter Hawkins
a259df0d76 Move compiler APIs out of dispatch.py and xla_bridge.py into a new jax._src.compiler module.
Refactoring only, no user-visible changes intended.

PiperOrigin-RevId: 557116160
2023-08-15 06:39:46 -07:00
Peter Hawkins
ca17b6c08f Move functions out of xla.py closer to their users.
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
2023-08-08 14:40:42 -07:00
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
cdb48134e5 [JAX] Add support for multiple pytree registries.
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.

One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.

PiperOrigin-RevId: 549301796
2023-07-19 06:48:21 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
a9e48af260 Deprecated xla_call_p since it has been replaced with pjit.pjit_p
PiperOrigin-RevId: 518921538
2023-03-23 11:44:42 -07:00
Peter Hawkins
befce6d2c8 [XLA:Python] Allow passing ExecutableBuildOptions to outfeed receiver.
Outfeed receiver compiles computations (during shutdown), and if the correct options aren't provided, then it may not be able to do things like find ptxas for CUDA builds. Plumb the executable build options through from Python.

PiperOrigin-RevId: 518852909
2023-03-23 07:31:06 -07:00
Peter Hawkins
ed491b3056 Shorten alias chains for names exported in jax. namespace.
Add some additional type annotations on public APIs.

This allows pytype to do a better job of type inference.

PiperOrigin-RevId: 513255770
2023-03-01 09:19:44 -08:00
Yash Katariya
52a7701dda Replace usage of {in|out}_axis_resources with {in|out}_shardings
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
jax authors
3838d7612a Merge pull request #14504 from skye:host_callback_pjrt_error
PiperOrigin-RevId: 509972891
2023-02-15 17:11:01 -08:00
Skye Wanderman-Milne
d9f628c972 Raise a user-friendly error message if in/outfeed-based host_callback stuff is used with PJRT C API.
Prior to this change, it would crash horribly instead.

I manually tested by running the following on a Cloud TPU v4-8:
```
JAX_USE_PJRT_C_API_ON_TPU=1 python3 -m pytest tests/host_callback_test.py --tb=no
```
And verifying that all errors were the new error message.

The new error message is:
`host_callback functionality isn't supported with the new Cloud TPU
runtime. See https://jax.readthedocs.io/en/latest/debugging/index.html
and
https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
for alternatives. Please file a feature request at
https://github.com/google/jax/issues if none of the alternatives are
sufficent.`
2023-02-16 00:12:25 +00:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Peter Hawkins
6860cb8d2a Move jax.interpreters.xla to jax._src.interpreters.xla.
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Peter Hawkins
a13a2c5cc2 [JAX] Remove obsolete unit type declarations in jax.core.
Remove obsolete unit test in host_callback.

PiperOrigin-RevId: 507473737
2023-02-06 07:33:14 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Jake VanderPlas
a0eae5709f Raise an error when attempting to mutate Jaxpr objects 2023-01-23 09:37:58 -08:00
George Necula
30cf057bf3 [host_callback] Add device_index to hcb.call and add tests
The device_index feature works only with outfeed, add an
error message.

PiperOrigin-RevId: 502951721
2023-01-18 12:41:11 -08:00
Yash Katariya
94f0ccc54a Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
c36c25aaf2 Use in_shardings and out_shardings because those are the things available in pjit's params.
PiperOrigin-RevId: 499296361
2023-01-03 13:07:02 -08:00
Jake VanderPlas
fe4c9584f7 doc: fix host callback module crossref 2022-12-27 15:59:32 -08:00
Yash Katariya
57840dd916 Move functions into api_util.py and dispatch.py to remove circular import error when pjit is imported in api.py for merging the jit and pjit frontend API.
PiperOrigin-RevId: 497172760
2022-12-22 08:42:05 -08:00
Chang Lan
9c4e2fa8fa Make the device assignment of outfeed configurable
PiperOrigin-RevId: 496574960
2022-12-19 22:53:15 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Jake VanderPlas
904398a43d [x64] better type safety for host_callback 2022-12-01 11:47:07 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Dan Zheng
9b0c4e5b9c Fix typo.
decice -> device
2022-10-14 22:12:08 -07:00