20692 Commits

Author SHA1 Message Date
Jake VanderPlas
5150cfeeb0 Fix PRNGKey handling under jit-of-pmap 2024-05-13 19:04:22 -07:00
Junwhan Ahn
cd6e012326 Enable JAX memory tests for GPUs and CPUs
PjRt GPU and CPU has recently gotten memory space support with just one memory space per device, so enabling relevant JAX memory tests. Most tests cannot be enabled yet because they rely on `unpinned_host`, so only enabling `ShardingMemoriesTest` for now.

PiperOrigin-RevId: 633335638
2024-05-13 14:37:37 -07:00
Peter Hawkins
72a81e58e6 Readd a default lowering rule for cumsum et al.
A previous change removed the only non-constrained lowering rule, breaking lowering for platforms without explicit lowering rules

PiperOrigin-RevId: 633297839
2024-05-13 12:34:51 -07:00
Justin Fu
1e48adc698 [Pallas] Pad input/outputs in interpret mode to fix errors in OOB memory accesses.
PiperOrigin-RevId: 633283991
2024-05-13 11:50:21 -07:00
jax authors
b8ed346665 Merge pull request #21119 from jakevdp:linalg-cond
PiperOrigin-RevId: 633281675
2024-05-13 11:43:24 -07:00
jax authors
6189a559cf Merge pull request #21048 from jakevdp:np-squeeze-doc
PiperOrigin-RevId: 633273019
2024-05-13 11:18:09 -07:00
Yue Sheng
9e7830df2d Async dispatch expensive computations on the JAX CPU backend.
By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior.

PiperOrigin-RevId: 633264117
2024-05-13 10:53:09 -07:00
Yash Katariya
54d4072730 Populate propagated_out_mem_kinds inside the branch where it's needed
PiperOrigin-RevId: 633262630
2024-05-13 10:48:52 -07:00
Jake VanderPlas
1f6d902174 jnp.linalg.cond: improve implementation & docs 2024-05-13 10:36:50 -07:00
jax authors
85e91c2be4 Merge pull request #21203 from gnecula:export_device_poly
PiperOrigin-RevId: 633253709
2024-05-13 10:23:28 -07:00
Jake VanderPlas
7ceae95dd7 Better documentation for several jax.numpy functions 2024-05-13 10:09:52 -07:00
George Necula
98aead70eb [export] Relax the check that exported modules are used with same number of devices as when exported
Now we allow a module exported for 1 device and not using any sharding annotations
to be called from a computation that uses multiple devices. Such exported modules
can be parallelized trivially point-wise.
2024-05-13 20:09:43 +03:00
Justin Fu
e4f3b3ff8f
Merge pull request #21169 from justinjfu/splash_precision_fix
Disable bfloat16 on long seq lengths for splash attention kernel test
2024-05-13 09:34:39 -07:00
Jake VanderPlas
35a512dadf CI: update NumPy build version to 2.0.0rc2
PiperOrigin-RevId: 633233231
2024-05-13 09:18:40 -07:00
George Karpenkov
de14e3b32e Reverts 49bd4d6f01d6cda00f9b1bdfbda156636baae928
PiperOrigin-RevId: 633221195
2024-05-13 08:35:40 -07:00
jax authors
e66a234be4 Merge pull request #21191 from gnecula:export_simplify
PiperOrigin-RevId: 633179742
2024-05-13 05:48:13 -07:00
jax authors
54ca3d46d3 Merge pull request #21202 from superbobry:pallas
PiperOrigin-RevId: 633176367
2024-05-13 05:30:57 -07:00
jax authors
1fed78499f Merge pull request #20940 from piotrfilipiuk:changelist/623910451
PiperOrigin-RevId: 633170419
2024-05-13 05:03:28 -07:00
George Necula
78d4d0a498 [export] Simplify export internals, prepare for integration with AOT APIs
In preparation for a better integration of the jax.experimental.export with
the AOT APIs, we make several simplifications:

  * turn on always the generation of shape assertions in presence of shape
  polymorphism. Previously, shape assertions were turned on unless the
  serialization version was less than 7 (possible only before March 27th, 2024
  when the minimum serialization version was bumped to 9), or if the
  user specified explicitly that shape assertions should be turned off. It is
  not safe to turn off shape assertions and I am not aware of an instance where
  somebody had to turn them off, except for temporary debugging. We keep the
  `DisabledSafetyCheck.shape_assertions` API for now, for backwards compatibility,
  but it has no effect and it emits a deprecation warning.

  * remove the code that was conditional on the serialization version
  being less than 9, e.g., for the lowering in presence of effects.

  * remove a safety check that ensures that when `export` is used on JAX
  callables, i.e., not the result of `jax.jit`, the code should not
  contain non-replicated sharding annotations. This usage of `export` is
  rare and will be removed once `export` will be integrated with the AOT
  APIs.

  * remove code that was needed only for older jaxlib to replace_tokens_with_dummy.
