This is doubly non-public: nothing under `jax.lib` is public, and also the object itself has a preceding underscore. Therefore it is safe to remove (chex had referenced this previously, but that's now addressed in adaf1b2b75).
PiperOrigin-RevId: 704825268
This type is unused by JAX, so there is no replacement.
(JAX does have an internal PaddingType enum in lax, but it is not present in any APIs, as best I can tell.)
PiperOrigin-RevId: 684451556
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.
We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.
PiperOrigin-RevId: 684447186
This is an alias for jax.lib.xla_extension. Why the deprecation warning
for this when #22844 removed other APIs without any warning? This one
is relatively commonly used (I found a few dozen downstream references)
so I feld that a deprecation warning might be helpful.
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
These were accidental exports and have public equivalents under the top-level jax namespace. The deprecation policy does not apply to names under jax.lib, which is intended to be private.
PiperOrigin-RevId: 506088434
An upcoming change will move and rename these functions, and it's not clear they should have been public in the first place.
PiperOrigin-RevId: 404051961
* Combine the concepts of "platform" and "backend". The main upshot of this is that the tpu_driver backend requires users to write `jit(..., backend="tpu_driver")` if mixing CPU and TPU execution, however I doubt users are writing that because it didn't work to mix CPU and tpu_driver before.
* Initialize all platforms at startup, rather than lazily initializing platforms on demand. This makes it easy to do things like "list the available platforms".
* Don't use two levels of caching. Cache backends only in xla_bridge.py, not xla_client.py.
PiperOrigin-RevId: 376883261
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.
Fixes https://github.com/google/jax/issues/6671
PiperOrigin-RevId: 374683190
libdevice.10.bc is a redistributable part of the CUDA SDK.
This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.
This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.
This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.
Example output:
```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
process_index: 0
device_count: 8
local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
devices: None
abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
dtype=object), ('x',))
abstract args: []
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
[TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
dtype=object), ('x', 'y'))
abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```