161 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
a9e48af260 Deprecated xla_call_p since it has been replaced with pjit.pjit_p
PiperOrigin-RevId: 518921538
2023-03-23 11:44:42 -07:00
Peter Hawkins
befce6d2c8 [XLA:Python] Allow passing ExecutableBuildOptions to outfeed receiver.
Outfeed receiver compiles computations (during shutdown), and if the correct options aren't provided, then it may not be able to do things like find ptxas for CUDA builds. Plumb the executable build options through from Python.

PiperOrigin-RevId: 518852909
2023-03-23 07:31:06 -07:00
Peter Hawkins
ed491b3056 Shorten alias chains for names exported in jax. namespace.
Add some additional type annotations on public APIs.

This allows pytype to do a better job of type inference.

PiperOrigin-RevId: 513255770
2023-03-01 09:19:44 -08:00
Yash Katariya
52a7701dda Replace usage of {in|out}_axis_resources with {in|out}_shardings
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
jax authors
3838d7612a Merge pull request #14504 from skye:host_callback_pjrt_error
PiperOrigin-RevId: 509972891
2023-02-15 17:11:01 -08:00
Skye Wanderman-Milne
d9f628c972 Raise a user-friendly error message if in/outfeed-based host_callback stuff is used with PJRT C API.
Prior to this change, it would crash horribly instead.

I manually tested by running the following on a Cloud TPU v4-8:
```
JAX_USE_PJRT_C_API_ON_TPU=1 python3 -m pytest tests/host_callback_test.py --tb=no
```
And verifying that all errors were the new error message.

The new error message is:
`host_callback functionality isn't supported with the new Cloud TPU
runtime. See https://jax.readthedocs.io/en/latest/debugging/index.html
and
https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
for alternatives. Please file a feature request at
https://github.com/google/jax/issues if none of the alternatives are
sufficent.`
2023-02-16 00:12:25 +00:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Peter Hawkins
6860cb8d2a Move jax.interpreters.xla to jax._src.interpreters.xla.
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Peter Hawkins
a13a2c5cc2 [JAX] Remove obsolete unit type declarations in jax.core.
Remove obsolete unit test in host_callback.

PiperOrigin-RevId: 507473737
2023-02-06 07:33:14 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Jake VanderPlas
a0eae5709f Raise an error when attempting to mutate Jaxpr objects 2023-01-23 09:37:58 -08:00
George Necula
30cf057bf3 [host_callback] Add device_index to hcb.call and add tests
The device_index feature works only with outfeed, add an
error message.

PiperOrigin-RevId: 502951721
2023-01-18 12:41:11 -08:00
Yash Katariya
94f0ccc54a Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
c36c25aaf2 Use in_shardings and out_shardings because those are the things available in pjit's params.
PiperOrigin-RevId: 499296361
2023-01-03 13:07:02 -08:00
Jake VanderPlas
fe4c9584f7 doc: fix host callback module crossref 2022-12-27 15:59:32 -08:00
Yash Katariya
57840dd916 Move functions into api_util.py and dispatch.py to remove circular import error when pjit is imported in api.py for merging the jit and pjit frontend API.
PiperOrigin-RevId: 497172760
2022-12-22 08:42:05 -08:00
Chang Lan
9c4e2fa8fa Make the device assignment of outfeed configurable
PiperOrigin-RevId: 496574960
2022-12-19 22:53:15 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Jake VanderPlas
904398a43d [x64] better type safety for host_callback 2022-12-01 11:47:07 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Dan Zheng
9b0c4e5b9c Fix typo.
decice -> device
2022-10-14 22:12:08 -07:00
Nicholas Junge
efd61b73f6 Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
b7426b5ef9 rolling forward deletion of custom_jvp_call_jaxpr_p yet again...
PiperOrigin-RevId: 468541924
2022-08-18 14:02:40 -07:00
jax authors
03e2ca0ee7 roll-forward deletion of custom_jvp_call_jaxpr_p
PiperOrigin-RevId: 468522879
2022-08-18 12:39:21 -07:00
Matthew Johnson
3a20de1575 roll-forward deletion of custom_jvp_call_jaxpr_p
PiperOrigin-RevId: 468499658
2022-08-18 11:01:10 -07:00
jax authors
fe665b3a64 Copybara import of the project:
--
887b7ce2cb3d6d8aedac5cc273e137f1c876e3c7 by Matthew Johnson <mattjj@google.com>:

remove custom_jvp_call_jaxpr_p and its rules

They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!

PiperOrigin-RevId: 468373797
2022-08-17 22:40:58 -07:00
Matthew Johnson
887b7ce2cb remove custom_jvp_call_jaxpr_p and its rules
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!
2022-08-17 21:12:27 -07:00
jax authors
9ca37c9e33 Merge pull request #11950 from mattjj:delete-old-remat
PiperOrigin-RevId: 468173667
2022-08-17 05:40:26 -07:00
Matthew Johnson
d19e34fa4a delete old remat implementation
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
2022-08-16 23:16:37 -07:00
jax authors
0abbdd0648 Add a backend field to mlir.ModuleContext so that host callback lowering can use the correct backend
PiperOrigin-RevId: 468024979
2022-08-16 14:26:53 -07:00
Sharad Vikram
375ef0bc63 Making sharding an arg to MHLO callback lowering 2022-08-03 11:04:14 -07:00
Matthew Johnson
cbcfe95e80 fix ad_checkpoint.checkpoint caching issue
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
2022-07-29 19:59:28 -07:00
George Necula
ab7d036271 Remove dependencies on masking.py 2022-07-25 11:25:26 +03:00
Sharad Vikram
b666f665ec Rollback of HCB GPU custom call due to internal failures
PiperOrigin-RevId: 460079787
2022-07-10 13:05:27 -07:00
jax authors
66ab792fc0 Merge pull request #11383 from YouJiacheng:Enable-HCB-customCall-implementation-on-GPU
PiperOrigin-RevId: 459872063
2022-07-08 18:23:16 -07:00
YouJiacheng
7c707832aa Enable CustomCall implementation on GPU 2022-07-09 02:29:08 +08:00
Sharad Vikram
6274b9ed39 Enable Python callbacks on TFRT TPU backend
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Sharad Vikram
fcf65ac64e Bump minimum jaxlib version to 0.3.10 2022-06-28 15:39:21 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Matthew Johnson
bb56f40947 Internal change
PiperOrigin-RevId: 447549479
2022-05-09 13:26:30 -07:00
Matthew Johnson
705c07ae6d remove count attribute and total_ordering from core.Var
Originally we used the 'Var.count' attribute to ensure Var instances were
printed consistently regardless of context, even though only their object id
was load-bearing. That is, Var.count was only used for pretty printing. (#1949
added a total_ordering on Var for reasons out of scope of JAX's core code.)

But #8019 revised our pretty-printing so as not to use Var.count. Instead it
chose how to pretty-print Var instances based on their order of appearance in a
jaxpr. That meant Var.count really wasn't useful anymore. So this PR removes
Var.count.

In fact, Var.__repr__ and JaxprEqn.__repr__ were made confusing after #8019,
since they could print variable names totally different from the names that
would appear when the same JaxprEqn or Var objects were printed as part of a
jaxpr. That is, before this PR< we might have a jaxpr which printed like:

```python
import jax

def f(x):
  for _ in range(3):
    x = jax.numpy.sin(x)
  return x

jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)

_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
```

Notice the variable names in the equation pretty-print don't correspond to any
in the jaxpr pretty-print!

So this PR changes JaxprEqn.__repr__ and Var.__repr__ to show Var object ids.
2022-05-09 09:31:23 -07:00
Peter Hawkins
931bf3674b [JAX] Split the "gpu" platform in internal JAX usage into separate "cuda" and "rocm" platforms.
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.

[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.

PiperOrigin-RevId: 446737518
2022-05-05 09:33:06 -07:00
Sharad Vikram
5d68280e58 Add an emit_python_callback helper function
PiperOrigin-RevId: 444633097
2022-04-26 12:19:50 -07:00