2024-05-13 14:41:51 +03:00
Sergei Lebedev
8094d0d132 Guarded Pallas GPU import in tests/pallas/pallas_test.py
We do not build Triton IR bindings on Windows.

This should fix https://github.com/google/jax/actions/runs/9051189315/job/24867428634.
2024-05-13 12:23:18 +01:00
Sergei Lebedev
1c6855a492 Ensured that all Pallas GPU tests depend on :pallas_gpu
This dependency is added implicitly by Google-internal infra, but we need
it to be explicit for Bazel builds to avoid ImportErrors at lowering time.

PiperOrigin-RevId: 633147268
2024-05-13 03:07:22 -07:00
Jieying Luo
ba8480a212 Register TPU profiler plugin when get_topology_desc is called with tpu platform.
This allows the TPU profiler to work with other plugin backends.

Tested on a GPU VM:
$ pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$ pip install -e .
$ TPU_SKIP_MDS_QUERY=1 python tests/cross_aot_test.py
Running tests under Python 3.10.12: /usr/bin/python
[ RUN      ] JaxAotTest.test_tpu_profiler_registered_get_topology_from_devices
NOT_FOUND: WARNING: could not determine TPU accelerator type. Set env var `TPU_ACCELERATOR_TYPE` to set manually. TPU runtime may not be properly initialized.
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:285

NOT_FOUND: WARNING: could not determine TPU worker number. Set env var `TPU_WORKER_ID` to set manually. TPU runtime may not be properly initialized.
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:285

NOT_FOUND: WARNING: could not determine TPU worker hostnames or internal IP addresses. Set env var `TPU_WORKER_HOSTNAMES` to set manually. TPU runtime may not be properly initialized.
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:285
learning/45eac/tfrc/runtime/common_lib.cc:341

I0510 00:32:03.063246 130900437979136 cross_aot_test.py:58] Expected to fail to get topology
I0510 00:32:03.079923 130900437979136 xla_bridge.py:884] Unable to initialize backend 'cuda':
I0510 00:32:03.080080 130900437979136 xla_bridge.py:884] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0510 00:32:03.089399 130900437979136 xla_bridge.py:884] Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: No ba16c7433 device found.
W0510 00:32:03.089633 130900437979136 xla_bridge.py:931] An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
/home/jieying/.local/lib/python3.10/site-packages/tensorflow/__init__.py:30: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives
  import distutils as _distutils
2024-05-10 00:32:03.359597: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-10 00:32:03.359652: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-10 00:32:03.361368: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-10 00:32:04.562557: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[       OK ] JaxAotTest.test_tpu_profiler_registered_get_topology_from_devices
----------------------------------------------------------------------
Ran 1 test in 2.549s

OK

