135 Commits

Author SHA1 Message Date
Dan Zheng
9b0c4e5b9c Fix typo.
decice -> device
2022-10-14 22:12:08 -07:00
Nicholas Junge
efd61b73f6 Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
b7426b5ef9 rolling forward deletion of custom_jvp_call_jaxpr_p yet again...
PiperOrigin-RevId: 468541924
2022-08-18 14:02:40 -07:00
jax authors
03e2ca0ee7 roll-forward deletion of custom_jvp_call_jaxpr_p
PiperOrigin-RevId: 468522879
2022-08-18 12:39:21 -07:00
Matthew Johnson
3a20de1575 roll-forward deletion of custom_jvp_call_jaxpr_p
PiperOrigin-RevId: 468499658
2022-08-18 11:01:10 -07:00
jax authors
fe665b3a64 Copybara import of the project:
--
887b7ce2cb3d6d8aedac5cc273e137f1c876e3c7 by Matthew Johnson <mattjj@google.com>:

remove custom_jvp_call_jaxpr_p and its rules

They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!

PiperOrigin-RevId: 468373797
2022-08-17 22:40:58 -07:00
Matthew Johnson
887b7ce2cb remove custom_jvp_call_jaxpr_p and its rules
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!
2022-08-17 21:12:27 -07:00
jax authors
9ca37c9e33 Merge pull request #11950 from mattjj:delete-old-remat
PiperOrigin-RevId: 468173667
2022-08-17 05:40:26 -07:00
Matthew Johnson
d19e34fa4a delete old remat implementation
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
2022-08-16 23:16:37 -07:00
jax authors
0abbdd0648 Add a backend field to mlir.ModuleContext so that host callback lowering can use the correct backend
PiperOrigin-RevId: 468024979
2022-08-16 14:26:53 -07:00
Sharad Vikram
375ef0bc63 Making sharding an arg to MHLO callback lowering 2022-08-03 11:04:14 -07:00
Matthew Johnson
cbcfe95e80 fix ad_checkpoint.checkpoint caching issue
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
2022-07-29 19:59:28 -07:00
George Necula
ab7d036271 Remove dependencies on masking.py 2022-07-25 11:25:26 +03:00
Sharad Vikram
b666f665ec Rollback of HCB GPU custom call due to internal failures
PiperOrigin-RevId: 460079787
2022-07-10 13:05:27 -07:00
jax authors
66ab792fc0 Merge pull request #11383 from YouJiacheng:Enable-HCB-customCall-implementation-on-GPU
PiperOrigin-RevId: 459872063
2022-07-08 18:23:16 -07:00
YouJiacheng
7c707832aa Enable CustomCall implementation on GPU 2022-07-09 02:29:08 +08:00
Sharad Vikram
6274b9ed39 Enable Python callbacks on TFRT TPU backend
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Sharad Vikram
fcf65ac64e Bump minimum jaxlib version to 0.3.10 2022-06-28 15:39:21 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Matthew Johnson
bb56f40947 Internal change
PiperOrigin-RevId: 447549479
2022-05-09 13:26:30 -07:00
Matthew Johnson
705c07ae6d remove count attribute and total_ordering from core.Var
Originally we used the 'Var.count' attribute to ensure Var instances were
printed consistently regardless of context, even though only their object id
was load-bearing. That is, Var.count was only used for pretty printing. (#1949
added a total_ordering on Var for reasons out of scope of JAX's core code.)

But #8019 revised our pretty-printing so as not to use Var.count. Instead it
chose how to pretty-print Var instances based on their order of appearance in a
jaxpr. That meant Var.count really wasn't useful anymore. So this PR removes
Var.count.

In fact, Var.__repr__ and JaxprEqn.__repr__ were made confusing after #8019,
since they could print variable names totally different from the names that
would appear when the same JaxprEqn or Var objects were printed as part of a
jaxpr. That is, before this PR< we might have a jaxpr which printed like:

```python
import jax

def f(x):
  for _ in range(3):
    x = jax.numpy.sin(x)
  return x

jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)

_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
```

Notice the variable names in the equation pretty-print don't correspond to any
in the jaxpr pretty-print!

So this PR changes JaxprEqn.__repr__ and Var.__repr__ to show Var object ids.
2022-05-09 09:31:23 -07:00
Peter Hawkins
931bf3674b [JAX] Split the "gpu" platform in internal JAX usage into separate "cuda" and "rocm" platforms.
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.

[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.

PiperOrigin-RevId: 446737518
2022-05-05 09:33:06 -07:00
Sharad Vikram
5d68280e58 Add an emit_python_callback helper function
PiperOrigin-RevId: 444633097
2022-04-26 12:19:50 -07:00
Sharad Vikram
098f2126ae Add CustomCall MLIR lowering for HCB outside_call primitive
PiperOrigin-RevId: 444436860
2022-04-25 19:38:39 -07:00
Sharad Vikram
0fa1eddd25 Adds simple effect types to jaxprs 2022-04-11 11:50:41 -07:00
Peter Hawkins
208e83ceb7 Avoid retracing when a host_callback.call is called multiple times with the same function.
If we build a lambda in the host_callback.call() method, the identity of that lambda is different each time and will never lead to a primitive compilation cache hit. Instead, use a custom wrapper object with hash/equality.

This issue was found in passing while debugging #9970.
2022-04-01 14:41:14 -04:00
Yash Katariya
687a7630ee Deprecate maps.mesh and replace it with maps.Mesh.
PiperOrigin-RevId: 430489855
2022-02-23 10:47:06 -08:00
Antal Szava
d3ed24f910 fix id_tap jit example 2022-01-11 10:43:26 -05:00
Piotr Padlewski
6bf7b0d325
Fix host_callback docs
There was a missing ':' causing invalid rendering of the docs.
2022-01-07 12:06:23 +01:00
jax authors
04ca2d3cde Merge pull request #8934 from juesato:hcb_docs
PiperOrigin-RevId: 416560873
2021-12-15 08:05:32 -08:00
George Necula
3021d3e2e2 [hcb] Add support for remat2 to host_callback
A callback under ad_checkpoint.checkpoint will be invoked
twice when taking the gradient: once during the forward pass
and once again during the backward pass when the residuals
for the forward pass are rematerialized.
2021-12-15 10:32:15 +02:00
Jonathan Uesato
b28a1c5240 Tweak documentation on error handling for host_callback.call() 2021-12-14 07:16:10 -08:00
George Necula
f08156ab7c [hcb] Simplifications to the host_calback API
* dropping support for special AD handling for hcb.id_tap and id_print.
  From now on, only the primals are tapped. The old behavior can be
  obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS
  environment variale, or the --flax_host_callback_ad_transforms flag.
  Additionally, added documentation for how to implement the old behavior
  using JAX custom AD APIs.

This allows us to make some significant cleanup in the internals.
2021-12-11 08:24:56 +01:00
Peter Hawkins
68e9e1c26d Consolidate more XLA-lowering logic between jit, pmap, and xmap.
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.

PiperOrigin-RevId: 413244450
2021-11-30 14:24:33 -08:00
jax authors
737d021fd2 Merge pull request #8329 from AdrienCorenflos:patch-1
PiperOrigin-RevId: 413223979
2021-11-30 12:57:58 -08:00
Peter Hawkins
31ff340020 Remove type annotations for id_tap.
The current annotation does not accurately describe the tap_with_device case.

PiperOrigin-RevId: 412982303
2021-11-29 14:25:10 -08:00
George Necula
0915f6d6fa mend 2021-11-24 11:57:28 +02:00
George Necula
277a1d775e [hcb] Cleanup to account for changes in minimum jaxlib version
We can assume now that jaxlib has the support for CustomCall.
2021-11-24 11:47:11 +02:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
AdrienCorenflos
b056f79654
Fix typo in doc for barrier_wait 2021-10-21 16:12:33 +03:00
Peter Hawkins
e783cbcb72 Port remaining translation rules inside JAX to new style.
PiperOrigin-RevId: 404288551
2021-10-19 09:48:37 -07:00
Peter Hawkins
714e19a794 Remove xla_bridge.make_computation_builder().
This is a vestigal wrapper around xla_client.XlaBuilder whose purpose is long gone.

Also rename uses of XlaComputationBuilder to XlaBuilder. XlaComputationBuilder was an older name that is gone in most places.
2021-10-18 13:20:34 -04:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
efdc3cc794 [JAX] Fix more pylint errors.
* duplicate-string-formatting-argument: use f-strings.
* logging-format-interpolation: use interpolation. Some of these are real but minor performance problems.
* bad-string-format-type: don't use the wrong format type.

PiperOrigin-RevId: 400843759
2021-10-04 16:37:15 -07:00
Peter Hawkins
5fa4613e99 Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
2021-09-27 21:09:24 -04:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
George Necula
d172ba7b7e [host_callback] Fix an assertion failure for grad(remat(host_callback))
Fixes: #5878
2021-09-17 16:59:36 +02:00
Sergei Lebedev
af41a959d3 Most of JAX now uses concrete types for things defined in jaxlib.xla_client
Note that a few call sites in the diff got a ``# type: ignore``, because
the latest jaxlib does not have up-to-date signatures for the correpsonding
callables.
2021-08-16 20:33:36 +01:00
elliotwaite
7392a57b75 DOC: many small fixes 2021-08-04 16:55:13 -07:00