174 Commits

Author SHA1 Message Date
jax authors
f19e748303 Merge pull request #17016 from mattjj:royroyroy
PiperOrigin-RevId: 559524338
2023-08-23 13:22:37 -07:00
George Necula
8891503f87 [callback] Add workaround for TPU host_callback not supporting empty arrays.
Currently JAX callbacks on TPU raise errors when the called function takes empty arguments or returns empty results. It seems that the send_to_host function works
even with empty arrays, but recv_from_host crashes (crash log below).

Here we work around this issue, by ensuring that only the non-empty results of the Python callback are sent to the device computation and the empty results are replaced with empty constants in the device computation.

This is part of the work to replace uses of host_callback with io_callback.

PiperOrigin-RevId: 559061336
2023-08-22 03:47:18 -07:00
Matthew Johnson
78199bbc32 [custom-jvp/vjp] for symbolic zeros, ensure rules can be run more than once
Co-authored-by: Roy Frostig <frostig@google.com>
2023-08-21 15:28:43 -07:00
George Necula
ad15a38ec1 [host_callback] Remove old backwards compatibility flag jax_host_callback_ad_transforms.
This flag was added in https://github.com/google/jax/pull/8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.

This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.

PiperOrigin-RevId: 557520668
2023-08-16 10:01:49 -07:00
Peter Hawkins
a259df0d76 Move compiler APIs out of dispatch.py and xla_bridge.py into a new jax._src.compiler module.
Refactoring only, no user-visible changes intended.

PiperOrigin-RevId: 557116160
2023-08-15 06:39:46 -07:00
Peter Hawkins
ca17b6c08f Move functions out of xla.py closer to their users.
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
2023-08-08 14:40:42 -07:00
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
cdb48134e5 [JAX] Add support for multiple pytree registries.
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.

One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.

PiperOrigin-RevId: 549301796
2023-07-19 06:48:21 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
a9e48af260 Deprecated xla_call_p since it has been replaced with pjit.pjit_p
PiperOrigin-RevId: 518921538
2023-03-23 11:44:42 -07:00
Peter Hawkins
befce6d2c8 [XLA:Python] Allow passing ExecutableBuildOptions to outfeed receiver.
Outfeed receiver compiles computations (during shutdown), and if the correct options aren't provided, then it may not be able to do things like find ptxas for CUDA builds. Plumb the executable build options through from Python.

PiperOrigin-RevId: 518852909
2023-03-23 07:31:06 -07:00
Peter Hawkins
ed491b3056 Shorten alias chains for names exported in jax. namespace.
Add some additional type annotations on public APIs.

This allows pytype to do a better job of type inference.

PiperOrigin-RevId: 513255770
2023-03-01 09:19:44 -08:00
Yash Katariya
52a7701dda Replace usage of {in|out}_axis_resources with {in|out}_shardings
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
jax authors
3838d7612a Merge pull request #14504 from skye:host_callback_pjrt_error
PiperOrigin-RevId: 509972891
2023-02-15 17:11:01 -08:00
Skye Wanderman-Milne
d9f628c972 Raise a user-friendly error message if in/outfeed-based host_callback stuff is used with PJRT C API.
Prior to this change, it would crash horribly instead.

I manually tested by running the following on a Cloud TPU v4-8:
```
JAX_USE_PJRT_C_API_ON_TPU=1 python3 -m pytest tests/host_callback_test.py --tb=no
```
And verifying that all errors were the new error message.

The new error message is:
`host_callback functionality isn't supported with the new Cloud TPU
runtime. See https://jax.readthedocs.io/en/latest/debugging/index.html
and
https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
for alternatives. Please file a feature request at
https://github.com/google/jax/issues if none of the alternatives are
sufficent.`
2023-02-16 00:12:25 +00:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Peter Hawkins
6860cb8d2a Move jax.interpreters.xla to jax._src.interpreters.xla.
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Peter Hawkins
a13a2c5cc2 [JAX] Remove obsolete unit type declarations in jax.core.
Remove obsolete unit test in host_callback.

PiperOrigin-RevId: 507473737
2023-02-06 07:33:14 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Jake VanderPlas
a0eae5709f Raise an error when attempting to mutate Jaxpr objects 2023-01-23 09:37:58 -08:00
George Necula
30cf057bf3 [host_callback] Add device_index to hcb.call and add tests
The device_index feature works only with outfeed, add an
error message.

PiperOrigin-RevId: 502951721
2023-01-18 12:41:11 -08:00
Yash Katariya
94f0ccc54a Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
c36c25aaf2 Use in_shardings and out_shardings because those are the things available in pjit's params.
PiperOrigin-RevId: 499296361
2023-01-03 13:07:02 -08:00
Jake VanderPlas
fe4c9584f7 doc: fix host callback module crossref 2022-12-27 15:59:32 -08:00
Yash Katariya
57840dd916 Move functions into api_util.py and dispatch.py to remove circular import error when pjit is imported in api.py for merging the jit and pjit frontend API.
PiperOrigin-RevId: 497172760
2022-12-22 08:42:05 -08:00
Chang Lan
9c4e2fa8fa Make the device assignment of outfeed configurable
PiperOrigin-RevId: 496574960
2022-12-19 22:53:15 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Jake VanderPlas
904398a43d [x64] better type safety for host_callback 2022-12-01 11:47:07 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Dan Zheng
9b0c4e5b9c Fix typo.
decice -> device
2022-10-14 22:12:08 -07:00
Nicholas Junge
efd61b73f6 Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
b7426b5ef9 rolling forward deletion of custom_jvp_call_jaxpr_p yet again...
PiperOrigin-RevId: 468541924
2022-08-18 14:02:40 -07:00
jax authors
03e2ca0ee7 roll-forward deletion of custom_jvp_call_jaxpr_p
PiperOrigin-RevId: 468522879
2022-08-18 12:39:21 -07:00
Matthew Johnson
3a20de1575 roll-forward deletion of custom_jvp_call_jaxpr_p
PiperOrigin-RevId: 468499658
2022-08-18 11:01:10 -07:00
jax authors
fe665b3a64 Copybara import of the project:
--
887b7ce2cb3d6d8aedac5cc273e137f1c876e3c7 by Matthew Johnson <mattjj@google.com>:

remove custom_jvp_call_jaxpr_p and its rules

They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!

PiperOrigin-RevId: 468373797
2022-08-17 22:40:58 -07:00
Matthew Johnson
887b7ce2cb remove custom_jvp_call_jaxpr_p and its rules
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!
2022-08-17 21:12:27 -07:00
jax authors
9ca37c9e33 Merge pull request #11950 from mattjj:delete-old-remat
PiperOrigin-RevId: 468173667
2022-08-17 05:40:26 -07:00
Matthew Johnson
d19e34fa4a delete old remat implementation
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
2022-08-16 23:16:37 -07:00
jax authors
0abbdd0648 Add a backend field to mlir.ModuleContext so that host callback lowering can use the correct backend
PiperOrigin-RevId: 468024979
2022-08-16 14:26:53 -07:00