In tests/cross_aot_test.py
class JaxAotTest(jtu.JaxTestCase):
  def test_tpu_profiler_registered_get_topology_from_devices(self):
    try:
      _ = topologies.get_topology_desc(
          topology_name='fake_topology',
          platform='tpu',
      )
    except xla_extension.XlaRuntimeError:
      logging.info('Expected to fail to get topology')

    with tempfile.TemporaryDirectory() as tmpdir:
      try:
        jax.profiler.start_trace(tmpdir)
        jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
            jnp.ones(jax.local_device_count())
        )
      finally:
        jax.profiler.stop_trace()

      proto_path = glob.glob(
          os.path.join(tmpdir, '**/*.xplane.pb'), recursive=True
      )
      self.assertLen(proto_path, 1)
      with open(proto_path[0], 'rb') as f:
        proto = f.read()
      # Sanity check that serialized proto contains host, and Python traces
      # without deserializing.
      self.assertIn(b'/host:metadata', proto)
      if jtu.test_device_matches(['tpu']):
        self.assertNotIn(b'/device:TPU', proto)
      self.assertIn(b'pxla.py', proto)

PiperOrigin-RevId: 633076007
2024-05-12 20:50:41 -07:00
jax authors
af4bddb74d [Mosaic GPU] Correct TileTransform.transform_index()
Previously TileTransform.transform_index() would transform like so

x, y, z -> x, y // t_x, z // t_y, 0, 0

But it acutally should be

x, y, z -> x, y // t_x, z // t_y, y % t_x, z % t_y

PiperOrigin-RevId: 633052067
2024-05-12 18:23:25 -07:00
jax authors
dac5b754ee [jax:mosaic-gpu] Test type conversion for tiled fragment array
PiperOrigin-RevId: 633052027
2024-05-12 18:23:07 -07:00
jax authors
c1f184d6ab [Mosaic GPU] Cleanup matmul example mainly by removing the m/n block loops (1 step).
PiperOrigin-RevId: 633051910
2024-05-12 18:20:06 -07:00
jax authors
3e3d916a14 Update XLA dependency to use revision
8dd9fc3539.

PiperOrigin-RevId: 633049132
2024-05-12 18:02:26 -07:00
jax authors
b4f2145337 Update XLA dependency to use revision
d3e881ad66.

PiperOrigin-RevId: 632860505
2024-05-11 18:00:43 -07:00
Adam Paszke
a527b71970 [Mosaic GPU] Prepare for writing warp-specialized kernels
PiperOrigin-RevId: 632854287
2024-05-11 17:09:08 -07:00
Peter Hawkins
49bd4d6f01 Reverts 586568f4fe44cf9ad8b1bd022148a10c4b69f33a
PiperOrigin-RevId: 632818524
2024-05-11 12:24:06 -07:00
piotrfilipiuk
93dfe05aec Implements Ragged Dot API 2024-05-11 06:40:18 -07:00
Yue Sheng
3b03e5497d Raise a runtime error when trying to convert the jax.Array wrapped by jax.core.Token to a numpy array, as it is an internal implementation detail and the buffer has XLA token shape.
PiperOrigin-RevId: 632682906
2024-05-10 21:08:06 -07:00
jax authors
20646eb07c Update XLA dependency to use revision
0b3dc68410.

PiperOrigin-RevId: 632655986
2024-05-10 18:19:21 -07:00
Jake VanderPlas
9ac1d38226 Finish jax and jaxlib 0.4.28 release
PiperOrigin-RevId: 632653310
2024-05-10 18:06:52 -07:00
jax authors
979d9ca3e5 Merge pull request #21168 from 8bitmp3:upgrade-sharded--doc
PiperOrigin-RevId: 632648408
2024-05-10 17:44:15 -07:00
Yash Katariya
a4693db6cf Add a jaxpr interpreter for propagating memory kinds to output. It only triggers if we detect multiple memory kinds in the jaxpr.
This hopefully should go away when XLA implements it's own memory space propagation pass or JAX adds memory_kind to the type system of jaxpr i.e. on avals.

