Currently JAX callbacks on TPU raise errors when the called function takes empty arguments or returns empty results. It seems that the send_to_host function works
even with empty arrays, but recv_from_host crashes (crash log below).
Here we work around this issue, by ensuring that only the non-empty results of the Python callback are sent to the device computation and the empty results are replaced with empty constants in the device computation.
This is part of the work to replace uses of host_callback with io_callback.
PiperOrigin-RevId: 559061336
This flag was added in https://github.com/google/jax/pull/8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.
This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.
PiperOrigin-RevId: 557520668
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.
Remove an unused top_k translation rule as well.
PiperOrigin-RevId: 554946059
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.
For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.
Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.
This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.
PiperOrigin-RevId: 551604974
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.
One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.
PiperOrigin-RevId: 549301796
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.
PiperOrigin-RevId: 523146076
Outfeed receiver compiles computations (during shutdown), and if the correct options aren't provided, then it may not be able to do things like find ptxas for CUDA builds. Plumb the executable build options through from Python.
PiperOrigin-RevId: 518852909
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.
Unchanged occurrences:
1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
argument value in Lowering.as_text and Lowering.compiler_ir.
2) Documentation (changelog, JEPs, IR examples, etc).
3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
so both are necessary to disambiguate.
PiperOrigin-RevId: 495771153
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):
- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.
Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:
```py
import logging
logger = logging.getLogger(__name__)
logger.debug(...)
logger.info(...)
```
The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.
The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
--
887b7ce2cb3d6d8aedac5cc273e137f1c876e3c7 by Matthew Johnson <mattjj@google.com>:
remove custom_jvp_call_jaxpr_p and its rules
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).
This change languished until we could land #11830 / #11950 and friends. But now
we can!
PiperOrigin-RevId: 468373797
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).
This change languished until we could land #11830 / #11950 and friends. But now
we can!