This allows mypy/pytype to obtain accurate types for the public jax.numpy APIs, which is helpful to downstream users of JAX, if not JAX itself.
PiperOrigin-RevId: 570058363
An upcoming XLA change will reject programs containing int4 on CPU and GPU, because the XLA support is buggy and incomplete. When the XLA supports this we can reenable these tests.
Issue https://github.com/google/jax/issues/17672
PiperOrigin-RevId: 570042917
We change the lowering rule selection code to work when
`ModuleContext.lowering_parameters.platforms` contains multiple
string, and emit conditional
code to select the lowering based on the platform index argument.
These changes will not affect the normal JAX lowering paths (when
`ModuleContext.lowering_parameters.platforms` is `None`). It will
also not affect the JAX native serialization paths for single
platform lowering.
These changes should work for most primitives, with the exception
of the few ones that actually access `ModuleContext.platform` inside
the lowering rules (most primitives just register different
rules for different platforms, which is taken into account by
these changes).
Previous PR in this series: #17316.
* Use pathlib.Path object-oriented paths.
* Change copy_files() helper to copy many files in one call.
* Make copy_files() also make the output directory, if needed.
* Format file with pyink --pyink-indentation=2
Similarly for the GpuName() constant.
While most of the time we treat CUDA and ROCm GPUs identically, we sometimes want to distinguish between CUDA and ROCm (e.g., for DLPack exports) and it's helpful if this is encoded in the platform ID.
PiperOrigin-RevId: 569513495
There are currently two parameters that are used to configure
lowering: lowering_platform (for cross-platform lowering), and
override_lowering_rules. Each of them are passed as separate arguments
through several layers of lowering internal functions. This is tedious,
and error prone. In fact, override_lowering_rules was not plumbed
in all places, and due to using default arguments in all places,
this leads to silent errors.
We foresee introducing other parameters for lowering: for multi-platform
lowering, for controlling the lowering of effects.
Here is pack all such parameters into a `mlir.LoweringParameters`
dataclass and we plumb that through.
`Array` is structurally a `Sequence[Array]`, so the first overload always
matches under pytype, which defines `collections.abc.Sequence` as a
`Protocol`.
See
b8f91a37e5/pytype/stubs/builtins/typing.pytd (L149).
This assumes less about whether the thread that destructs `CacheEntry` has GIL or not, which is difficult to reason about due to the `xla::LRUCache`'s use of `std::shared_ptr<CacheEntry>`.
The following changes have been made in JAX to accommodate the behavior differences from direct destruction to GC:
* Since `PyLoadedExecutable`s cached in `WeakRefLRUCache` are now destructed out of band, `PyClient::LiveExecutables()` calls `GlobalPyRefManager()->CollectGarbage()` to make the returned information accurate and up to date.
* `test_jit_reference_dropping` has been updated to call `gc.collect()` before verifying the live executable counts since the destruction of executables owned by weak ref maps is now done out of band as part of `GlobalPyRefManager`'s GC.
PiperOrigin-RevId: 569062402