It's required to treat the following code blocks (1) and (2) as equivalent when lowering to stablehlo. In general shardings should also be treated the same way but we'll cross that bridge later.

1. `jit(f, out_shardings=s_host)`

2. ```
   @jax.jit
   def f(x):
     return jax.device_put(x, s_host)
   ```

PiperOrigin-RevId: 632621025
2024-05-10 15:34:57 -07:00
Sergei Lebedev
27c932a3a9 Do not import from lowering in tests/pallas/pallas_test.py
This ensures that the test is importable even with a non-GPU jaxlib, which
does not have Triton dialect bindings.

PiperOrigin-RevId: 632603225
2024-05-10 14:25:17 -07:00
8bitmp3
9ea3fcbab3 Upgrade JAX Parallelism Sharded Computation 101 doc 2024-05-10 21:24:16 +00:00
jax authors
17444fc8fa Merge pull request #21174 from hawkinsp:spmm
PiperOrigin-RevId: 632589433
2024-05-10 13:35:04 -07:00
Peter Hawkins
dda428e74a Disable tests that trigger warning if x64 mode isn't enabled. 2024-05-10 19:58:22 +00:00
jax authors
c3cab2e3d3 Reverts 6c425338d20c0c9be3fc69d2f07ababf79c881d3
PiperOrigin-RevId: 632579101
2024-05-10 12:56:10 -07:00
jax authors
c231cd51eb Merge pull request #21173 from hawkinsp:precision
PiperOrigin-RevId: 632577567
2024-05-10 12:50:07 -07:00
Peter Hawkins
24b47318bd Force float32 matmuls in examples_test.
This test started failing when we changed our CI to use L4 GPUs. Using
highest precision resolves the problem.
2024-05-10 19:30:02 +00:00
Jieying Luo
0a3e432745 [PJRT C API] Enable PJRT C API runtime in jax2tf dlpack.
GetDefaultLayout added a fallback for GPU backend so it is no longer blocked by the fact that PJRT C API does not support GetDefaultLayout yet.

PiperOrigin-RevId: 632555239
2024-05-10 11:30:37 -07:00
Peter Hawkins
6c425338d2 Reverts 0267ed0ba9584bbc137792361b53aa80e9c4d306
PiperOrigin-RevId: 632548226
2024-05-10 11:06:38 -07:00
George Karpenkov
586568f4fe Simplify JAX lowering rules for cumulative sum
Rely on XLA decomposition.

# JAX GPU microbenchmarks

285us for cumsum over 1e8 elements

449us for cumsum over 1e8 elements.

# JAX CPU microbenchmarks:

1.8s vs. 0.7s for 50 iterations over cumsum over 1e7 elements

PiperOrigin-RevId: 632547166
2024-05-10 11:03:28 -07:00
jax authors
13a195589e Merge pull request #21167 from jakevdp:einsum-path-func
PiperOrigin-RevId: 632538144
2024-05-10 10:35:32 -07:00
Justin Fu
ebb918402c Disable bfloat16 on long seq lengths for splash attention kernel test 2024-05-10 10:31:29 -07:00
Yash Katariya
bac3a6fa8f Allow tokens being passed to jit and through dispatch and being returned from the jitted function.
Fixes https://github.com/google/jax/issues/21160

PiperOrigin-RevId: 632531105
2024-05-10 10:12:48 -07:00
jax authors
0267ed0ba9 Replace xla_extension symlink with genrule that makes xla_extension module accessible from jax._src.lib.
The runfiles of the original targets were lost when the symlinked files were used.

This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When `xla_extension` is simlinked, the content of the runfiles is lost. With `genrule` the content of the runfiles is preserved.

PiperOrigin-RevId: 632508121
2024-05-10 08:48:12 -07:00
Jake VanderPlas
d07951c592 jnp.einsum_path: improve docs & annotations 2024-05-10 08:39:32 -07:00