PjRt GPU and CPU has recently gotten memory space support with just one memory space per device, so enabling relevant JAX memory tests. Most tests cannot be enabled yet because they rely on `unpinned_host`, so only enabling `ShardingMemoriesTest` for now.
PiperOrigin-RevId: 633335638
A previous change removed the only non-constrained lowering rule, breaking lowering for platforms without explicit lowering rules
PiperOrigin-RevId: 633297839
By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior.
PiperOrigin-RevId: 633264117
Now we allow a module exported for 1 device and not using any sharding annotations
to be called from a computation that uses multiple devices. Such exported modules
can be parallelized trivially point-wise.
In preparation for a better integration of the jax.experimental.export with
the AOT APIs, we make several simplifications:
* turn on always the generation of shape assertions in presence of shape
polymorphism. Previously, shape assertions were turned on unless the
serialization version was less than 7 (possible only before March 27th, 2024
when the minimum serialization version was bumped to 9), or if the
user specified explicitly that shape assertions should be turned off. It is
not safe to turn off shape assertions and I am not aware of an instance where
somebody had to turn them off, except for temporary debugging. We keep the
`DisabledSafetyCheck.shape_assertions` API for now, for backwards compatibility,
but it has no effect and it emits a deprecation warning.
* remove the code that was conditional on the serialization version
being less than 9, e.g., for the lowering in presence of effects.
* remove a safety check that ensures that when `export` is used on JAX
callables, i.e., not the result of `jax.jit`, the code should not
contain non-replicated sharding annotations. This usage of `export` is
rare and will be removed once `export` will be integrated with the AOT
APIs.
* remove code that was needed only for older jaxlib to replace_tokens_with_dummy.
This dependency is added implicitly by Google-internal infra, but we need
it to be explicit for Bazel builds to avoid ImportErrors at lowering time.
PiperOrigin-RevId: 633147268
This allows the TPU profiler to work with other plugin backends.
Tested on a GPU VM:
$ pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$ pip install -e .
$ TPU_SKIP_MDS_QUERY=1 python tests/cross_aot_test.py
Running tests under Python 3.10.12: /usr/bin/python
[ RUN ] JaxAotTest.test_tpu_profiler_registered_get_topology_from_devices
NOT_FOUND: WARNING: could not determine TPU accelerator type. Set env var `TPU_ACCELERATOR_TYPE` to set manually. TPU runtime may not be properly initialized.
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:285
NOT_FOUND: WARNING: could not determine TPU worker number. Set env var `TPU_WORKER_ID` to set manually. TPU runtime may not be properly initialized.
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:285
NOT_FOUND: WARNING: could not determine TPU worker hostnames or internal IP addresses. Set env var `TPU_WORKER_HOSTNAMES` to set manually. TPU runtime may not be properly initialized.
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:285
learning/45eac/tfrc/runtime/common_lib.cc:341
I0510 00:32:03.063246 130900437979136 cross_aot_test.py:58] Expected to fail to get topology
I0510 00:32:03.079923 130900437979136 xla_bridge.py:884] Unable to initialize backend 'cuda':
I0510 00:32:03.080080 130900437979136 xla_bridge.py:884] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0510 00:32:03.089399 130900437979136 xla_bridge.py:884] Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: No ba16c7433 device found.
W0510 00:32:03.089633 130900437979136 xla_bridge.py:931] An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
/home/jieying/.local/lib/python3.10/site-packages/tensorflow/__init__.py:30: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives
import distutils as _distutils
2024-05-10 00:32:03.359597: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-10 00:32:03.359652: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-10 00:32:03.361368: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-10 00:32:04.562557: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[ OK ] JaxAotTest.test_tpu_profiler_registered_get_topology_from_devices
----------------------------------------------------------------------
Ran 1 test in 2.549s
OK
In tests/cross_aot_test.py
class JaxAotTest(jtu.JaxTestCase):
def test_tpu_profiler_registered_get_topology_from_devices(self):
try:
_ = topologies.get_topology_desc(
topology_name='fake_topology',
platform='tpu',
)
except xla_extension.XlaRuntimeError:
logging.info('Expected to fail to get topology')
with tempfile.TemporaryDirectory() as tmpdir:
try:
jax.profiler.start_trace(tmpdir)
jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
jnp.ones(jax.local_device_count())
)
finally:
jax.profiler.stop_trace()
proto_path = glob.glob(
os.path.join(tmpdir, '**/*.xplane.pb'), recursive=True
)
self.assertLen(proto_path, 1)
with open(proto_path[0], 'rb') as f:
proto = f.read()
# Sanity check that serialized proto contains host, and Python traces
# without deserializing.
self.assertIn(b'/host:metadata', proto)
if jtu.test_device_matches(['tpu']):
self.assertNotIn(b'/device:TPU', proto)
self.assertIn(b'pxla.py', proto)
PiperOrigin-RevId: 633076007
Previously TileTransform.transform_index() would transform like so
x, y, z -> x, y // t_x, z // t_y, 0, 0
But it acutally should be
x, y, z -> x, y // t_x, z // t_y, y % t_x, z % t_y
PiperOrigin-RevId: 633052067
This hopefully should go away when XLA implements it's own memory space propagation pass or JAX adds memory_kind to the type system of jaxpr i.e. on avals.
It's required to treat the following code blocks (1) and (2) as equivalent when lowering to stablehlo. In general shardings should also be treated the same way but we'll cross that bridge later.
1. `jit(f, out_shardings=s_host)`
2. ```
@jax.jit
def f(x):
return jax.device_put(x, s_host)
```
PiperOrigin-RevId: 632621025
GetDefaultLayout added a fallback for GPU backend so it is no longer blocked by the fact that PJRT C API does not support GetDefaultLayout yet.
PiperOrigin-RevId: 632555239
Rely on XLA decomposition.
# JAX GPU microbenchmarks
285us for cumsum over 1e8 elements
449us for cumsum over 1e8 elements.
# JAX CPU microbenchmarks:
1.8s vs. 0.7s for 50 iterations over cumsum over 1e7 elements
PiperOrigin-RevId: 632547166
The runfiles of the original targets were lost when the symlinked files were used.
This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When `xla_extension` is simlinked, the content of the runfiles is lost. With `genrule` the content of the runfiles is preserved.
PiperOrigin-RevId: 632508121