From 3a800803927ecef1506ffc614c34db614d0c755b Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Thu, 20 Feb 2025 00:41:04 +0000 Subject: [PATCH 001/100] fix unit tests to not use fmha rewriter --- jax/_src/cudnn/fused_attention_stablehlo.py | 2 +- tests/fused_attention_stablehlo_test.py | 112 +++++++++----------- 2 files changed, 54 insertions(+), 60 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 62f3dafc5..9e8f89e7a 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -333,7 +333,7 @@ def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, def check_is_flash_attention( - query, key, layout: int, cudnn_version, has_bias, is_training, is_packed, + query, key, layout: int, cudnn_version, has_bias, is_training, is_packed=False, is_fp8=False): # Extract sequence length (T) and head dim (H) based on layout if layout == AttentionLayout.BNTH.value: diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 59b6458ea..8f417431c 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -16,9 +16,6 @@ from functools import partial from absl.testing import absltest import os -os.environ["XLA_FLAGS"] = \ - "--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true" - import numpy as np import jax import jax.numpy as jnp @@ -30,7 +27,6 @@ from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention, check_is_flash_attention, check_cudnn_version, - get_large_negative_number, MaskType, AttentionLayout, ) @@ -90,6 +86,9 @@ cast_to_representable = partial( quantize = partial(quantize_to_fp8, scale=1) +def get_large_negative_number(dtype): + return 0.7 * jnp.finfo(dtype).min + def sdpa_train(query: Array, key: Array, value: Array, @@ -168,7 +167,7 @@ def sdpa_ref(query: Array, B, T, qN, H = query.shape _, _, kN, _ = key.shape - logits = jnp.einsum("bqhd,bkhd->bhqk", query, key) + logits = jnp.einsum("bqhd,bkhd->bhqk", query, key, preferred_element_type=jnp.float32) if scale != 1.0: logits = logits * scale if mask_type == MaskType.CAUSAL: @@ -182,28 +181,31 @@ def sdpa_ref(query: Array, bias = get_sliding_window_mask(logits, sliding_window_length) if mask is not None: large_negative_number = get_large_negative_number(logits.dtype) - mask = jnp.where(mask, jnp.asarray(0, query.dtype), large_negative_number) + mask = jnp.where(mask, 0, large_negative_number) + # combine bias and mask if bias is None: bias = mask elif mask is not None: + bias = bias.astype(logits.dtype) bias += mask + # apply bias to logits if bias is not None: if bias.shape != logits.shape: bias = jnp.broadcast_to(bias, logits.shape) logits = logits + bias.astype(logits.dtype) - probs = jax.nn.softmax(logits, axis=-1) + probs = jax.nn.softmax(logits, axis=-1).astype(query.dtype) if dropout_rate > 0.: keep_prob = 1.0 - dropout_rate dropout_rng = jax.random.key(0) keep = jax.random.bernoulli(dropout_rng, keep_prob, probs.shape) probs = jax.lax.select(keep, probs / keep_prob, jnp.zeros_like(probs)) - encoded = jnp.einsum("bhqk,bkhd->bqhd", probs, value) + encoded = jnp.einsum("bhqk,bkhd->bqhd", probs, value, preferred_element_type=jnp.float32) if mask_type == MaskType.PADDING: # cuDNN padding mask generation will mask out output accordingly # make sure the behavior is the same encoded_mask = get_encoded_padding_mask(encoded) encoded = encoded * encoded_mask - return encoded + return encoded.astype(query.dtype) def sdpa_train_ref(query: Array, key: Array, @@ -239,7 +241,7 @@ def sdpa_train_fp8( f_p = partial( dot_product_attention, scale=scale, mask_type=mask_type, use_fp8=True ) - return f_p(query, key, value, None, None, None, None, fp8_metas) + return f_p(query, key, value, fp8_params=fp8_metas) out, sdpa_vjp = jax.vjp( dot_product_attention_fp8, query, key, value, fp8_metas @@ -274,7 +276,7 @@ class DotProductAttentionTest(jtu.JaxTestCase): use_mask=[False, True], use_bias=[False, True], mask_type=[MaskType.NO_MASK], - dropout_rate=[0, 0.5], + dropout_rate=[0], scale=[0.5], dtype=[jnp.float16, jnp.bfloat16] ) @@ -351,18 +353,13 @@ class DotProductAttentionTest(jtu.JaxTestCase): jitted_sdpa_train(query, key, value, grad, bias, mask) out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ jitted_sdpa_train_ref(query, key, value, grad, bias, mask) - self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) - if seq_len > 512: - # query_grad in flash attention is not deterministic - self.assertArraysAllClose( - query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) - else: - self.assertArraysAllClose( - query_grad_ref, query_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) self.assertArraysAllClose( - key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) + query_grad_ref, query_grad, rtol=2e-1, atol=2e-1) self.assertArraysAllClose( - value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + key_grad_ref, key_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose( + value_grad_ref, value_grad, rtol=2e-1, atol=2e-1) @jtu.run_on_devices("cuda") def test_sdpa_inference(self): @@ -381,9 +378,8 @@ class DotProductAttentionTest(jtu.JaxTestCase): with Mesh(devices, ("dp", "tp")) as mesh: qkv_spec = PartitionSpec("dp", None, "tp", None) qkv_sharding = NamedSharding(mesh, qkv_spec) - replicated = NamedSharding(mesh, PartitionSpec()) in_shardings = ( - qkv_sharding, qkv_sharding, qkv_sharding, replicated, replicated) + qkv_sharding, qkv_sharding, qkv_sharding) out_shardings = qkv_sharding query = jax.device_put(query, qkv_sharding) key = jax.device_put(key, qkv_sharding) @@ -403,15 +399,14 @@ class DotProductAttentionTest(jtu.JaxTestCase): out_shardings=out_shardings ) - out = jitted_sdpa_inference(query, key, value, None, None) - out_ref = jitted_sdpa_inference_ref(query, key, value, None, None) - self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) + out = jitted_sdpa_inference(query, key, value) + out_ref = jitted_sdpa_inference_ref(query, key, value) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) @jtu.run_on_devices("cuda") def test_sdpa_var_seq(self): if jax.device_count() < 4: self.skipTest("Requires more than 4 devices.") - self.skipTest("Skip before fixed.") k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) query = jax.random.normal( k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) @@ -432,13 +427,13 @@ class DotProductAttentionTest(jtu.JaxTestCase): ) out, (query_grad, key_grad, value_grad) = \ - jitted_sdpa_train(query, key, value, grad, None, None) + jitted_sdpa_train(query, key, value, grad) out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ - jitted_sdpa_train_ref(query, key, value, grad, None, None) - self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) - self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + jitted_sdpa_train_ref(query, key, value, grad) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) + self.assertArraysAllClose(query_grad_ref, query_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(key_grad_ref, key_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(value_grad_ref, value_grad, rtol=2e-1, atol=2e-1) @jtu.run_on_devices("cuda") def test_sdpa_broadcast_bias_and_dbias(self): @@ -472,9 +467,8 @@ class DotProductAttentionTest(jtu.JaxTestCase): qkv_sharding = NamedSharding(mesh, qkv_spec) bias_spec = PartitionSpec("tp", None, None) bias_sharding = NamedSharding(mesh, bias_spec) - replicated = NamedSharding(mesh, PartitionSpec()) in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, - qkv_sharding, bias_sharding, replicated) + qkv_sharding, bias_sharding) out_shardings = (qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding)) query = jax.device_put(query, qkv_sharding) key = jax.device_put(key, qkv_sharding) @@ -496,14 +490,14 @@ class DotProductAttentionTest(jtu.JaxTestCase): ) out, (query_grad, key_grad, value_grad, bias_grad) = \ - jitted_sdpa_train(query, key, value, grad, bias, None) + jitted_sdpa_train(query, key, value, grad, bias) out_ref, (query_grad_ref, key_grad_ref, value_grad_ref, bias_grad_ref) = \ - jitted_sdpa_train_ref(query, key, value, grad, bias, None) - self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) - self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(bias_grad_ref, bias_grad, rtol=1e-5, atol=1e-5) + jitted_sdpa_train_ref(query, key, value, grad, bias) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) + self.assertArraysAllClose(query_grad_ref, query_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(key_grad_ref, key_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(value_grad_ref, value_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(bias_grad_ref, bias_grad, rtol=2e-1, atol=2e-1) @jtu.sample_product( batch_size=[1, 16], @@ -573,13 +567,13 @@ class DotProductAttentionTest(jtu.JaxTestCase): ) out, (query_grad, key_grad, value_grad) = \ - jitted_sdpa_train(query, key, value, grad, None, None) + jitted_sdpa_train(query, key, value, grad) out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ - jitted_sdpa_train_ref(query, key, value, grad, None, None) - self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) - self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) - self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + jitted_sdpa_train_ref(query, key, value, grad) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) + self.assertArraysAllClose(query_grad_ref, query_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(key_grad_ref, key_grad, rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(value_grad_ref, value_grad, rtol=2e-1, atol=2e-1) @jtu.run_on_devices("cuda") def test_sdpa_large_head_size(self): @@ -607,12 +601,12 @@ class DotProductAttentionTest(jtu.JaxTestCase): sdpa_train_ref, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0) ) - out_ans, grads_ans = sdpa_train_ans(query, key, value, grad, None, None) - out_ref, grads_ref = sdpa_train_rfc(query, key, value, grad, None, None) + out_ans, grads_ans = sdpa_train_ans(query, key, value, grad) + out_ref, grads_ref = sdpa_train_rfc(query, key, value, grad) self.assertArraysAllClose(out_ref, out_ans) - self.assertArraysAllClose(grads_ref[0], grads_ans[0]) - self.assertArraysAllClose(grads_ref[1], grads_ans[1]) - self.assertArraysAllClose(grads_ref[2], grads_ans[2]) + self.assertArraysAllClose(grads_ref[0], grads_ans[0], rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(grads_ref[1], grads_ans[1], rtol=2e-1, atol=2e-1) + self.assertArraysAllClose(grads_ref[2], grads_ans[2], rtol=2e-1, atol=2e-1) @jtu.run_on_devices("cuda") def test_sdpa_packed_layout(self): @@ -679,7 +673,7 @@ class DotProductAttentionTest(jtu.JaxTestCase): kv_seqlen = q_seqlen.copy() mask = generate_padding_mask(segment_ids, q_seqlen.shape[1], query.shape, query.dtype) - bias = generate_segment_mask(segment_ids, query.dtype) + bias = generate_segment_mask(segment_ids, jnp.float32) devices = np.array(jax.local_devices()[:4]) devices = devices.reshape((2, 2)) @@ -757,8 +751,8 @@ class DotProductAttentionTest(jtu.JaxTestCase): value = jax.random.normal(k2, (B, S, N, H), dtype=dtype) grad = jax.random.normal(k3, (B, T, N, H), dtype=dtype) - btnh_fn = jax.jit(partial(sdpa_train_ref, scale=.5, - mask_type=MaskType.CAUSAL, dropout_rate=0.0)) + btnh_fn = jax.jit(partial(sdpa_train, scale=.5, + mask_type=MaskType.CAUSAL, is_bnth=False, dropout_rate=0.0)) out_ref, (dq_ref, dk_ref, dv_ref) = btnh_fn(query, key, value, grad) def _cvt(x): @@ -811,7 +805,7 @@ class DotProductAttentionF8Test(jtu.JaxTestCase): except RuntimeError as e: self.skipTest(str(e)) return - if cudnn_version < 91000: + if cudnn_version < 90100: self.skipTest("Requires >= cuDNN 9.1.0") if not jtu.is_cuda_compute_capability_at_least("9.0"): self.skipTest("Requires at least Hopper arch") @@ -877,7 +871,7 @@ class DotProductAttentionF8Test(jtu.JaxTestCase): fp8_metas, ) out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = ( - jitted_sdpa_train_ref(query, key, value, grad, None, None) + jitted_sdpa_train_ref(query, key, value, grad) ) self.assertArraysAllClose(out_ref, out.astype(dtype), rtol=5e-1, atol=5e-1) @@ -938,7 +932,7 @@ class DotProductAttentionF8Test(jtu.JaxTestCase): qkv_layout=qkv_layout, use_fp8=True, ) - return f_p(query, key, value, None, None, None, None, fp8_metas) + return f_p(query, key, value, fp8_params=fp8_metas) jitted_sdpa_inference = jax.jit( dot_product_attention_fp8, From 91cae595e427969251a79d1ab3d6d5392dd8e6a9 Mon Sep 17 00:00:00 2001 From: "H. Vetinari" Date: Sat, 22 Feb 2025 16:39:41 +1100 Subject: [PATCH 002/100] fix member access to packed CUDA struct --- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 312605df8..2f415912f 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -182,8 +182,9 @@ void callback_complete(CUcontext context, uint32_t streamId, // Convert integer nanoseconds to floating point milliseconds to match // the interface of the events-based profiler. double duration_ms = (kernel->end - kernel->start) / 1e6; + const char* kernel_name = kernel->name; profiler_state.timings.push_back( - std::make_tuple(kernel->name, duration_ms)); + std::make_tuple(kernel_name, duration_ms)); } } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { // no more records available From 5acfc88a0034ea134547cc04ff9067c858558b4f Mon Sep 17 00:00:00 2001 From: Klaus Greff Date: Wed, 26 Feb 2025 14:25:15 +0100 Subject: [PATCH 003/100] fix Initializer protocol --- jax/_src/nn/initializers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 660d18086..87f90a353 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -46,8 +46,8 @@ RealNumeric = Any # Scalar jnp array or float @export @typing.runtime_checkable class Initializer(Protocol): - @staticmethod - def __call__(key: Array, + def __call__(self, + key: Array, shape: core.Shape, dtype: DTypeLikeInexact = jnp.float_) -> Array: raise NotImplementedError From 56285aec6b7e6d41efd99544467acfd7033b6576 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Fri, 28 Feb 2025 06:17:14 +0000 Subject: [PATCH 004/100] Fixed printing order of results in jax.debug.print documentation. --- docs/debugging/print_breakpoint.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index 73ac02628..85580120c 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -91,8 +91,8 @@ def f(x): jax.debug.print("x: {}", x) return x jax.pmap(f)(xs) -# Prints: x: 1.0 -# x: 0.0 +# Prints: x: 0.0 +# x: 1.0 # OR # Prints: x: 1.0 # x: 0.0 From 9c18e8dcc12300fdf64a49b9f4e31029be437322 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Fri, 28 Feb 2025 12:32:00 +0530 Subject: [PATCH 005/100] Remove duplicate JAX version 0.4.37 heading in changelog --- CHANGELOG.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d5a01b780..f05936644 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -169,8 +169,6 @@ to signify this. This is a patch release of jax 0.4.36. Only "jax" was released at this version. -## jax 0.4.37 - * Bug fixes * Fixed a bug where `jit` would error if an argument was named `f` (#25329). * Fix a bug that will throw `index out of range` error in From 8b1b039e0dab7d391b5024496880b08656caa299 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 4 Mar 2025 10:31:35 -0500 Subject: [PATCH 006/100] Improve error messages when input argument resolution fails in custom_* APIs. --- jax/_src/api_util.py | 5 +- jax/_src/custom_batching.py | 10 +++- jax/_src/custom_dce.py | 10 +++- jax/_src/custom_derivatives.py | 21 +++++++- tests/api_test.py | 96 ++++++++++++++++++++++++++++++++++ 5 files changed, 136 insertions(+), 6 deletions(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 1fd371034..b9fd505e3 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -564,8 +564,9 @@ def resolve_kwargs(fun: Callable, args, kwargs) -> tuple[Any, ...]: passed_kwargs = [k for k in ba.kwargs if k in kwargs] if passed_kwargs: raise TypeError( - f"keyword arguments ({passed_kwargs}) could not be resolved to " - "positions") + "The following keyword arguments could not be resolved to positions: " + f"{', '.join(passed_kwargs)}" + ) return ba.args diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index a6ca6479c..338074837 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -143,7 +143,15 @@ class custom_vmap: def __call__(self, *args, **kwargs): debug_fun = api_util.debug_info("custom_vmap fun", self.fun, args, kwargs) - args = api_util.resolve_kwargs(self.fun, args, kwargs) + try: + args = api_util.resolve_kwargs(self.fun, args, kwargs) + except TypeError as e: + raise TypeError( + "The input arguments to the custom_vmap-decorated function " + f"{debug_fun.func_name} could not be resolved to positional-only " + f"arguments. Binding failed with the error:\n{e}" + ) from e + if not self.vmap_rule: raise AttributeError( f"No batching rule defined for custom_vmap function {debug_fun.func_name} " diff --git a/jax/_src/custom_dce.py b/jax/_src/custom_dce.py index 9166965b5..d336c969a 100644 --- a/jax/_src/custom_dce.py +++ b/jax/_src/custom_dce.py @@ -133,7 +133,15 @@ class custom_dce: debug_rule = api_util.debug_info("custom_dce_rule", self.dce_rule, args, {}, static_argnums=self.static_argnums) - args = api_util.resolve_kwargs(self.fun, args, kwargs) + try: + args = api_util.resolve_kwargs(self.fun, args, kwargs) + except TypeError as e: + raise TypeError( + "The input arguments to the custom_dce-decorated function " + f"{debug.func_name} could not be resolved to positional-only " + f"arguments. Binding failed with the error:\n{e}" + ) from e + if self.static_argnums: static_argnums = set(self.static_argnums) for i in static_argnums: diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 1cea84110..2bcdb7b5c 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -250,7 +250,15 @@ class custom_jvp(Generic[ReturnValue]): msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp." raise AttributeError(msg) - args = resolve_kwargs(self.fun, args, kwargs) + try: + args = resolve_kwargs(self.fun, args, kwargs) + except TypeError as e: + raise TypeError( + "The input arguments to the custom_jvp-decorated function " + f"{primal_name} could not be resolved to positional-only arguments. " + f"Binding failed with the error:\n{e}" + ) from e + if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) args = tuple(_stop_gradient(x) if i in nondiff_argnums else x @@ -634,7 +642,16 @@ class custom_vjp(Generic[ReturnValue]): if not self.fwd or not self.bwd: msg = f"No VJP defined for custom_vjp function {debug_fun.func_name} using defvjp." raise AttributeError(msg) - args = resolve_kwargs(self.fun, args, kwargs) + + try: + args = resolve_kwargs(self.fun, args, kwargs) + except TypeError as e: + raise TypeError( + "The input arguments to the custom_vjp-decorated function " + f"{debug_fun.func_name} could not be resolved to positional-only " + f"arguments. Binding failed with the error:\n{e}" + ) from e + debug_fwd = debug_info("custom_vjp fwd", self.fwd, args, kwargs, static_argnums=self.nondiff_argnums) # TODO(necula): figure out how to construct the debug_bwd args diff --git a/tests/api_test.py b/tests/api_test.py index e8fef8011..543335529 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -8104,6 +8104,29 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose( api.jvp(f1, (x, y), (0.0, 1.0)), (f1(x, y), -0.5 * jnp.sin(y))) + def test_resolve_kwargs_error_message(self): + @jax.custom_jvp + def f(x, y, *, z=None): + return jnp.sin(x), x + jnp.cos(y) + + @f.defjvp + def f_jvp(primals, tangents): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_jvp-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_jvp-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + class CustomVJPTest(jtu.JaxTestCase): @@ -9762,6 +9785,33 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose( api.grad(f1, argnums=(0, 1))(x, y), (1.5, -0.5 * jnp.sin(y))) + def test_resolve_kwargs_error_message(self): + @jax.custom_vjp + def f(x, y, *, z=None): + return jnp.sin(x), x + jnp.cos(y) + + def f_fwd(x, y): + self.fail("should not be executed") + + def f_bwd(res, cts): + self.fail("should not be executed") + + f.defvjp(f_fwd, f_bwd) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vjp-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vjp-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + def transpose_unary(f, x_example): def transposed(y): @@ -10490,6 +10540,29 @@ class CustomDceTest(jtu.JaxTestCase): self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), expected[0]) self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), expected[1]) + def test_resolve_kwargs_error_message(self): + @jax.experimental.custom_dce.custom_dce + def f(x, y, *, z=None): + return jnp.sin(x) * y, x * jnp.sin(y) + + @f.def_dce + def f_dce_rule(used_outs, x, y): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_dce-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_dce-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + class CustomVmapTest(jtu.JaxTestCase): @@ -11115,6 +11188,29 @@ class CustomVmapTest(jtu.JaxTestCase): out, f_vjp = jax.vjp(f, xs, y) f_vjp(out) # Doesn't crash. + def test_resolve_kwargs_error_message(self): + @jax.custom_batching.custom_vmap + def f(x, y, *, z=None): + return jnp.sin(x) * y + + @f.def_vmap + def f_vmap_rule(axis_size, in_batched, xs, ys): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vmap-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vmap-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" From 52ab8c4cc2920a000e7a00a7ca56f0a479604c40 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Mon, 3 Mar 2025 18:37:45 -0500 Subject: [PATCH 007/100] Fix detection of epath Unfortunately, the old detection code doesn't guarantee that `epath` is installed: ``` [utM] In [7]: importlib.util.find_spec("etils.epath") Out[7]: ModuleSpec(name='etils.epath', loader=<_frozen_importlib_external.SourceFileLoader object at 0x73b8492a7230>, origin='/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath/__init__.py', submodule_search_locations=['/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath']) [utM] In [8]: import etils.epath --------------------------------------------------------------------------- ModuleNotFoundError Traceback (most recent call last) Cell In[8], line 1 ----> 1 import etils.epath ... ModuleNotFoundError: No module named 'importlib_resources' ``` This happened every time I ran jax with a clean environment. --- jax/_src/compilation_cache_interface.py | 2 +- jax/_src/path.py | 41 ++++++++++++------------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/jax/_src/compilation_cache_interface.py b/jax/_src/compilation_cache_interface.py index 480457871..e0241d54b 100644 --- a/jax/_src/compilation_cache_interface.py +++ b/jax/_src/compilation_cache_interface.py @@ -15,8 +15,8 @@ from __future__ import annotations import abc +import pathlib -from jax._src import path as pathlib from jax._src import util diff --git a/jax/_src/path.py b/jax/_src/path.py index 8c46c5560..03a15e42e 100644 --- a/jax/_src/path.py +++ b/jax/_src/path.py @@ -12,35 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Protocol import logging +import os import pathlib -import importlib.util __all__ = ["Path"] - logger = logging.getLogger(__name__) +epath_installed: bool + +class PathProtocol(Protocol): + """A factory that creates a PurePath.""" + def __call__(self, *pathsegments: str | os.PathLike) -> pathlib.Path: + ... + +Path: PathProtocol # If etils.epath (aka etils[epath] to pip) is present, we prefer it because it # can read and write to, e.g., GCS buckets. Otherwise we use the builtin # pathlib and can only read/write to the local filesystem. -epath_installed = bool( - importlib.util.find_spec("etils") and - importlib.util.find_spec("etils.epath") -) -if epath_installed: - logger.debug("etils.epath found. Using etils.epath for file I/O.") - - def __dir__(): - return ["Path"] - - def __getattr__(name): - if name != "Path": - raise AttributeError(f"module '{__name__}' has no attribute '{name}") - - global Path - from etils import epath - Path = epath.Path - return Path -else: +try: + from etils import epath # type: ignore +except ImportError: logger.debug("etils.epath was not found. Using pathlib for file I/O.") Path = pathlib.Path + epath_installed = False +else: + logger.debug("etils.epath found. Using etils.epath for file I/O.") + # Ultimately, epath.Path implements pathlib.Path. See: + # https://github.com/google/etils/blob/2083f3d932a88d8a135ef57112cd1f9aff5d559e/etils/epath/abstract_path.py#L47 + Path = epath.Path + epath_installed = True From f0bbd26d0347423030d1d370bf91ff9c0a30a03b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 4 Mar 2025 10:17:51 -0800 Subject: [PATCH 008/100] Update array-api-tests to latest commit --- .github/workflows/jax-array-api.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 3600ad134..2b97c5a05 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'd982a6245400295477f5da5afa1c4a2a5e641ea4' # Latest commit as of 2025-01-30 + ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} From 8af6f70fe0b3f375f41323cc60aadd347f1d2209 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 4 Mar 2025 10:32:32 -0800 Subject: [PATCH 009/100] [JAX] Disable msan and asan for the profiler test running on nvidia gpu PiperOrigin-RevId: 733380848 --- tests/BUILD | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index d5d6c7be5..0ffa68ed8 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -838,6 +838,13 @@ jax_multiplatform_test( jax_multiplatform_test( name = "profiler_test", srcs = ["profiler_test.py"], + backend_tags = { + "gpu": [ + # disable suspicious leaking in cupti/cuda, + # TODO: remove this once b/372714955 is resolved. + "noasan", + ], + }, enable_backends = [ "cpu", "gpu", From 8cec6e636ad8de654830b52c123d9f6c97cc69b1 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 20 Feb 2025 12:31:31 -0800 Subject: [PATCH 010/100] jax.numpy ndim/shape/size: deprecate non-array input --- CHANGELOG.md | 2 + jax/_src/numpy/util.py | 125 +++++++++++++++++++++++++++++++++++++++- jax/numpy/__init__.py | 9 ++- jax/numpy/__init__.pyi | 6 +- tests/lax_numpy_test.py | 36 ++++++++++++ 5 files changed, 171 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d5a01b780..86d0cab0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. A downstream effect of this several other internal functions need debug info. This change does not affect public APIs. See https://github.com/jax-ml/jax/issues/26480 for more detail. + * In {func}`jax.numpy.ndim`, {func}`jax.numpy.shape`, and {func}`jax.numpy.size`, + non-arraylike inputs (such as lists, tuples, etc.) are now deprecated. * Bug fixes * TPU runtime startup and shutdown time should be significantly improved on diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 1db2e0bde..e281c63ae 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -26,7 +26,7 @@ from jax._src import dtypes from jax._src.lax import lax from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding -from jax._src.util import safe_zip, safe_map +from jax._src.util import safe_zip, safe_map, set_module from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape from jax.sharding import Sharding @@ -35,6 +35,8 @@ import numpy as np zip, unsafe_zip = safe_zip, zip map, unsafe_map = safe_map, map +export = set_module('jax.numpy') + _dtype = partial(dtypes.dtype, canonicalize=True) def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]: @@ -308,3 +310,124 @@ def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Shardin return SingleDeviceSharding(device) else: return device + + +@export +def ndim(a: ArrayLike) -> int: + """Return the number of dimensions of an array. + + JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function + raises a :class:`TypeError` if the input is a collection such as a list or + tuple. + + Args: + a: array-like object. + + Returns: + An integer specifying the number of dimensions of ``a``. + + Examples: + Number of dimensions for arrays: + + >>> x = jnp.arange(10) + >>> jnp.ndim(x) + 1 + >>> y = jnp.ones((2, 3)) + >>> jnp.ndim(y) + 2 + + This also works for scalars: + + >>> jnp.ndim(3.14) + 0 + + For arrays, this can also be accessed via the :attr:`jax.Array.ndim` property: + + >>> x.ndim + 1 + """ + # Deprecation warning added 2025-2-20. + check_arraylike("ndim", a, emit_warning=True) + return np.ndim(a) # NumPy dispatches to a.ndim if available. + + +@export +def shape(a: ArrayLike) -> tuple[int, ...]: + """Return the shape an array. + + JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function + raises a :class:`TypeError` if the input is a collection such as a list or + tuple. + + Args: + a: array-like object. + + Returns: + An tuple of integers representing the shape of ``a``. + + Examples: + Shape for arrays: + + >>> x = jnp.arange(10) + >>> jnp.shape(x) + (10,) + >>> y = jnp.ones((2, 3)) + >>> jnp.shape(y) + (2, 3) + + This also works for scalars: + + >>> jnp.shape(3.14) + () + + For arrays, this can also be accessed via the :attr:`jax.Array.shape` property: + + >>> x.shape + (10,) + """ + # Deprecation warning added 2025-2-20. + check_arraylike("shape", a, emit_warning=True) + return np.shape(a) # NumPy dispatches to a.shape if available. + + +@export +def size(a: ArrayLike, axis: int | None = None) -> int: + """Return number of elements along a given axis. + + JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function + raises a :class:`TypeError` if the input is a collection such as a list or + tuple. + + Args: + a: array-like object + axis: optional integer along which to count elements. By default, return + the total number of elements. + + Returns: + An integer specifying the number of elements in ``a``. + + Examples: + Size for arrays: + + >>> x = jnp.arange(10) + >>> jnp.size(x) + 10 + >>> y = jnp.ones((2, 3)) + >>> jnp.size(y) + 6 + >>> jnp.size(y, axis=1) + 3 + + This also works for scalars: + + >>> jnp.size(3.14) + 1 + + For arrays, this can also be accessed via the :attr:`jax.Array.size` property: + + >>> y.size + 6 + """ + # Deprecation warning added 2025-2-20. + check_arraylike("size", a, emit_warning=True) + return np.size(a, axis=axis) # NumPy dispatches to a.size if available. diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index d563483a2..ad71b9f74 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -254,6 +254,12 @@ from jax._src.numpy.tensor_contractions import ( vdot as vdot, ) +from jax._src.numpy.util import ( + ndim as ndim, + shape as shape, + size as size, +) + from jax._src.numpy.window_functions import ( bartlett as bartlett, blackman as blackman, @@ -279,15 +285,12 @@ from numpy import ( integer as integer, iterable as iterable, nan as nan, - ndim as ndim, newaxis as newaxis, number as number, object_ as object_, pi as pi, save as save, savez as savez, - shape as shape, - size as size, signedinteger as signedinteger, unsignedinteger as unsignedinteger, ) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index dee61c145..b73a3b95b 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -728,7 +728,7 @@ def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., ddof: int = 0, keepdims: builtins.bool = False, where: ArrayLike | None = ...) -> Array: ... ndarray = Array -ndim = _np.ndim +def ndim(a: ArrayLike) -> int: ... def negative(x: ArrayLike, /) -> Array: ... newaxis = None def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ... @@ -842,7 +842,7 @@ def setdiff1d( fill_value: ArrayLike | None = ..., ) -> Array: ... def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ... -shape = _np.shape +def shape(a: ArrayLike) -> tuple[int, ...]: ... def sign(x: ArrayLike, /) -> Array: ... def signbit(x: ArrayLike, /) -> Array: ... signedinteger = _np.signedinteger @@ -850,7 +850,7 @@ def sin(x: ArrayLike, /) -> Array: ... def sinc(x: ArrayLike, /) -> Array: ... single: Any def sinh(x: ArrayLike, /) -> Array: ... -size = _np.size +def size(a: ArrayLike, axis: int | None = None) -> int: ... def sort( a: ArrayLike, axis: int | None = ..., diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 50773e23b..98f10d9c0 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6140,6 +6140,42 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol, check_dtypes=False) + @jtu.sample_product( + shape=all_shapes, + dtype=default_dtypes, + op=['ndim', 'shape', 'size'], + ) + def testNdimShapeSize(self, shape, dtype, op): + rng = jtu.rand_default(self.rng()) + jnp_op = getattr(jnp, op) + np_op = getattr(np, op) + x = rng(shape, dtype) + expected = np_op(x) + self.assertEqual(expected, jnp_op(x)) # np.ndarray or scalar input. + self.assertEqual(expected, jnp_op(jnp.asarray(x))) # jax.Array input. + self.assertEqual(expected, jax.jit(jnp_op)(x)) # Traced input. + + @jtu.sample_product( + shape=nonzerodim_shapes, + dtype=default_dtypes, + ) + def testSizeAlongAxis(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + axis = self.rng().randint(-len(shape), len(shape)) + np_op = partial(np.size, axis=axis) + jnp_op = partial(jnp.size, axis=axis) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + op=[jnp.ndim, jnp.shape, jnp.size], + ) + def testNdimShapeSizeNonArrayInput(self, op): + msg = f"{op.__name__} requires ndarray or scalar arguments" + with self.assertWarnsRegex(DeprecationWarning, msg): + op([1, 2, 3]) + # Most grad tests are at the lax level (see lax_test.py), but we add some here # as needed for e.g. particular compound ops of interest. From d112c85e6d811c8158e34882e31c73f796709589 Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Tue, 4 Mar 2025 11:17:06 -0800 Subject: [PATCH 011/100] Internal config change PiperOrigin-RevId: 733398579 --- .bazelrc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.bazelrc b/.bazelrc index 8f9f910c0..f86af3a9b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -196,6 +196,9 @@ build:public_cache_push --config=public_cache --remote_upload_local_results=true # "oct2023" in the URL is just the date when the bucket was created and can be # disregarded. It still contains the latest cache that is being used. build:macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false +# This flag is to address mac arm64 nightly build failures, which are believed +# to be caused by cache poisoning after the Bazel 7.4.0 toolchain upgrade. +build:macos_cache --remote_default_platform_properties='properties:{name:"cache-silo-key" value:"cache-poisoning-2025-03-03"}' # Cache pushes are limited to JAX's CI system. build:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials From 43b6be0e81cc1add64f9c0ee0e7916ec1458a74f Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Tue, 4 Mar 2025 11:50:11 -0800 Subject: [PATCH 012/100] [Mosaic GPU] Add lowering for `log`, and a fast path using log2. PiperOrigin-RevId: 733411276 --- jax/_src/pallas/mosaic_gpu/lowering.py | 7 +++++++ .../mosaic/gpu/fragmented_array.py | 18 +++++++++++++++++ tests/pallas/mosaic_gpu_test.py | 20 +++++++++++++------ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index dcdfe62bb..e7331d2d5 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1386,6 +1386,13 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x): return a.exp2(approx=ctx.module_ctx.approx_math) +@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane) +def _log_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + a = _ensure_fa(x, x_aval.dtype) + return a.log(approx=ctx.module_ctx.approx_math) + + @register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane) def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 2ade6e848..a6e325bfb 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -966,6 +966,24 @@ class FragmentedArray: return self._pointwise(self._lift_fast_instr("ex2.approx.ftz.f32")) return self._pointwise(mlir_math.exp2) + def log(self, *, approx: bool = False): + if not ir.FloatType.isinstance(self.mlir_dtype): + raise NotImplementedError + if approx: + dtype = self.mlir_dtype + ln2 = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.6931471805599453)) + return self.log2(approx=True) * ln2 + return self._pointwise(mlir_math.log) + + def log2(self, *, approx: bool = False): + if not ir.FloatType.isinstance(self.mlir_dtype): + raise NotImplementedError(self.mlir_dtype) + if approx: + if not ir.F32Type.isinstance(self.mlir_dtype): + raise NotImplementedError(self.mlir_dtype) + return self._pointwise(self._lift_fast_instr("lg2.approx.ftz.f32")) + return self._pointwise(mlir_math.log2) + def sin(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 7dae608a4..74f9cc617 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -88,6 +88,7 @@ class PallasCallTest(PallasTest): ("square", lambda x: x ** 2), ("rsqrt", jax.lax.rsqrt), ("tanh", jax.lax.tanh, 1e-6), + ("log", jax.lax.log) ) def test_unary_op(self, unary, rtol=1e-7): @functools.partial( @@ -641,18 +642,25 @@ class PallasCallTest(PallasTest): x = jnp.arange(128).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + x.sum()*2) - @parameterized.parameters(False, True) - def test_rsqrt(self, approx_math): + @parameterized.named_parameters( + ("rsqrt", jax.lax.rsqrt, ), + ("log", jax.lax.log, 5e-7), + ("exp", jax.lax.exp, ), + ("exp2", jax.lax.exp2, 5e-7), + ("logistic", jax.lax.logistic, ), + ("tanh", jax.lax.tanh, 5e-7), + ) + def test_approx_math_unary_op(self, unary_op, rtol=1e-7): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math), + compiler_params=plgpu.GPUCompilerParams(approx_math=True), ) def kernel(x_ref, o_ref): - o_ref[...] = jax.lax.rsqrt(x_ref[...]) + o_ref[...] = unary_op(x_ref[...]) - x = jnp.arange(128).astype(jnp.float32) - np.testing.assert_allclose(kernel(x), jax.lax.rsqrt(x)) + x = jnp.arange(128).astype(jnp.float32) / 128 + np.testing.assert_allclose(kernel(x), unary_op(x), rtol=rtol, atol=1e-5) @parameterized.product(input_factor=[0.001, 1, 10, 100, 100]) def test_layer_norm(self, input_factor): From bb80a56898f10fc7074fe9633d848c6f5decb4cb Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 27 Feb 2025 17:22:29 -0800 Subject: [PATCH 013/100] Update setup.py to automatically pick up libtpu patch releases --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fa57bfa8c..cd283e5e7 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ _current_jaxlib_version = '0.5.1' # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.5.0' -_libtpu_version = '0.0.10' +_libtpu_version = '0.0.10.*' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( From ce224293b1a7d9b39b5d9194d429b54f38faf6fe Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 4 Mar 2025 12:50:43 -0800 Subject: [PATCH 014/100] Prepare for JAX release 0.5.2 (patch release over 0.5.1) --- CHANGELOG.md | 7 +++++++ jax/version.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e5c1903d1..07910c638 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,13 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> +## jax 0.5.2 (Mar 4, 2025) + +Patch release of 0.5.1 + +* Bug fixes + * Fixes TPU metric logging and `tpu-info`, which was broken in 0.5.1 + ## jax 0.5.1 (Feb 24, 2025) * New Features diff --git a/jax/version.py b/jax/version.py index 1aa049434..616950577 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import os import pathlib import subprocess -_version = "0.5.1" +_version = "0.5.2" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None From 721d1a32113f87454eb0df6e515926226fbd8691 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 4 Mar 2025 14:20:27 -0800 Subject: [PATCH 015/100] Add functionality to allow promoting RC wheels during release List of changes: 1. Allow us to build a RC wheel when building release artifacts. This is done by modifying the build CLI to use the new JAX build rule and passing in the build options that control the wheel tag. A new build argument `use_new_wheel_build_rule` is introduced to the build CLI to avoid breaking anyone that uses the CLI and the old build rule. Note that this option will go way in the future when the build CLI migrates fully to the new build rule. 2. Change the upload script to upload both rc and release tagged wheels (changes internal) PiperOrigin-RevId: 733464219 --- build/build.py | 75 +++++++++++++------ build/tools/utils.py | 10 ++- ci/build_artifacts.sh | 96 ++++++++++++++++--------- ci/envs/default.env | 9 +++ ci/utilities/setup_build_environment.sh | 3 + 5 files changed, 135 insertions(+), 58 deletions(-) diff --git a/build/build.py b/build/build.py index 0df7d646f..d38b911bb 100755 --- a/build/build.py +++ b/build/build.py @@ -63,6 +63,17 @@ WHEEL_BUILD_TARGET_DICT = { "jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", } +# Dictionary with the new wheel build rule. Note that when JAX migrates to the +# new wheel build rule fully, the build CLI will switch to the new wheel build +# rule as the default. +WHEEL_BUILD_TARGET_DICT_NEW = { + "jax": "//:jax_wheel", + "jaxlib": "//jaxlib/tools:jaxlib_wheel", + "jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel", + "jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel", + "jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel", + "jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel", +} def add_global_arguments(parser: argparse.ArgumentParser): """Adds all the global arguments that applies to all the CLI subcommands.""" @@ -147,6 +158,16 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): """, ) + parser.add_argument( + "--use_new_wheel_build_rule", + action="store_true", + help= + """ + Whether to use the new wheel build rule. Temporary flag and will be + removed once JAX migrates to the new wheel build rule fully. + """, + ) + parser.add_argument( "--editable", action="store_true", @@ -386,7 +407,10 @@ async def main(): for option in args.bazel_startup_options: bazel_command_base.append(option) - bazel_command_base.append("run") + if not args.use_new_wheel_build_rule or args.command == "requirements_update": + bazel_command_base.append("run") + else: + bazel_command_base.append("build") if args.python_version: # Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version @@ -592,13 +616,19 @@ async def main(): wheel_build_command_base.append("--config=cuda_libraries_from_stubs") with open(".jax_configure.bazelrc", "w") as f: - jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list()) + jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule) if not jax_configure_options: logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") sys.exit(1) f.write(jax_configure_options) logging.info("Bazel options written to .jax_configure.bazelrc") + if args.use_new_wheel_build_rule: + logging.info("Using new wheel build rule") + wheel_build_targets = WHEEL_BUILD_TARGET_DICT_NEW + else: + wheel_build_targets = WHEEL_BUILD_TARGET_DICT + if args.configure_only: logging.info("--configure_only is set so not running any Bazel commands.") else: @@ -611,7 +641,7 @@ async def main(): if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: wheel = "jax-" + wheel - if wheel not in WHEEL_BUILD_TARGET_DICT.keys(): + if wheel not in wheel_build_targets.keys(): logging.error( "Incorrect wheel name provided, valid choices are jaxlib," " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt," @@ -629,32 +659,33 @@ async def main(): ) # Append the build target to the Bazel command. - build_target = WHEEL_BUILD_TARGET_DICT[wheel] + build_target = wheel_build_targets[wheel] wheel_build_command.append(build_target) - wheel_build_command.append("--") + if not args.use_new_wheel_build_rule: + wheel_build_command.append("--") - if args.editable: - logger.info("Building an editable build") - output_path = os.path.join(output_path, wheel) - wheel_build_command.append("--editable") + if args.editable: + logger.info("Building an editable build") + output_path = os.path.join(output_path, wheel) + wheel_build_command.append("--editable") - wheel_build_command.append(f'--output_path="{output_path}"') - wheel_build_command.append(f"--cpu={target_cpu}") + wheel_build_command.append(f'--output_path="{output_path}"') + wheel_build_command.append(f"--cpu={target_cpu}") - if "cuda" in wheel: - wheel_build_command.append("--enable-cuda=True") - if args.cuda_version: - cuda_major_version = args.cuda_version.split(".")[0] - else: - cuda_major_version = args.cuda_major_version - wheel_build_command.append(f"--platform_version={cuda_major_version}") + if "cuda" in wheel: + wheel_build_command.append("--enable-cuda=True") + if args.cuda_version: + cuda_major_version = args.cuda_version.split(".")[0] + else: + cuda_major_version = args.cuda_major_version + wheel_build_command.append(f"--platform_version={cuda_major_version}") - if "rocm" in wheel: - wheel_build_command.append("--enable-rocm=True") - wheel_build_command.append(f"--platform_version={args.rocm_version}") + if "rocm" in wheel: + wheel_build_command.append("--enable-rocm=True") + wheel_build_command.append(f"--platform_version={args.rocm_version}") - wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") + wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log) # Exit with error if any wheel build fails. diff --git a/build/tools/utils.py b/build/tools/utils.py index e91b2d424..7e3751698 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -213,11 +213,15 @@ def get_gcc_major_version(gcc_path: str): return major_version -def get_jax_configure_bazel_options(bazel_command: list[str]): +def get_jax_configure_bazel_options(bazel_command: list[str], use_new_wheel_build_rule: bool): """Returns the bazel options to be written to .jax_configure.bazelrc.""" # Get the index of the "run" parameter. Build options will come after "run" so - # we find the index of "run" and filter everything after it. - start = bazel_command.index("run") + # we find the index of "run" and filter everything after it. If we are using + # the new wheel build rule, we will find the index of "build" instead. + if use_new_wheel_build_rule: + start = bazel_command.index("build") + else: + start = bazel_command.index("run") jax_configure_bazel_options = "" try: for i in range(start + 1, len(bazel_command)): diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 3cc1fa0c5..84b8d35a2 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -45,52 +45,82 @@ if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then arch="amd64" fi +# Determine the artifact tag flags based on the artifact type. A release +# wheel is tagged with the release version (e.g. 0.5.1), a nightly wheel is +# tagged with the release version and a nightly suffix that contains the +# current date (e.g. 0.5.2.dev20250227), and a default wheel is tagged with +# the git commit hash of the HEAD of the current branch and the date of the +# commit (e.g. 0.5.1.dev20250128+3e75e20c7). +if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then + artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=release" +elif [[ "$JAXCI_ARTIFACT_TYPE" == "nightly" ]]; then + current_date=$(date +%Y%m%d) + artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=${current_date} --bazel_options=--repo_env=ML_WHEEL_TYPE=nightly" +elif [[ "$JAXCI_ARTIFACT_TYPE" == "default" ]]; then + artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=custom --bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD) --bazel_options=--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD) --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)" +else + echo "Error: Invalid artifact type: $JAXCI_ARTIFACT_TYPE. Allowed values are: release, nightly, default" + exit 1 +fi + if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then + # Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_" + # flags in the .bazelrc depending upon the platform we are building for. + bazelrc_config="${os}_${arch}" - # Build the jax artifact - if [[ "$artifact" == "jax" ]]; then - python -m build --outdir $JAXCI_OUTPUT_DIR + # On platforms with no RBE support, we can use the Bazel remote cache. Set + # it to be empty by default to avoid unbound variable errors. + bazel_remote_cache="" + + if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then + bazelrc_config="rbe_${bazelrc_config}" else + bazelrc_config="ci_${bazelrc_config}" - # Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_" - # flags in the .bazelrc depending upon the platform we are building for. - bazelrc_config="${os}_${arch}" - - # On platforms with no RBE support, we can use the Bazel remote cache. Set - # it to be empty by default to avoid unbound variable errors. - bazel_remote_cache="" - - if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then - bazelrc_config="rbe_${bazelrc_config}" + # Set remote cache flags. Pushes to the cache bucket is limited to JAX's + # CI system. + if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then + bazel_remote_cache="--bazel_options=--config=public_cache_push" else - bazelrc_config="ci_${bazelrc_config}" - - # Set remote cache flags. Pushes to the cache bucket is limited to JAX's - # CI system. - if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then - bazel_remote_cache="--bazel_options=--config=public_cache_push" - else - bazel_remote_cache="--bazel_options=--config=public_cache" - fi + bazel_remote_cache="--bazel_options=--config=public_cache" fi + fi - # Use the "_cuda" configs when building the CUDA artifacts. - if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then - bazelrc_config="${bazelrc_config}_cuda" - fi + # Use the "_cuda" configs when building the CUDA artifacts. + if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then + bazelrc_config="${bazelrc_config}_cuda" + fi - # Build the artifact. + # Build the artifact. + python build/build.py build --wheels="$artifact" \ + --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ + --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ + --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + $artifact_tag_flags + + # If building release artifacts, we also build a release candidate ("rc") + # tagged wheel. + if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then python build/build.py build --wheels="$artifact" \ --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ - --verbose --detailed_timestamped_log + --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + $artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION" + fi - # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we - # run `auditwheel show` to verify manylinux compliance. - if [[ "$os" == "linux" ]]; then - ./ci/utilities/run_auditwheel.sh - fi + # Move the built artifacts from the Bazel cache directory to the output + # directory. + if [[ "$artifact" == "jax" ]]; then + mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR" + mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR" + else + mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR" + fi + # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we + # run `auditwheel show` to verify manylinux compliance. + if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then + ./ci/utilities/run_auditwheel.sh fi else diff --git a/ci/envs/default.env b/ci/envs/default.env index 72646113e..66578efac 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -50,6 +50,15 @@ export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} # flag is enabled only for CI builds. export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0} +# Type of artifacts to build. Valid values are "default", "release", "nightly". +# This affects the wheel naming/tag. +export JAXCI_ARTIFACT_TYPE=${JAXCI_ARTIFACT_TYPE:-"default"} + +# When building release artifacts, we build a release candidate wheel ("rc" +# tagged wheel) in addition to the release wheel. This environment variable +# sets the version of the release candidate ("RC") artifact to build. +export JAXCI_WHEEL_RC_VERSION=${JAXCI_WHEEL_RC_VERSION:-} + # ############################################################################# # Test script specific environment variables. # ############################################################################# diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh index 8e4039414..114acf247 100644 --- a/ci/utilities/setup_build_environment.sh +++ b/ci/utilities/setup_build_environment.sh @@ -98,3 +98,6 @@ function retry { # Retry "bazel --version" 3 times to avoid flakiness when downloading bazel. retry "bazel --version" + +# Create the output directory if it doesn't exist. +mkdir -p "$JAXCI_OUTPUT_DIR" \ No newline at end of file From 1a19d5594a906e3fed0357043ffdc9f526d2b021 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 4 Mar 2025 15:54:40 -0800 Subject: [PATCH 016/100] Update all uses of `@tsl//third_party` to `@xla//third_party` PiperOrigin-RevId: 733495240 --- BUILD.bazel | 2 +- WORKSPACE | 12 ++++++------ jaxlib/tools/BUILD.bazel | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 617e39e73..441f689e3 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@tsl//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") +load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") load( "//jaxlib:jax.bzl", "jax_wheel", diff --git a/WORKSPACE b/WORKSPACE index 8c4f49ecf..129488281 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -70,7 +70,7 @@ jax_python_wheel_repository( ) load( - "@tsl//third_party/py:python_wheel.bzl", + "@xla//third_party/py:python_wheel.bzl", "python_wheel_version_suffix_repository", ) python_wheel_version_suffix_repository( @@ -78,7 +78,7 @@ python_wheel_version_suffix_repository( ) load( - "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "@xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", ) @@ -90,7 +90,7 @@ load( "CUDNN_REDISTRIBUTIONS", ) load( - "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "@xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", "cuda_redist_init_repositories", "cudnn_redist_init_repository", ) @@ -104,21 +104,21 @@ cudnn_redist_init_repository( ) load( - "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "@xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure", ) cuda_configure(name = "local_config_cuda") load( - "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "@xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", "nccl_redist_init_repository", ) nccl_redist_init_repository() load( - "@tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "@xla//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure", ) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index b95483b22..6eab64823 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -19,7 +19,7 @@ load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load( - "@tsl//third_party/py:py_manylinux_compliance_test.bzl", + "@xla//third_party/py:py_manylinux_compliance_test.bzl", "verify_manylinux_compliance_test", ) load( From cdeeacabcf2e32d4ced4cf278b26d08f4d342bfd Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Tue, 4 Mar 2025 18:30:34 -0800 Subject: [PATCH 017/100] Update references to JAX's GitHub repo JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 733536104 --- docs/about.md | 2 +- jax/_src/custom_derivatives.py | 2 +- jax/experimental/colocated_python/__init__.py | 2 +- tests/lax_control_flow_test.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/about.md b/docs/about.md index c4bc93140..58e170384 100644 --- a/docs/about.md +++ b/docs/about.md @@ -10,7 +10,7 @@ DeepMind](https://deepmind.google/), Alphabet more broadly, and elsewhere. At the heart of the project is the [JAX -core](http://github.com/google/jax) library, which focuses on the +core](http://github.com/jax-ml/jax) library, which focuses on the fundamentals of machine learning and numerical computing, at scale. When [developing](#development) the core, we want to maintain agility diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 2bcdb7b5c..32856106a 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1255,7 +1255,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]: def _maybe_perturbed(x: Any) -> bool: # False if x can't represent an AD-perturbed value (i.e. a value # with a nontrivial tangent attached), up to heuristics, and True otherwise. - # See https://github.com/google/jax/issues/6415 for motivation. + # See https://github.com/jax-ml/jax/issues/6415 for motivation. if not isinstance(x, core.Tracer): # If x is not a Tracer, it can't be perturbed. return False diff --git a/jax/experimental/colocated_python/__init__.py b/jax/experimental/colocated_python/__init__.py index 5bd56e732..2d387b37c 100644 --- a/jax/experimental/colocated_python/__init__.py +++ b/jax/experimental/colocated_python/__init__.py @@ -14,7 +14,7 @@ """Colocated Python API.""" # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 # pylint: disable=useless-import-alias from jax.experimental.colocated_python.api import ( diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 4c61426c6..15fc37805 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2120,7 +2120,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash def testIssue804(self): - # https://github.com/google/jax/issues/804 + # https://github.com/jax-ml/jax/issues/804 num_devices = jax.device_count() f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.) jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash From 766315f79150c008c1b71184ca04398e96a1acc9 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 4 Mar 2025 18:34:34 -0800 Subject: [PATCH 018/100] Make sure concat + vmap of sharded input and replicated input works properly. In this case, the example boils down to: ``` inp1 = f32[16@x, 4] inp2 = f32[4] def f(x: f32[4], y: f32[4]) return jnp.concat([x, y], axis=-1) vmap(f, in_axes=(0, None))(inp1) ``` This example was breaking in concat batching rule because we didn't broadcast with the right sharding. PiperOrigin-RevId: 733536944 --- jax/_src/lax/lax.py | 6 +++++- tests/pjit_test.py | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index adc632a82..5dc5abdf9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5802,8 +5802,12 @@ def _concatenate_transpose_rule(t, *operands, dimension): def _concatenate_batch_rule(batched_args, batch_dims, *, dimension): size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims) if bdim is not None) + spec = next(core.get_aval(op).sharding.spec[bdim] + for op, bdim in zip(batched_args, batch_dims) if bdim is not None) operands = [batching.moveaxis(op, bdim, 0) if bdim is not None - else broadcast(op, (size,)) + else broadcast( + op, (size,), out_sharding=core.get_aval(op).sharding.with_spec( + (spec, *core.get_aval(op).sharding.spec))) for op, bdim in zip(batched_args, batch_dims)] return concatenate(operands, dimension + 1), 0 diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0cb3f9d28..7a7f6b7d6 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6240,6 +6240,30 @@ class ShardingInTypesTest(jtu.JaxTestCase): out2 = core.jaxpr_as_fun(jaxpr)(arr) self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) + @jtu.with_user_mesh((4,), ('x',)) + def test_concat_vmap(self, mesh): + @jax.jit + def _f(sharded_array, replicated_array): + def _single_array(a, b): + return jnp.concatenate([a, b], axis=-1) + + _first_vmap = jax.vmap(_single_array, in_axes=(None, 0)) + _second_vmap = jax.vmap(_first_vmap, in_axes=(0, None)) + return jax.vmap(_second_vmap, in_axes=(0, None))(sharded_array, replicated_array) + + np_inp = np.ones((4 * 4, 10, 5, 4)) + arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + arr2 = jax.device_put( + jnp.ones((10, 5, 3)), NamedSharding(mesh, P())) + + out = _f(arr1, arr2) + self.assertEqual(out.sharding, + NamedSharding(mesh, P('x', None, None, None, None))) + + out = _f(arr1, jnp.ones((10, 5, 3))) + self.assertEqual(out.sharding, + NamedSharding(mesh, P('x', None, None, None, None))) + @jtu.with_user_mesh((2, 2), ('x', 'y'), axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')}) def test_full_user_mode(self, mesh): From cebedb9f1a3e0a26b2cc69f81c4202b721e7b067 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 4 Mar 2025 18:49:12 -0800 Subject: [PATCH 019/100] Update version number after 0.5.2 release --- jax/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/version.py b/jax/version.py index 616950577..be20aca06 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import os import pathlib import subprocess -_version = "0.5.2" +_version = "0.5.3" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None From 51719a1afe559d0e2b44120dcf21076831b0704f Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Tue, 4 Mar 2025 19:40:14 -0800 Subject: [PATCH 020/100] [mgpu] Non-vector untiled stores for tiling layouts. Useful for storing in memrefs where the minormost stride is >1. PiperOrigin-RevId: 733551038 --- .../mosaic/gpu/fragmented_array.py | 34 +++++++++++++++---- tests/mosaic/gpu_test.py | 14 ++++++-- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index a6e325bfb..bba93b5de 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1441,19 +1441,27 @@ class FragmentedArray: if create_array: return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) - def store_untiled(self, ref: ir.Value): + def store_untiled(self, ref: ir.Value, *, vector_store: bool = True): if not ir.MemRefType.isinstance(ref.type): raise ValueError(ref) + def vs_unsupported(): + if not vector_store: + raise NotImplementedError( + f"Can't use non-vector stores with layout {self.layout}" + ) + match self.layout: case WGMMAFragLayout(): self._store_untiled_wgmma(ref) case WGSplatFragLayout(): + vs_unsupported() self._store_untiled_splat(ref) case WGStridedFragLayout(): + vs_unsupported() self._store_untiled_wg_strided(ref) case TiledLayout(): - self._store_untiled_tiled(ref) + self._store_untiled_tiled(ref, vector_store=vector_store) case _: raise NotImplementedError(self.layout) @@ -1520,7 +1528,7 @@ class FragmentedArray: col = arith.addi(col_base, c(col_tile * 8 + col_idx)) memref.store(value, ref, [row, col]) - def _store_untiled_tiled(self, ref: ir.Value): + def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): """Stores an array with a tiled layout. Not optimized at the moment.""" if utils.bitwidth(self.mlir_dtype) < 8: raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})") @@ -1528,7 +1536,7 @@ class FragmentedArray: layout = self.layout assert isinstance(layout, TiledLayout) ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset() - if ref_strides[layout.vector_dim] != 1: + if vector_store and ref_strides[layout.vector_dim] != 1: raise NotImplementedError( "Can't use vector stores with non-unit minormost stride" ) @@ -1549,9 +1557,21 @@ class FragmentedArray: ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype) # All warp tile offsets are static and can be fused into the store. for tile_idx, reg in np.ndenumerate(self.registers): - lin_idx = sum(i * s for i, s in zip(tile_idx, strides, strict=True)) - reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype) - llvm.store(reg, reg_ptr) + if vector_store: + elems = [reg] + else: + index = ir.IndexType.get() + elems = [ + vector.extractelement(reg, position=c(i, index)) + for i in range(ir.VectorType(reg.type).shape[0]) + ] + for i, e in enumerate(elems): + tile_idx_local = list(tile_idx) + tile_idx_local[layout.vector_dim] += i + tile_idx_local = list(tile_idx_local) + lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True)) + reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype) + llvm.store(e, reg_ptr) def store_tiled(self, ref, swizzle: int | None): match self.layout: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f3d94d917..02cd26c15 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -485,12 +485,20 @@ def get_packed_shape(strides, shape): class WGMMALayoutTest(TestCase): - @parameterized.named_parameters(("f32", jnp.float32), ("f16", jnp.float16)) - def test_store_untiled(self, dtype): + @parameterized.product(dtype=[jnp.float16, jnp.float32], + tiled_layout=[False, True], + transposed_smem=[False, True]) + def test_store_untiled(self, dtype, tiled_layout, transposed_smem): def kernel(ctx, out, _): del ctx - iota_tensor(64, 64, dtype).store_untiled(out) + if transposed_smem: + out = memref_transpose(out, (1, 0)) + iota_tensor(64, 64, dtype, tiled_layout=tiled_layout).store_untiled( + out, vector_store=not transposed_smem + ) expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64) + if transposed_smem: + expected = expected.T iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() From 06b760eea2e2d497328e6e1f505280758f9e979c Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 5 Mar 2025 01:43:06 -0800 Subject: [PATCH 021/100] Update XLA dependency to use revision http://github.com/openxla/xla/commit/e0e56a1190d5be336f7d3e308457349fbfacebd2. PiperOrigin-RevId: 733636860 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e98a250db..569518a67 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "2274501a951e52a4fa32d65136f467d35c8950b9" -XLA_SHA256 = "809ebf3ee4e6271d16d73ec2f37a7f61f2b8248767935ade327f60352c459d0b" +XLA_COMMIT = "e0e56a1190d5be336f7d3e308457349fbfacebd2" +XLA_SHA256 = "cf64ab04fa47dd86bb222e87dc573d083e869d211c4ee806786fc4da17c1dafc" def repo(): tf_http_archive( From 342cb7b99a09180472823a33c7cdad8a8db77875 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 5 Mar 2025 02:05:49 -0800 Subject: [PATCH 022/100] Attempt 2 at landing custom_vjp.optimize_remat using custom_dce. The original change was rolled back because there were real world use cases of custom_vjp where the fwd function had the wrong signature. To preserve backwards compatibility, we shouldn't resolve the input arguments to fwd using fwds signature. Instead, we can just ignore the signature because custom_vjp handles the resolution before we ever get here. Reverts 1f3176636d304398b00a7d2cb0933859618affd8 PiperOrigin-RevId: 733643149 --- CHANGELOG.md | 3 + jax/_src/custom_derivatives.py | 253 ++++-------------------------- jax/custom_derivatives.py | 1 - jax/experimental/jax2tf/jax2tf.py | 12 +- tests/api_test.py | 78 ++++++--- 5 files changed, 97 insertions(+), 250 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 86d0cab0f..65f1dafa8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. true, matching the current behavior. If set to false, JAX does not need to emit code clamping negative indices, which improves code size. +* Breaking changes + * The ``jax.custom_derivatives.remat_opt_p`` helper primitive was removed. + ## jax 0.5.1 (Feb 24, 2025) * New Features diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 32856106a..579086f36 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -16,7 +16,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence import dataclasses -from functools import update_wrapper, reduce, partial, wraps +from functools import update_wrapper, reduce, partial from typing import Any, Generic, TypeVar from jax._src import config @@ -32,6 +32,7 @@ from jax._src.ad_util import ( from jax._src.api_util import ( argnums_partial, flatten_fun_nokwargs, resolve_kwargs, prepend_static_args, debug_info) +from jax._src.custom_dce import custom_dce from jax._src.errors import UnexpectedTracerError from jax._src.state.types import AbstractRef from jax._src.interpreters import ad @@ -657,10 +658,12 @@ class custom_vjp(Generic[ReturnValue]): # TODO(necula): figure out how to construct the debug_bwd args debug_bwd = debug_info("custom_vjp bwd", self.bwd, args, {}) if self.optimize_remat: + if self.symbolic_zeros: + # TODO(dfm): This probably shouldn't be too hard to support. + raise NotImplementedError( + "remat optimization for custom_vjp does not support symbolic zeros") fwd = optimize_remat_of_custom_vjp_fwd( - self.fun, debug_fun, self.fwd, debug_fwd, - nondiff_argnums=self.nondiff_argnums, - symbolic_zeros=self.symbolic_zeros) + self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums) else: fwd = self.fwd if config.enable_custom_vjp_by_custom_transpose.value: @@ -1571,229 +1574,31 @@ custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr") # simpler, but it would be worth revisiting this. def optimize_remat_of_custom_vjp_fwd( fun: Callable[..., ReturnValue], - debug_fun: core.DebugInfo, fwd: Callable[..., tuple[ReturnValue, Any]], - debug_fwd: core.DebugInfo, nondiff_argnums: Sequence[int] = (), - symbolic_zeros: bool = False, ) -> Callable[..., tuple[ReturnValue, Any]]: - if symbolic_zeros: - # TODO(dfm): This probably shouldn't be too hard to support. - raise NotImplementedError( - "remat optimization for custom_vjp does not support symbolic zeros") + wrapped_fwd = custom_dce( + # It might seem like we don't need this lambda, but there are some real + # world use cases where the signature of `fwd` is wrong, and we shouldn't + # error out when resolving the arguments in those cases. This is fine, + # because the arguments have already been resolved in custom_vjp. + lambda *args: fwd(*args), # pylint: disable=unnecessary-lambda + static_argnums=nondiff_argnums, + ) - @wraps(fwd) - def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: - # TODO(dfm): This initial logic is duplicated from custom_vjp.__call__ - # above and it would be good to consolidate it. - fwd_name = debug_fwd.func_name if debug_fwd else str(fwd) - # Note: we use `fun` instead of `fwd` here for consistency with - # custom_vjp.__call__ above. - args = resolve_kwargs(fun, args, kwargs) - if nondiff_argnums: - for i in nondiff_argnums: _check_for_tracers(args[i]) - nondiff_argnums_ = set(nondiff_argnums) - dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums_] - f_, dyn_args = argnums_partial(lu.wrap_init(fun, debug_info=debug_fun), - dyn_argnums, - args, require_static_args_hashable=False) - fwd_, _ = argnums_partial(lu.wrap_init(fwd, debug_info=debug_fwd), - dyn_argnums, args, - require_static_args_hashable=False) - else: - f_, dyn_args = lu.wrap_init(fun, debug_info=debug_fun), args - fwd_ = lu.wrap_init(fwd, debug_info=debug_fwd) - args_flat, in_tree = tree_flatten(dyn_args) - flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) - flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False, - debug_fun, debug_fwd, in_tree, out_type) - flat_fwd = _fix_fwd_args(flat_fwd) - - in_avals = [core.get_aval(x) for x in args_flat] - fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals) - fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr)) - prim_tree, res_tree = out_trees() - num_res = res_tree.num_leaves - - if fwd_jaxpr.effects: - raise NotImplementedError( - "remat optimization for custom_vjp does not support forward " - f"functions with side effects, but {fwd_name} has the following " - f"effects: {fwd_jaxpr.effects}") - - @pe._memoize - def fun_jaxpr_thunk(): - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) - return jaxpr, consts - - out_flat = remat_opt_p.bind(*consts, *args_flat, - num_consts=len(consts), - num_res=num_res, - fwd_jaxpr=fwd_jaxpr, - fun_jaxpr_thunk=fun_jaxpr_thunk) - res, out_flat = split_list(out_flat, [num_res]) - out_tree = treedef_tuple((prim_tree, res_tree)) - return tree_unflatten(out_tree, (*out_flat, *res)) + @wrapped_fwd.def_dce + def _(*args): + static_args, used_outs, args = split_list(args, [len(nondiff_argnums), 1]) + static_args_iter = iter(static_args) + args_iter = iter(args) + nondiff_argnums_ = set(nondiff_argnums) + fun_args = [ + next(static_args_iter) if i in nondiff_argnums_ else next(args_iter) + for i in range(len(static_args) + len(args))] + used_outs, = used_outs + _, used_res = used_outs + if any(tree_leaves(used_res)): + return fwd(*fun_args) + return fun(*fun_args), None return wrapped_fwd - -@lu.transformation2 -def _fix_fwd_args(f, *args): - args = [(x, True) for x in args] - args = [x for pair in args for x in pair] - return f(*args) - -def _remat_opt_impl( - *args, - num_consts: int, - num_res: int, - fwd_jaxpr: core.ClosedJaxpr, - fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], -): - del num_consts, num_res, fun_jaxpr_thunk # unused - return core.jaxpr_as_fun(fwd_jaxpr)(*args) - -def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_): - del args - return fwd_jaxpr.out_avals, fwd_jaxpr.effects - -def _remat_opt_vmap( - axis_data, args, in_dims, - *, - num_consts: int, - num_res: int, - fwd_jaxpr: core.ClosedJaxpr, - fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], -): - args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 - else x for x, d in zip(args, in_dims)] - in_batched = [d is not not_mapped for d in in_dims] - batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( - fwd_jaxpr, axis_data, in_batched, False) - extra_consts = batched_fwd_jaxpr.consts - batched_fwd_jaxpr = pe.close_jaxpr( - pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr)) - out_dims = [0 if b else not_mapped for b in out_batched] - - _, prim_batched = split_list(in_batched, [num_consts]) - - @pe._memoize - def batched_fun_jaxpr_thunk(): - fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) - batched_fun_jaxpr, out_batched = batching.batch_jaxpr( - fun_jaxpr, axis_data, prim_batched, False) - return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts - - batched_outs = remat_opt_p.bind(*extra_consts, *args, - num_consts=num_consts + len(extra_consts), - num_res=num_res, - fwd_jaxpr=batched_fwd_jaxpr, - fun_jaxpr_thunk=batched_fun_jaxpr_thunk) - - return batched_outs, out_dims - -def _remat_opt_jvp( - primals, - tangents, - *, - num_consts: int, - num_res: int, - fwd_jaxpr: core.ClosedJaxpr, - fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], -): - consts, primals = split_list(primals, [num_consts]) - consts_dot, tangents = split_list(tangents, [num_consts]) - # Tangents must be instantated in case we end up DCEing later. - tangents = map(ad.instantiate_zeros, tangents) - consts_nz = [not isinstance(t, Zero) for t in consts_dot] - consts_dot = [c for nz, c in zip(consts_nz, consts_dot) if nz] - in_nz = consts_nz + [True] * len(tangents) - fwd_jaxpr_jvp_, out_nz = ad.jvp_jaxpr(fwd_jaxpr, in_nz, True) - num_out = len(out_nz) - num_res - fwd_jaxpr_jvp_ = ad.rearrange_binders( - fwd_jaxpr_jvp_, [num_consts, len(primals)], - [len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out]) - fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr)) - - # @pe._memoize - def fun_jvp_jaxpr_thunk(): - fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) - in_nz = [True] * len(primals) - fun_jvp_jaxpr, _ = ad.jvp_jaxpr(fun_jaxpr, in_nz, True) - return fun_jvp_jaxpr.jaxpr, fun_jvp_jaxpr.consts - - new_num_consts = len(fwd_jaxpr_jvp_.consts) + num_consts + len(consts_dot) - outs = remat_opt_p.bind(*fwd_jaxpr_jvp_.consts, *consts, *consts_dot, - *primals, *tangents, num_consts=new_num_consts, - num_res=2 * num_res, fwd_jaxpr=fwd_jaxpr_jvp, - fun_jaxpr_thunk=fun_jvp_jaxpr_thunk) - res, res_dot, outs, outs_dot = split_list(outs, [num_res, num_res, num_out]) - return (*res, *outs), (*res_dot, *outs_dot) - -def _remat_opt_transpose( - cts, *args, - num_consts: int, - num_res: int, - fwd_jaxpr: core.ClosedJaxpr, - fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], -): - # TODO(dfm): It shouldn't be too hard to implement this as needed in the - # future. - raise NotImplementedError( - "remat optimization for custom_vjp does not support higher-order AD") - -def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): - if not any(used_outs) and not pe.has_effects(eqn): - return [False] * len(eqn.invars), None - used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]]) - outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] - if any(used_res): - # If any of the residuals are used, we still need to run fwd at this point, - # but we may end up DCEing again in the future, so we must instantiate all - # the input primals. - instantiate = [False] * eqn.params["num_consts"] - instantiate += [True] * (len(eqn.invars) - eqn.params["num_consts"]) - new_jaxpr, used_ins = pe.dce_jaxpr(eqn.params["fwd_jaxpr"].jaxpr, used_outs, - instantiate=instantiate) - assert not new_jaxpr.constvars - closed_jaxpr = pe.close_jaxpr(new_jaxpr) - invars = [v for used, v in zip(used_ins, eqn.invars) if used] - new_params = dict(eqn.params) - new_num_consts = sum(split_list(used_ins, [eqn.params["num_consts"]])[0]) - new_params["num_consts"] = new_num_consts - new_params["fwd_jaxpr"] = closed_jaxpr - new_params["num_res"] = sum(used_res) - new_eqn = pe.new_jaxpr_eqn( - invars, outvars, remat_opt_p, new_params, closed_jaxpr.effects, - eqn.source_info, eqn.ctx) - return used_ins, new_eqn - else: - # If none of the residuals are used, we run the primal computation instead. - # At this point we drop this custom DCE behavior, but since the primal might - # have different consts than fwd, we build a new JaxprEqn with a closed_call - # primitive. - fun_jaxpr, consts = eqn.params["fun_jaxpr_thunk"]() - new_jaxpr, used_consts, used_ins = pe.dce_jaxpr_consts(fun_jaxpr, used_prims) - consts = [c for used, c in zip(used_consts, consts) if used] - closed_jaxpr = core.ClosedJaxpr(new_jaxpr, consts) - _, invars = split_list(eqn.invars, [eqn.params["num_consts"]]) - invars = [v for used, v in zip(used_ins, invars) if used] - new_eqn = pe.new_jaxpr_eqn( - invars, outvars, core.closed_call_p, dict(call_jaxpr=closed_jaxpr), - closed_jaxpr.effects, eqn.source_info, eqn.ctx) - used_ins = [False] * eqn.params["num_consts"] + used_ins - return used_ins, new_eqn - -remat_opt_p = core.Primitive("remat_opt") -remat_opt_p.multiple_results = True -remat_opt_p.def_impl(_remat_opt_impl) -remat_opt_p.def_effectful_abstract_eval(_remat_opt_abstract_eval) -xla.register_initial_style_primitive(remat_opt_p) -mlir.register_lowering(remat_opt_p, mlir.lower_fun( - _remat_opt_impl, multiple_results=True)) - - -batching.fancy_primitive_batchers[remat_opt_p] = _remat_opt_vmap -ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp -ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose -pe.dce_rules[remat_opt_p] = _remat_opt_dce diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 3628ae4aa..0b0c8621e 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -30,7 +30,6 @@ from jax._src.custom_derivatives import ( custom_vjp_primal_tree_values as custom_vjp_primal_tree_values, CustomVJPPrimal as CustomVJPPrimal, linear_call as linear_call, - remat_opt_p as remat_opt_p, ) from jax._src.ad_util import ( diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index d58a1bb0d..8f39f53ea 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -45,6 +45,7 @@ from jax._src import api from jax._src import api_util from jax._src import config from jax._src import core +from jax._src import custom_dce from jax._src import dispatch from jax._src import dtypes from jax._src import linear_util as lu @@ -3473,15 +3474,14 @@ def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]: tf_impl[ad.custom_lin_p] = _custom_lin -def _remat_opt(*args: TfVal, num_consts: int, num_res: int, - fwd_jaxpr: core.ClosedJaxpr, - fun_jaxpr_thunk: Callable) -> Sequence[TfVal]: - del num_consts, num_res, fun_jaxpr_thunk - return _interpret_jaxpr(fwd_jaxpr, *args, extra_name_stack="remat_opt", +def _custom_dce(*args: TfVal, num_consts: int, fun_jaxpr: core.ClosedJaxpr, + dce_jaxpr_thunk: Callable) -> Sequence[TfVal]: + del num_consts, dce_jaxpr_thunk + return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_dce_call", fresh_constant_cache=False) -tf_impl[custom_derivatives.remat_opt_p] = _remat_opt +tf_impl[custom_dce.custom_dce_p] = _custom_dce PartitionsOrReplicated = Union[tuple[int, ...], None] diff --git a/tests/api_test.py b/tests/api_test.py index 543335529..a5e441259 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -9599,10 +9599,7 @@ class CustomVJPTest(jtu.JaxTestCase): return np.array([2.0])*x*x/np.array([1.0]), (x,) x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), - fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) - + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed @@ -9612,9 +9609,7 @@ class CustomVJPTest(jtu.JaxTestCase): def fwd(x): return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,) x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), - fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x) self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x) @@ -9625,9 +9620,7 @@ class CustomVJPTest(jtu.JaxTestCase): return x*x, (x,) x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), - fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) def g(x): return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x) @@ -9641,9 +9634,7 @@ class CustomVJPTest(jtu.JaxTestCase): def fwd_(x): return x*x, (x,) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}), - fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {})) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd_) calc = jax.jvp(fwd, (3.2,), (1.0,)) expected = jax.jvp(fwd_, (3.2,), (1.0,)) self.assertAllClose(calc, expected) @@ -9740,6 +9731,55 @@ class CustomVJPTest(jtu.JaxTestCase): x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5) jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error + def test_optimize_remat_nondiff_argnums(self): + @partial(jax.custom_vjp, nondiff_argnums=(2,)) + def f(x, y, fun): + return fun(x, y) + + def f_fwd(x, y, fun): + del fun + return jnp.cos(x) * y, (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(fun, res, g): + del fun + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + def fun(x, y): + return jnp.sin(x) * y + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = 0.5, 0.1 + res = jax.value_and_grad(lambda *args: f(*args, fun))(x, y)[0] + self.assertAllClose(res, f_fwd(x, y, fun)[0]) + res = jax.jit(lambda *args: jax.value_and_grad( + lambda *args: f(*args, fun))(*args)[0])(x, y) + self.assertAllClose(res, fun(x, y)) + + def test_optimize_remat_incorrect_signature(self): + def f_(x, y): + return jnp.sin(x) * y + + @jax.custom_vjp + def f(x, y): + return f_(x, y) + + def wrong_signature(x, y, z): + self.fail("wrong_signature should not be called") + + @functools.wraps(wrong_signature) + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = 3.2, 1.0 + self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y)) + + def test_dce(self): @jax.custom_vjp def f(x, y): @@ -10468,20 +10508,20 @@ class CustomDceTest(jtu.JaxTestCase): self.assertAllClose(v, jnp.tan(3.2)**2) def test_static_argnums(self): - @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(0,)) - def g(f, x): + @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(1,)) + def g(x, f): return f(x), 10 * f(x) @g.def_dce def g_dce(f, used_outs, x): # note: static_argnums are always passes first self.assertTrue(callable(f)) - return [2 * v if used else None for used, v in zip(used_outs, g(f, x))] + return [2 * v if used else None for used, v in zip(used_outs, g(x, f))] x = 1.1234 f = lambda x: jnp.exp(x) - expected = g(f, x) - self.assertAllClose(jax.jit(lambda x: g(f, x)[0])(x), 2 * expected[0]) - self.assertAllClose(jax.jit(lambda x: g(f, x)[1])(x), 2 * expected[1]) + expected = g(x, f) + self.assertAllClose(jax.jit(lambda x: g(x, f)[0])(x), 2 * expected[0]) + self.assertAllClose(jax.jit(lambda x: g(x, f)[1])(x), 2 * expected[1]) def test_shape_mismatch_error(self): @jax.experimental.custom_dce.custom_dce From 40e1a2a56132eb4ad95d2fccff6809ce1471b5a7 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 5 Mar 2025 08:39:58 -0500 Subject: [PATCH 023/100] Remove a TSAN suppression. https://github.com/python/cpython/issues/130547 has been marked as fixed and backported to 3.13, so this suppression should no longer be necessary. --- .github/workflows/tsan-suppressions.txt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions.txt index 71542ea5d..7b713b2da 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions.txt @@ -26,8 +26,6 @@ race_top:PyMember_GetOne # https://github.com/python/cpython/issues/129547 race:type_get_annotations -# https://github.com/python/cpython/issues/130547 -race:split_keys_entry_added # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx @@ -64,3 +62,6 @@ race:gemm_oncopy # https://github.com/python/cpython/issues/130571 # race:_PyObject_GetMethod + +# https://github.com/python/cpython/issues/130547 +# race:split_keys_entry_added From 6230ef1d5198f023ff99198a480f2ee78c252692 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 5 Mar 2025 15:18:43 +0000 Subject: [PATCH 024/100] Removed unused import --- tests/fused_attention_stablehlo_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 83e2fd350..af0b18b02 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -14,7 +14,6 @@ from functools import partial from absl.testing import absltest -import os import numpy as np import jax From d119138766cc58444f814daaf86f1398bba62ec1 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 5 Mar 2025 08:05:25 -0800 Subject: [PATCH 025/100] [Mosaic GPU][NFC] Refactor MMA SMEM descriptor creation This makes the code path uniform for LHS/RHS and greatly clarifies the magical computation of LBO/SBO. This change should make it significantly easier for us to enable small tile support for the LHS. PiperOrigin-RevId: 733737302 --- jax/experimental/mosaic/gpu/mma_utils.py | 221 +++++++++++ jax/experimental/mosaic/gpu/tcgen05.py | 151 +++++--- jax/experimental/mosaic/gpu/wgmma.py | 460 ++++++----------------- tests/mosaic/gpu_test.py | 63 ++-- 4 files changed, 467 insertions(+), 428 deletions(-) create mode 100644 jax/experimental/mosaic/gpu/mma_utils.py diff --git a/jax/experimental/mosaic/gpu/mma_utils.py b/jax/experimental/mosaic/gpu/mma_utils.py new file mode 100644 index 000000000..9e8fee49b --- /dev/null +++ b/jax/experimental/mosaic/gpu/mma_utils.py @@ -0,0 +1,221 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +import math + +from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import llvm + +from . import utils + +# mypy: ignore-errors + +def tiled_memref_shape(ref: ir.Value): + """Returns the 2D untiled shape and element type of a tiled 4D memref.""" + ref_ty = ir.MemRefType(ref.type) + if ref_ty.rank != 4: + raise ValueError(f"Expected a 4D memref, got: {ref_ty}") + logical_shape = ( + ref_ty.shape[0] * ref_ty.shape[2], ref_ty.shape[1] * ref_ty.shape[3] + ) + return logical_shape, ref_ty.element_type + + +class Dim(enum.Enum): + K = enum.auto() + MN = enum.auto() + + +def create_descriptor( + ref: ir.Value, + swizzle: int, + large_tile: tuple[int, int], # Soft deprecated. Use small tiling instead. + group_size: tuple[int, int], # Instruction group size on each operand dim. + logical_k_major: bool, # False for LHS, True for RHS. + supports_small_tile: bool = False, # TODO(apaszke): This is a temporary. +): + ref_ty = ir.MemRefType(ref.type) + element_bytewidth = utils.bytewidth(ref_ty.element_type) + swizzle_elems = swizzle // element_bytewidth + ref_strides, _ = ref_ty.get_strides_and_offset() + ref_byte_strides = [s * element_bytewidth for s in ref_strides] + if logical_k_major: + _, mn_tiles, k_tiling, mn_tiling = ref_ty.shape + k_tile_stride, mn_tile_stride, k_tiling_stride, mn_tiling_stride = ( + ref_byte_strides + ) + k_large_tile, mn_large_tile = large_tile + k_group_size, mn_group_size = group_size + else: + mn_tiles, _, mn_tiling, k_tiling = ref_ty.shape + mn_tile_stride, k_tile_stride, mn_tiling_stride, k_tiling_stride = ( + ref_byte_strides + ) + mn_large_tile, k_large_tile = large_tile + mn_group_size, k_group_size = group_size + + IGNORED = 0 + MMA_ATOM_ROWS = 8 + MMA_BYTEWIDTH_K = 32 + mma_width_k = MMA_BYTEWIDTH_K // element_bytewidth + # As far as I can tell (which does not seem to fully align with the way MMA is + # documented in PTX docs), MMA expects the data to be tiled into matrices + # of shape 8 x swizzle_elems, with swizzle_elems dim being the fastest + # changing. I call this submatrix an MMA atom. + # + # The role of the SMEM descriptor is to specify the striding pattern between + # those atoms. The fastest changing dimension is called the "leading" + # dimension and it specifies the stride between consecutive atoms that share + # the same coordinate along that dim. The slower dimension is called a + # "stride" dimension. + if ( + k_large_tile == k_tiling + and (mn_large_tile == mn_tiling or mn_tiles == 1 and mn_tiling < mn_large_tile) + # There are configurations where large tiles are same size as small ones. + # We use the small path since it has fewer restrictions. + and set(large_tile) != {MMA_ATOM_ROWS, swizzle_elems} + ): # Large tiles. + if ( + k_tiling_stride == element_bytewidth + and mn_tiling_stride == k_tiling * element_bytewidth + ): + fastest_dim = Dim.K + leading_byte_offset = IGNORED # TC assumes K to be contiguous here. + # MMA atoms in a group are contiguous, so we increment by the MMA atom + # size. However, we only have one level of striding, and so if the group + # size exceeds a single large tile (and there is more than one tile) then + # that tiled dimension must be contiguous after tiles or else we would + # need another striding level. + if ( + mn_tiles > 1 + and mn_group_size > mn_tiling + and mn_tile_stride != math.prod(large_tile) * element_bytewidth + ): + raise ValueError( + "MMA layout with large tiles that is K-fastest only supports" + " multiple MN tiles when the tiled MN dimension is a contiguous" + " stack of tiles " + f"({mn_tiles}, {mn_tile_stride} != {math.prod(large_tile)} * {element_bytewidth})" + ) + stride_byte_offset = MMA_ATOM_ROWS * swizzle + desc_k_stride = MMA_BYTEWIDTH_K # K is contiguous. + elif ( + k_tiling_stride == k_tiling * element_bytewidth + and mn_tiling_stride == element_bytewidth + ): + if k_large_tile != mn_large_tile: + raise ValueError( + "MMA layout with large tiles that is MN-fastest is only supported" + " when the tiling is square" + ) + fastest_dim = Dim.MN + # Next swizzle atom with the same K coordinate is in the next MN tile. + leading_byte_offset = mn_tile_stride + # MMA atoms in a group are contiguous and a group does not exceed a tile. + assert k_large_tile == k_group_size + stride_byte_offset = MMA_ATOM_ROWS * swizzle + # Each row is swizzle bytes wide, and we read mma_width_k rows at a time. + assert mn_large_tile == swizzle // element_bytewidth + desc_k_stride = mma_width_k * swizzle + else: + raise ValueError("MMA tiles must be contiguous") + else: # Small tiles. + if not supports_small_tile: + raise NotImplementedError("Small tiles are not supported yet") + if k_tiling_stride > mn_tiling_stride: + slower_tiling, faster_tiling = k_tiling, mn_tiling + else: + faster_tiling, slower_tiling = k_tiling, mn_tiling + if slower_tiling != MMA_ATOM_ROWS or faster_tiling != swizzle_elems: + raise ValueError( + f"Tiling should be ({MMA_ATOM_ROWS}, swizzle_elems) where" + f" swizzle_elems = swizzle // bytewidth(dtype) (= {swizzle} //" + f" {element_bytewidth} = {swizzle_elems}), but got ({slower_tiling}," + f" {faster_tiling})" + ) + if k_tiling_stride == element_bytewidth and mn_tiling_stride == swizzle: + fastest_dim = Dim.K + leading_byte_offset = IGNORED # TC assumes K to be contiguous here. + stride_byte_offset = mn_tile_stride + desc_k_stride = MMA_BYTEWIDTH_K # K is contiguous. + elif k_tiling_stride == swizzle and mn_tiling_stride == element_bytewidth: + fastest_dim = Dim.MN + leading_byte_offset = mn_tile_stride + stride_byte_offset = k_tile_stride + k_tiles_per_mma = mma_width_k // MMA_ATOM_ROWS + desc_k_stride = k_tile_stride * k_tiles_per_mma + else: + raise ValueError("MMA tiles must be contiguous") + desc_base = encode_descriptor( + ref, + leading_byte_offset=leading_byte_offset, + stride_byte_offset=stride_byte_offset, + swizzle=swizzle, + ) + + mn_tiles_per_group, rem = divmod(mn_group_size, mn_tiling) + assert not rem + mn_group_stride = mn_tile_stride * mn_tiles_per_group + k_tiles_per_group, rem = divmod(k_group_size, k_tiling) + assert not rem + k_group_stride = k_tile_stride * k_tiles_per_group + + return ( + (desc_base, desc_k_stride), + (mn_group_stride, k_group_stride), + fastest_dim, + ) + + +def encode_addr(x: int): + result = (x & 0x3FFFF) >> 4 + if result << 4 != x: + raise ValueError(f"Cannot encode value in an MMA descriptor: {x}") + return result + + +def encode_descriptor( + memref_arg, + leading_byte_offset: int, + stride_byte_offset: int, + swizzle: int | mgpu_dialect.SwizzlingMode | None, + const_init: int = 0, +): + i64 = ir.IntegerType.get_signless(64) + ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, 3)) + c = lambda x: arith.constant(i64, x) + if swizzle is None or swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: + swizzle_encoding = 0 + elif swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle: + swizzle_encoding = 1 + elif swizzle == mgpu_dialect.SwizzlingMode.k64ByteSwizzle: + swizzle_encoding = 2 + elif swizzle == mgpu_dialect.SwizzlingMode.k32ByteSwizzle: + swizzle_encoding = 3 + else: + raise NotImplementedError(swizzle) + encoded_base_addr = llvm.lshr(llvm.and_(ptr_val, c(0x3FFFF)), c(4)) + # We ignore the offset + desc_const = ( + const_init + | (encode_addr(leading_byte_offset) << 16) + | (encode_addr(stride_byte_offset) << 32) + ) + desc = llvm.or_(arith.shli(c(swizzle_encoding), c(62)), c(desc_const)) + desc = llvm.or_(encoded_base_addr, desc) + return desc diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 1933610b6..7f4047007 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -18,7 +18,6 @@ from __future__ import annotations import dataclasses import math -from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import llvm @@ -27,7 +26,7 @@ import numpy as np from . import utils from . import fragmented_array as fa -from . import _wgmma +from . import mma_utils from .launch_context import LaunchContext # MyPy does a terrible job with the MLIR API. @@ -37,21 +36,6 @@ from .launch_context import LaunchContext TMEM_ROWS = 128 TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46 -def create_smem_descriptor( - memref_arg, - leading_byte_offset: int, - stride_byte_offset: int, - swizzle: int | mgpu_dialect.SwizzlingMode | None, -): - return _wgmma.create_descriptor( - memref_arg, - leading_byte_offset, - stride_byte_offset, - swizzle, - memory_space=3, - const_init=TCGEN05_SMEM_DESCRIPTOR_BIT, - ) - def create_instr_descriptor( m: int, n: int, @@ -100,70 +84,129 @@ def mma( collective: bool = False, ): i64 = ir.IntegerType.get_signless(64) + if isinstance(accumulate, bool): + accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate) + if a_swizzle != b_swizzle: + raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}") + swizzle = a_swizzle + num_cta = 2 if collective else 1 + + # Step 1. Establish the shape and element type of the operation. if not ir.MemRefType.isinstance(a.type): raise ValueError(f"A must be a memref, got {a.type}") if not ir.MemRefType.isinstance(b.type): raise ValueError(f"B must be a memref, got: {b.type}") - if a_swizzle != b_swizzle: - raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}") - if isinstance(accumulate, bool): - accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate) + (k, n), element_type = mma_utils.tiled_memref_shape(b) + (m, k2), element_type2 = mma_utils.tiled_memref_shape(a) + if k != k2: + raise ValueError( + "MMA requires A and B to have the same contraction dimension (K)," + f" got: {k2} and {k}" + ) + if element_type != element_type2: + raise ValueError( + "MMA requires A and B to have the same element type, got:" + f" {element_type2} and {element_type}" + ) + if d.shape != (m, n * num_cta): + raise ValueError( + f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}" + ) + f32 = ir.F32Type.get() + if element_type == f32 or element_type == ir.BF16Type.get(): + if d.dtype != f32: + raise ValueError( + f"MMA with element type {element_type} only supports accumulators" + f" of type f32, but got: {d.dtype}" + ) + elif element_type == ir.F16Type.get(): + if d.dtype != element_type and d.dtype != f32: + raise ValueError( + "MMA with element type f16 only supports accumulators of type f32" + f" or f16, but got: {d.dtype}" + ) - m_group_size = d.layout.elements_in_tile[0] - if m_group_size != 128: + # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles, + # instructions must be issued in groups of the same width as the swizzle. + m_group_elems = d.layout.elements_in_tile[0] + if m_group_elems != 128: raise NotImplementedError("Only 128-row accumulators supported for now") - - ( - a_desc_base, - b_desc_base, - (m, k, n), - (m_groups, k_groups, n_groups), - (a_m_group_stride, a_k_group_stride, b_k_group_stride, b_n_group_stride), - mma_params, - ) = _wgmma._validate_mma( - a, - b, - a_swizzle, - m_group_size=m_group_size, - descriptor_const_init=TCGEN05_SMEM_DESCRIPTOR_BIT, - ) - n_group_size = n // n_groups - if n > 512: - raise ValueError(f"N is too big: at most 512 is supported, but got {n}") - num_cta = 2 if collective else 1 + k_group_elems = swizzle // utils.bytewidth(element_type) + if n % 8: + raise ValueError(f"N must be a multiple of 8, got: {n}") + elif n > 256 and n != 512: + raise ValueError("Only N below 256 or N=512 are supported") if num_cta == 2 and n > 256: raise NotImplementedError( "N is too big for collective MMA. Only up to 256 is supported." ) + n_group_elems = min(n, 256) + if m % m_group_elems: + raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}") + if k % k_group_elems: + raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}") + if n % n_group_elems: + raise ValueError(f"N must be a multiple of {n_group_elems}, got: {n}") + m_groups = m // m_group_elems + k_groups = k // k_group_elems + n_groups = n // n_group_elems + # TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA. + wgmma_element_type = ( + ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type + ) - # TODO(apaszke): Verify that the cluster shape matches the expectation of - # collective MMA. - expected_acc_shape = (m, n * num_cta) - if d.shape != expected_acc_shape: - raise ValueError( - f"Accumulator shape mismatch: expected {expected_acc_shape}, got {d.shape}" - ) + # Step 3. Compute the operand descriptors. + ( + (a_desc_base, a_k_instr_stride), + (a_m_group_stride, a_k_group_stride), + a_fastest, + ) = mma_utils.create_descriptor( + a, + swizzle=swizzle, + large_tile=(m_group_elems, k_group_elems), + group_size=(m_group_elems, k_group_elems), + logical_k_major=False, + ) + ( + (b_desc_base, b_k_instr_stride), + (b_n_group_stride, b_k_group_stride), + b_fastest, + ) = mma_utils.create_descriptor( + b, + swizzle=swizzle, + large_tile=(k_group_elems,) * 2, # It's not a typo that we use k for n. + group_size=(k_group_elems, n_group_elems), + logical_k_major=True, + supports_small_tile=True, + ) + # Step 4. Issue the instructions. true = arith.constant(ir.IntegerType.get_signless(1), 1) for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups): a_offset = mi * a_m_group_stride + ki * a_k_group_stride - a_mk = arith.addi(a_desc_base, utils.c(_wgmma.wgmma_encode(a_offset), i64)) + a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64)) b_offset = ni * b_n_group_stride + ki * b_k_group_stride - b_nk = arith.addi(b_desc_base, utils.c(_wgmma.wgmma_encode(b_offset), i64)) + b_nk = arith.addi(b_desc_base, utils.c(mma_utils.encode_addr(b_offset), i64)) if m_groups != 1: raise NotImplementedError("D needs to be sliced") acc = accumulate if ki == 0 else true _do_mma( d.slice( - slice(None), utils.ds(ni * n_group_size, n_group_size) + slice(None), utils.ds(ni * n_group_elems, n_group_elems) ).address, a_mk, b_nk, d_type=ir.F32Type.get(), - m=m_group_size, + m=m_group_elems, + n=n_group_elems, collective=collective, - **mma_params, + a_transpose=a_fastest != mma_utils.Dim.K, + b_transpose=b_fastest != mma_utils.Dim.K, + a_k_stride=a_k_instr_stride, + b_k_stride=b_k_instr_stride, accumulate=acc, + swizzle=swizzle, + element_type=wgmma_element_type, ) diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index 33e1ce92e..f6b876087 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -14,13 +14,10 @@ # ============================================================================== import dataclasses -import enum import functools import itertools -from typing import Any import jax -from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import llvm @@ -29,6 +26,7 @@ from jaxlib.mlir.dialects import vector import numpy as np from . import fragmented_array as fa +from . import mma_utils from . import utils # mypy: ignore-errors @@ -84,60 +82,6 @@ class WGMMAAccumulator: return cls(_value=value[0], _sync=False) -def wgmma_encode(x: int): - result = (x & 0x3FFFF) >> 4 - if result << 4 != x: - raise ValueError(f"Cannot encode value in a WGMMA descriptor: {x}") - return result - - -def llvm_add(x, y): - return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none) - - -def create_descriptor( - memref_arg, - leading_byte_offset: int, - stride_byte_offset: int, - swizzle: int | mgpu_dialect.SwizzlingMode | None, - memory_space: int | None = None, - const_init: int = 0, -): - i64 = ir.IntegerType.get_signless(64) - ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, memory_space)) - if swizzle is None or swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: - swizzle_encoding = 0 - elif swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle: - swizzle_encoding = 1 - elif swizzle == mgpu_dialect.SwizzlingMode.k64ByteSwizzle: - swizzle_encoding = 2 - elif swizzle == mgpu_dialect.SwizzlingMode.k32ByteSwizzle: - swizzle_encoding = 3 - else: - raise NotImplementedError(swizzle) - encoded_base_addr = llvm.LShrOp( - llvm.AndOp(ptr_val, c(0x3FFFF, i64)).result, c(4, i64) - ) - # We ignore the offset - desc_const = ( - const_init - | (wgmma_encode(leading_byte_offset) << 16) - | (wgmma_encode(stride_byte_offset) << 32) - ) - desc = llvm.or_( - arith.shli(c(swizzle_encoding, i64), c(62, i64)), c(desc_const, i64) - ) - desc = llvm.or_(encoded_base_addr.result, desc) - return desc - - -def _unpack_i32(vec_ty, r): - i32 = ir.IntegerType.get_signless(32) - return vector.bitcast( - vec_ty, vector.splat(ir.VectorType.get((1,), i32), r) - ) - - def _supported_wgmma_types(dtype, abtype) -> bool: input_types_are = lambda ty: ty.isinstance(abtype) if ir.F32Type.isinstance(dtype): @@ -271,14 +215,14 @@ def wgmma_m64( a_args = [_as_i32_reg(v) for v in a_slice.registers.flat] else: if i > 0: - a = llvm_add( + a = _llvm_add( a, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)), ) a_args = [a] # Advance the B descriptor. if i > 0: - b_descriptor = llvm_add( + b_descriptor = _llvm_add( b_descriptor, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)), ) @@ -297,262 +241,6 @@ def wgmma_m64( return to_acc_vec_regs(acc_regs) -class WGMMALayout(enum.Enum): - ROW_MAJOR = enum.auto() - COL_MAJOR = enum.auto() - - -def _validate_mma( - a: Any, - b: ir.Value, - swizzle: int, - m_group_size: int, # The M used by a single instruction. - descriptor_const_init: int = 0, -): - # We need swizzle >= 32 to ensure that our K tiling is larger than the MMA - # instruction's K width. - if swizzle < 32: - raise ValueError(f"Unsupported swizzle: {swizzle}") - - # Get A type. - if a_in_smem := isinstance(a, ir.Value): - if not ir.MemRefType.isinstance(a.type): - raise ValueError(f"When A is an ir.Value, it must be a memref, got: {a.type}") - a_ty = ir.MemRefType(a.type) - a_element_type = a_ty.element_type - a_shape = tuple(a_ty.shape) - if a_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): - raise ValueError("A must be in workgroup memory when it's a reference") - if len(a_shape) != 4: - raise ValueError(f"A must be 4D when it's a reference, got rank {len(a_shape)}") - elif hasattr(a, "shape") and hasattr(a, "mlir_dtype"): - a_element_type = a.mlir_dtype - a_shape = a.shape - else: - raise NotImplementedError(f"Unsupported A type: {type(a)}") - - # Get B type (always a reference). - b_ty = ir.MemRefType(b.type) - if b_ty.rank != 4: - raise ValueError(f"B must be 4D, got rank {b_ty.rank}") - - # Veirfy element types and compute the tiling. - if (element_type := a_element_type) != b_ty.element_type: - raise ValueError( - f"A and B must have the same element type, got: {a_element_type} and" - f" {b_ty.element_type}" - ) - supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()} - if element_type not in supported_types: - raise ValueError(a_element_type) - element_bytewidth = bytewidth(element_type) - swizzle_elems = swizzle // element_bytewidth - - # Verify the shape and strides of B are as expected. - b_k_tiles, n_tiles, b_k_tiling, n_tiling = b_ty.shape - k = b_k_tiles * b_k_tiling - n = n_tiles * n_tiling - - b_strides, _ = b_ty.get_strides_and_offset() - b_byte_strides = [s * element_bytewidth for s in b_strides] - b_k_byte_stride, b_n_byte_stride, *b_tile_byte_strides = b_byte_strides - if ( - b_byte_strides[1] != n_tiling * b_k_tiling * element_bytewidth - and n_tiles != 1 # When there's only one tile, we never jump between them - ): - raise ValueError("B tiles must be contiguous along the N dimension") - if b_tile_byte_strides == [swizzle, element_bytewidth]: # N-fastest - b_order = WGMMALayout.ROW_MAJOR - # This first case (n_tiles == 1) is to allow the somewhat weird case of - # loading a small amount of N-fastest data, that needs to be padded to a - # larger tile due to swizzle. In this case we allow slicing the big tile - # before WGMMA to avoid unnecessary compute on padding. - if n_tiles == 1: - if n_tiling % 8: - raise ValueError("N tile size must be a multiple of 8") - elif n_tiling != swizzle_elems: - raise ValueError( - "Row major RHS (N-fastest) requires the N tile size to be equal to" - f" the swizzle tile size ({swizzle_elems}), but got {n_tiling}" - ) - if b_k_tiling not in {8, swizzle_elems}: - raise ValueError( - "Row major RHS (N-fastest) requires the K tile size to be either" - f" the swizzle tile size ({swizzle_elems}) or 8, but got {b_k_tiling}" - ) - elif b_tile_byte_strides == [element_bytewidth, swizzle]: # K-fastest - b_order = WGMMALayout.COL_MAJOR - if b_k_tiling != swizzle_elems: - raise ValueError( - "Column major RHS (K-fastest) requires the K tile size to be equal" - f" to the swizzle tile size ({swizzle_elems}), but got {b_k_tiling}" - ) - # See the explanation in the N-fastest case when n_tiles == 1. - if n_tiles == 1: - if n_tiling % 8: - raise ValueError("N tile size must be a multiple of 8") - elif n_tiling not in {8, swizzle_elems}: - raise ValueError( - "Column major RHS (K-fastest) requires the N tile size to be either" - f" to the swizzle tile size ({swizzle_elems}) or 8, but got {n_tiling}" - ) - else: - raise ValueError(b_byte_strides) - - if n > 256 and n % 256: - raise ValueError( - f"N group size must be a multiple of 256 when larger than 256, got: {n}" - ) - k_group_size = swizzle_elems - n_group_size = min(n, 256) - b_k_tiles_per_group = k_group_size // b_k_tiling - b_k_group_stride = b_k_byte_stride * b_k_tiles_per_group - n_tiles_per_group = n_group_size // n_tiling - b_n_group_stride = b_n_byte_stride * n_tiles_per_group - - # Verify the shape and strides of A are as expected. - if not a_in_smem: - m = a_shape[0] - a_order = a_m_group_stride = a_k_group_stride = None - else: - a_ty = ir.MemRefType(a.type) - m_tiles, a_k_tiles, m_tiling, a_k_tiling = a_ty.shape - m = m_tiles * m_tiling - # TODO(apaszke): I'm not actually convinced that we need this check. - if m_tiling != m_group_size: - raise ValueError( - f"A's row tiling must be equal to {m_group_size}, got: {m_tiling}" - ) - if a_k_tiling != swizzle_elems or a_k_tiles * a_k_tiling != k: - raise ValueError(a_ty.shape) - a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset() - a_m_byte_stride, a_k_byte_stride, *a_tile_byte_strides = [ - s * element_bytewidth for s in a_strides - ] - if a_tile_byte_strides == [swizzle, element_bytewidth]: - a_order = WGMMALayout.ROW_MAJOR - elif a_tile_byte_strides == [element_bytewidth, swizzle]: - a_order = WGMMALayout.COL_MAJOR - else: - raise ValueError(a_strides) - if a_order != WGMMALayout.ROW_MAJOR and m_tiling != swizzle_elems: - # Not sure what the layout is like, since the tiles aren't square. - raise NotImplementedError - a_m_tiles_per_group = m_group_size // m_tiling - a_m_group_stride = a_m_byte_stride * a_m_tiles_per_group - a_k_tiles_per_group = k_group_size // a_k_tiling - a_k_group_stride = a_k_byte_stride * a_k_tiles_per_group - - b_k_fastest = b_order == WGMMALayout.COL_MAJOR - a_k_fastest = a_order == WGMMALayout.ROW_MAJOR - # This is the number of rows until consecutive repeats of the swizzle pattern. - swizzle_pattern_rows = swizzle // 16 - # A swizzle atom is a 2D matrix with the dimensions below. - swizzle_atom_bytes = swizzle_pattern_rows * 128 - - # Here "leading" refers to the fastest changing dimension. There are two - # strides we have to define per value: - # Leading byte offset (LBO) - # K-fastest: ignored - # MN-fastest: stride between consecutive swizzle atoms that share the same - # K coordinate. - # Stride byte offset (SBO) - # As far as I can tell this is just the offset between two consecutive - # swizzle atoms along the non-leading dimension. - IGNORED = 0 - a_desc_fields = dict( - # I can't fully explain why WGMMA ignores LBO for A. For a_k_fastest, it - # is documented in the PTX docs, and my best explanation for the other - # case is that the instruction has a fixed shape and so it does not care - # about strides. It's possible that it's an artifact of the fact that we - # use tiling of 64. - leading_byte_offset=IGNORED, - stride_byte_offset=swizzle_atom_bytes, - swizzle=swizzle, - memory_space=3, - ) - # If B is N-fastest, all swizzle atoms within a tile share the same N - # coordinate, so we simply take the stride between consecutive N tiles. - # If B is K-fastest, all swizzle atoms within a tile share the same K - # coordinate, which forces us to lay out the tiles in N-fastest order or else - # they would have uneven strides. - b_desc_fields = dict( - leading_byte_offset=IGNORED if b_k_fastest else b_n_byte_stride, - # N tiles are contiguous, so the next N swizzle atom follows immediately. - # K tiles are not contiguous, so we take the stride between them. - stride_byte_offset=swizzle_atom_bytes - if b_k_fastest or b_k_tiling == swizzle_elems - else b_k_byte_stride, - swizzle=swizzle, - memory_space=3, - ) - # The K strides indicate the stride between the consecutive places where all - # coordinates are 0 except for K being incremented by the instruction width. - # If an input is K-fastest, we increment the descriptor by 32 bytes, since - # that is the K-width of all MMA instructions. - if b_k_fastest: - b_k_wgmma_stride = 32 - elif b_k_tiling == swizzle_elems: - # When B is N-fastest and we use the large square tiling, the relevant - # slices all fall within the first tile. A single MMA instruction for 16-bit - # types reads a subtile of shape 16x(swizzle bytes), giving us the necessary - # expression. - assert n_tiling == swizzle_elems or n_tiles == 1 - b_k_wgmma_stride = swizzle * 16 - else: - # If we use the small non-square tiling and N-fastest layout, each tile only - # contains a single swizzle atom with the K coordinate. But, each tile has - # 8 rows, while the WGMMA K width is 16, so we need to jump over 2 tiles. - b_k_wgmma_stride = b_k_byte_stride * 2 - wgmma_params = dict( - a_transpose=not a_k_fastest, - b_transpose=not b_k_fastest, - # TODO(apaszke): This explanation is quite bad. We should better figure - # out how to do LHS transposes. - # We only support swizzle=128 for M-fastest A. In this case the tile is - # swizzle x 64 (= swizzle elems) and so we just take a quarter of its size. - a_k_stride=32 if a_k_fastest else swizzle * 16, - b_k_stride=b_k_wgmma_stride, - swizzle=swizzle, - n=n_group_size, - element_type=ir.FloatTF32Type.get() - if ir.F32Type.isinstance(element_type) - else element_type, - ) - if not a_in_smem: - wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None - a_desc_base = None - else: - a_desc_base = create_descriptor( - a, **a_desc_fields, const_init=descriptor_const_init - ) - b_desc_base = create_descriptor( - b, **b_desc_fields, const_init=descriptor_const_init - ) - - if m % m_group_size: - raise ValueError(f"m must be a multiple of {m_group_size}, got: {m}") - m_groups = m // m_group_size - if k % k_group_size: - raise ValueError(f"k must be a multiple of {k_group_size}, got: {k}") - k_groups = k // k_group_size - if n % n_group_size: - raise ValueError(f"n must be a multiple of {n_group_size}, got: {n}") - n_groups = n // n_group_size - - return ( - a_desc_base, - b_desc_base, - (m, k, n), - (m_groups, k_groups, n_groups), - # Group strides are always in bytes! - (a_m_group_stride, a_k_group_stride, b_k_group_stride, b_n_group_stride), - wgmma_params, - ) - - -# TODO(apaszke): Remove WGMMALayout. Make input shapes logical and infer -# transpositions from memref strides. def wgmma( acc: WGMMAAccumulator, a: fa.FragmentedArray | ir.Value, @@ -570,61 +258,130 @@ def wgmma( The refs must be contiguous or be contiguous except for having their two minor dimensions swapped. """ - a_in_regs = isinstance(a, fa.FragmentedArray) - if not a_in_regs and not ir.MemRefType.isinstance(a.type): - raise ValueError(f"Unsupported A type: {type(a)}") + # Step 1. Establish the shape and element type of the operation. if not ir.MemRefType.isinstance(b.type): raise ValueError(f"B must be a memref, got: {b.type}") - - m_group_size = 64 # Hopper has a fixed M instruction shape. - - ( - a_desc_base, - b_desc_base, - (m, k, n), - (m_groups, k_groups, n_groups), - (a_m_group_stride, a_k_group_stride, b_k_group_stride, _), - wgmma_params, - ) = _validate_mma(a, b, swizzle, m_group_size=m_group_size) - - if n_groups > 1: - raise ValueError("N is too big for WGMMA. Only up to 256 is supported.") - - if a_in_regs: + (k, n), element_type = mma_utils.tiled_memref_shape(b) + if a_in_regs := isinstance(a, fa.FragmentedArray): + m, k2 = a.shape + element_type2 = a.mlir_dtype if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get(): raise ValueError( f"Only 16-bit dtypes supported for A in registers, got {a.mlir_dtype}" ) - if a.shape[0] % m_group_size: - raise ValueError(f"m must be a multiple of 64, got: {a.shape[0]}") - a_m_group_stride = a_k_group_stride = None - + elif ir.MemRefType.isinstance(a.type): + (m, k2), element_type2 = mma_utils.tiled_memref_shape(a) + else: + raise ValueError(f"Unsupported A type: {type(a)}") + if k != k2: + raise ValueError( + "WGMMA requires A and B to have the same contraction dimension (K)," + f" got: {k2} and {k}" + ) + if element_type != element_type2: + raise ValueError( + "WGMMA requires A and B to have the same element type, got:" + f" {element_type2} and {element_type}" + ) if acc.value.shape != (m, n): raise ValueError( f"Accumulator shape mismatch: expected {(m, n)}, got {acc.value.shape}" ) + f32 = ir.F32Type.get() + if element_type == f32 or element_type == ir.BF16Type.get(): + if acc.value.mlir_dtype != f32: + raise ValueError( + f"WGMMA with element type {element_type} only supports accumulators" + f" of type f32, but got: {acc.value.mlir_dtype}" + ) + elif element_type == ir.F16Type.get(): + if acc.value.mlir_dtype != element_type and acc.value.mlir_dtype != f32: + raise ValueError( + "WGMMA with element type f16 only supports accumulators of type f32" + f" or f16, but got: {acc.value.mlir_dtype}" + ) + # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles, + # instructions must be issued in groups of the same width as the swizzle. + m_group_elems = 64 # Hopper has a fixed M instruction shape. + k_group_elems = swizzle // utils.bytewidth(element_type) + if n > 256 or n % 8: + raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}") + n_group_elems = n # We assume only one N group below. + if m % m_group_elems: + raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}") + if k % k_group_elems: + raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}") + m_groups = m // m_group_elems + k_groups = k // k_group_elems + # TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA. + wgmma_element_type = ( + ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type + ) + + # Step 3. Compute the operand descriptors. + if a_in_regs: + a_desc_base = a_m_group_stride = a_k_group_stride = None + a_instr_params = dict(a_transpose=None, a_k_stride=None) + else: + ( + (a_desc_base, a_k_instr_stride), + (a_m_group_stride, a_k_group_stride), + a_fastest, + ) = mma_utils.create_descriptor( + a, + swizzle=swizzle, + large_tile=(m_group_elems, k_group_elems), + group_size=(m_group_elems, k_group_elems), + logical_k_major=False, + ) + a_instr_params = dict(a_transpose=a_fastest != mma_utils.Dim.K, + a_k_stride=a_k_instr_stride) + ( + (b_desc_base, b_k_instr_stride), + (b_n_group_stride, b_k_group_stride), + b_fastest, + ) = mma_utils.create_descriptor( + b, + swizzle=swizzle, + large_tile=(k_group_elems,) * 2, # It's not a typo that we use k for n. + group_size=(k_group_elems, n_group_elems), + logical_k_major=True, + supports_small_tile=True, + ) + del b_n_group_stride # We only support one N group. + + # Step 4. Issue the instructions. if a_in_regs: a = wgmma_fence(a) # Make sure the registers are ready. i64 = ir.IntegerType.get_signless(64) new_acc_regs = acc.value.registers.copy() - k_group_size = k // k_groups for mi in range(m_groups): for ki in range(k_groups): if a_in_regs: a_mk = a[ - mi * m_group_size : (mi + 1) * m_group_size, - ki * k_group_size : (ki + 1) * k_group_size, + mi * m_group_elems : (mi + 1) * m_group_elems, + ki * k_group_elems : (ki + 1) * k_group_elems, ] else: - a_mk = llvm_add( - a_desc_base, - c(wgmma_encode(mi * a_m_group_stride + ki * a_k_group_stride), i64), + a_group_offset = mi * a_m_group_stride + ki * a_k_group_stride + a_mk = _llvm_add( + a_desc_base, c(mma_utils.encode_addr(a_group_offset), i64), ) - b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_group_stride), i64)) + b_k = _llvm_add( + b_desc_base, c(mma_utils.encode_addr(ki * b_k_group_stride), i64) + ) new_acc_regs[mi : mi + 1] = wgmma_m64( - new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params + new_acc_regs[mi : mi + 1], + a_mk, + b_k, + swizzle=swizzle, + n=n_group_elems, + element_type=wgmma_element_type, + b_transpose=b_fastest != mma_utils.Dim.K, + b_k_stride=b_k_instr_stride, + **a_instr_params, ) return WGMMAAccumulator( _value=fa.FragmentedArray( @@ -668,3 +425,14 @@ def _as_i32_reg(v): def _lc(x): i32 = ir.IntegerType.get_signless(32) return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result + + +def _llvm_add(x, y): + return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none) + + +def _unpack_i32(vec_ty, r): + i32 = ir.IntegerType.get_signless(32) + return vector.bitcast( + vec_ty, vector.splat(ir.VectorType.get((1,), i32), r) + ) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 02cd26c15..32d911116 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -660,7 +660,7 @@ class WGMMATest(TestCase): k_steps=(1, 2), swizzle=(32, 64, 128), jax_out_dtype=(jnp.float16, jnp.float32), - small_rhs_tile=(False, True,), + rhs_tiling_kind=("large", "small", "small+no_transpose"), ) def test_wgmma_basic( self, @@ -672,12 +672,14 @@ class WGMMATest(TestCase): rhs_transpose, swizzle, jax_out_dtype, - small_rhs_tile, + rhs_tiling_kind, ): if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type: - raise self.skipTest("Only f16 input is supported for f16 output.") + self.skipTest("Only f16 input is supported for f16 output.") if swizzle != 128 and lhs_transpose: - raise self.skipTest("Transpose only supported in 128B swizzled WGMMA") + self.skipTest("Transpose only supported in 128B swizzled WGMMA") + if rhs_tiling_kind == "small+no_transpose" and not rhs_transpose: + self.skipTest("No transpose happening anyway") in_mlir_dtype = in_mlir_dtype_cls.get() out_mlir_dtype = utils.dtype_to_ir_type(jax_out_dtype) @@ -704,6 +706,8 @@ class WGMMATest(TestCase): assert m % 64 == 0 and n % nk_tile == 0 small_nk_tile = 8 + small_rhs_tile = rhs_tiling_kind != "large" + transpose_rhs_tiles = rhs_tiling_kind != "small+no_transpose" rhs_tiling = ( (small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile) ) @@ -715,7 +719,7 @@ class WGMMATest(TestCase): assert nk_tile == 64 # Make sure we didn't have to transpose tiling. lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) rhs_transform = (mgpu.TileTransform(rhs_tiling),) - if rhs_transpose: + if rhs_transpose and transpose_rhs_tiles: rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) ctx.async_copy( src_ref=lhs, @@ -737,7 +741,8 @@ class WGMMATest(TestCase): if lhs_transpose: lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2)) if rhs_transpose: - rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2)) + perm = (0, 1, 3, 2) if transpose_rhs_tiles else (1, 0, 3, 2) + rhs_smem = memref_transpose(rhs_smem, perm) acc = mgpu.wgmma(init_acc, lhs_smem, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) @@ -752,15 +757,14 @@ class WGMMATest(TestCase): y_shape = (n, k) if rhs_transpose else (k, n) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), jax_out_dtype) - rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling + if transpose_rhs_tiles: + rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling + rhs_smem_shape = (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling) + else: + rhs_smem_shape = tile_shape(y_shape, rhs_tiling) scratch_shape = [ - jax.ShapeDtypeStruct( - (m // 64, k // nk_tile, 64, nk_tile), in_jax_dtype - ), - jax.ShapeDtypeStruct( - (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling), - in_jax_dtype, - ), + jax.ShapeDtypeStruct((m // 64, k // nk_tile, 64, nk_tile), in_jax_dtype), + jax.ShapeDtypeStruct(rhs_smem_shape, in_jax_dtype), mgpu.TMABarrier(2), ] z = mgpu.as_gpu_kernel( @@ -906,7 +910,7 @@ class TCGen05Test(TestCase): n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 k_steps=(1, 2), swizzle=(32, 64, 128,), - small_rhs_tile=(False, True), + rhs_tiling_kind=("large", "small", "small+no_transpose"), ) def test_mma_basic( self, @@ -918,10 +922,10 @@ class TCGen05Test(TestCase): rhs_transpose, in_jax_dtype, out_jax_dtype, - small_rhs_tile, + rhs_tiling_kind, ): if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: - raise self.skipTest("Only f16 input is supported for f16 output.") + self.skipTest("Only f16 input is supported for f16 output.") in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) m_tile = 128 @@ -930,6 +934,9 @@ class TCGen05Test(TestCase): assert m % m_tile == 0 and n % nk_tile == 0 small_nk_tile = 8 + + small_rhs_tile = rhs_tiling_kind != "large" + transpose_rhs_tiles = rhs_tiling_kind != "small+no_transpose" rhs_tiling = (small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile) def kernel(ctx, lhs, rhs, out, scratch): @@ -939,7 +946,7 @@ class TCGen05Test(TestCase): assert nk_tile == m_tile # Make sure we didn't have to transpose tiling lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) rhs_transform = (mgpu.TileTransform(rhs_tiling),) - if rhs_transpose: + if rhs_transpose and transpose_rhs_tiles: rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) ctx.async_copy( src_ref=lhs, @@ -961,7 +968,8 @@ class TCGen05Test(TestCase): if lhs_transpose: lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2)) if rhs_transpose: - rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2)) + perm = (0, 1, 3, 2) if transpose_rhs_tiles else (1, 0, 3, 2) + rhs_smem = memref_transpose(rhs_smem, perm) tcgen05.mma( acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, ) @@ -980,15 +988,14 @@ class TCGen05Test(TestCase): y_shape = (n, k) if rhs_transpose else (k, n) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) - rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling + if transpose_rhs_tiles: + rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling + rhs_smem_shape = (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling) + else: + rhs_smem_shape = tile_shape(y_shape, rhs_tiling) scratch_shape = [ - jax.ShapeDtypeStruct( - tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype - ), - jax.ShapeDtypeStruct( - (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling), - in_jax_dtype, - ), + jax.ShapeDtypeStruct(tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype), + jax.ShapeDtypeStruct(rhs_smem_shape, in_jax_dtype), mgpu.TMABarrier(3), mgpu.TMEM((128, n), out_jax_dtype), ] @@ -2917,4 +2924,4 @@ class SerializationTest(absltest.TestCase): if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) + absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader()) From 4493889cda3378c6fde6136e20b4ea45f1038407 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 5 Mar 2025 09:00:19 -0800 Subject: [PATCH 026/100] [Mosaic GPU] Add support for small tiles for (WG)MMA LHS Thanks to the previous refactor the change is quite trivial and mostly focuses on adding tests. PiperOrigin-RevId: 733754797 --- jax/experimental/mosaic/gpu/mma_utils.py | 3 -- jax/experimental/mosaic/gpu/tcgen05.py | 1 - jax/experimental/mosaic/gpu/wgmma.py | 1 - tests/mosaic/gpu_test.py | 53 ++++++++++++++++-------- 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/jax/experimental/mosaic/gpu/mma_utils.py b/jax/experimental/mosaic/gpu/mma_utils.py index 9e8fee49b..4359ab8b4 100644 --- a/jax/experimental/mosaic/gpu/mma_utils.py +++ b/jax/experimental/mosaic/gpu/mma_utils.py @@ -47,7 +47,6 @@ def create_descriptor( large_tile: tuple[int, int], # Soft deprecated. Use small tiling instead. group_size: tuple[int, int], # Instruction group size on each operand dim. logical_k_major: bool, # False for LHS, True for RHS. - supports_small_tile: bool = False, # TODO(apaszke): This is a temporary. ): ref_ty = ir.MemRefType(ref.type) element_bytewidth = utils.bytewidth(ref_ty.element_type) @@ -135,8 +134,6 @@ def create_descriptor( else: raise ValueError("MMA tiles must be contiguous") else: # Small tiles. - if not supports_small_tile: - raise NotImplementedError("Small tiles are not supported yet") if k_tiling_stride > mn_tiling_stride: slower_tiling, faster_tiling = k_tiling, mn_tiling else: diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 7f4047007..94969c630 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -177,7 +177,6 @@ def mma( large_tile=(k_group_elems,) * 2, # It's not a typo that we use k for n. group_size=(k_group_elems, n_group_elems), logical_k_major=True, - supports_small_tile=True, ) # Step 4. Issue the instructions. diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index f6b876087..f3edbe639 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -347,7 +347,6 @@ def wgmma( large_tile=(k_group_elems,) * 2, # It's not a typo that we use k for n. group_size=(k_group_elems, n_group_elems), logical_k_major=True, - supports_small_tile=True, ) del b_n_group_stride # We only support one N group. diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 32d911116..d28927967 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -661,6 +661,7 @@ class WGMMATest(TestCase): swizzle=(32, 64, 128), jax_out_dtype=(jnp.float16, jnp.float32), rhs_tiling_kind=("large", "small", "small+no_transpose"), + lhs_tiling_kind=("large", "small", "small+no_transpose"), ) def test_wgmma_basic( self, @@ -673,13 +674,16 @@ class WGMMATest(TestCase): swizzle, jax_out_dtype, rhs_tiling_kind, + lhs_tiling_kind, ): if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type: self.skipTest("Only f16 input is supported for f16 output.") - if swizzle != 128 and lhs_transpose: + if swizzle != 128 and lhs_transpose and lhs_tiling_kind == "large": self.skipTest("Transpose only supported in 128B swizzled WGMMA") if rhs_tiling_kind == "small+no_transpose" and not rhs_transpose: self.skipTest("No transpose happening anyway") + if lhs_tiling_kind == "small+no_transpose" and not lhs_transpose: + self.skipTest("No transpose happening anyway") in_mlir_dtype = in_mlir_dtype_cls.get() out_mlir_dtype = utils.dtype_to_ir_type(jax_out_dtype) @@ -705,18 +709,17 @@ class WGMMATest(TestCase): k = nk_tile * k_steps assert m % 64 == 0 and n % nk_tile == 0 - small_nk_tile = 8 small_rhs_tile = rhs_tiling_kind != "large" transpose_rhs_tiles = rhs_tiling_kind != "small+no_transpose" - rhs_tiling = ( - (small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile) - ) + rhs_tiling = (8, nk_tile) if small_rhs_tile else (nk_tile, nk_tile) + small_lhs_tile = lhs_tiling_kind != "large" + transpose_lhs_tiles = lhs_tiling_kind != "small+no_transpose" + lhs_tiling = (8, nk_tile) if small_lhs_tile else (64, nk_tile) def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers = scratch - lhs_transform = (mgpu.TileTransform((64, nk_tile)),) - if lhs_transpose: - assert nk_tile == 64 # Make sure we didn't have to transpose tiling. + lhs_transform = (mgpu.TileTransform(lhs_tiling),) + if lhs_transpose and transpose_lhs_tiles: lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) rhs_transform = (mgpu.TileTransform(rhs_tiling),) if rhs_transpose and transpose_rhs_tiles: @@ -739,7 +742,8 @@ class WGMMATest(TestCase): barriers[i].wait() init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype) if lhs_transpose: - lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2)) + perm = (0, 1, 3, 2) if transpose_lhs_tiles else (1, 0, 3, 2) + lhs_smem = memref_transpose(lhs_smem, perm) if rhs_transpose: perm = (0, 1, 3, 2) if transpose_rhs_tiles else (1, 0, 3, 2) rhs_smem = memref_transpose(rhs_smem, perm) @@ -762,8 +766,13 @@ class WGMMATest(TestCase): rhs_smem_shape = (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling) else: rhs_smem_shape = tile_shape(y_shape, rhs_tiling) + if transpose_lhs_tiles: + lhs_tiling_t = lhs_tiling[::-1] if lhs_transpose else lhs_tiling + lhs_smem_shape = (m // lhs_tiling_t[0], k // lhs_tiling_t[1], *lhs_tiling) + else: + lhs_smem_shape = tile_shape(x_shape, lhs_tiling) scratch_shape = [ - jax.ShapeDtypeStruct((m // 64, k // nk_tile, 64, nk_tile), in_jax_dtype), + jax.ShapeDtypeStruct(lhs_smem_shape, in_jax_dtype), jax.ShapeDtypeStruct(rhs_smem_shape, in_jax_dtype), mgpu.TMABarrier(2), ] @@ -911,6 +920,7 @@ class TCGen05Test(TestCase): k_steps=(1, 2), swizzle=(32, 64, 128,), rhs_tiling_kind=("large", "small", "small+no_transpose"), + lhs_tiling_kind=("large", "small", "small+no_transpose"), ) def test_mma_basic( self, @@ -923,6 +933,7 @@ class TCGen05Test(TestCase): in_jax_dtype, out_jax_dtype, rhs_tiling_kind, + lhs_tiling_kind, ): if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: self.skipTest("Only f16 input is supported for f16 output.") @@ -933,17 +944,17 @@ class TCGen05Test(TestCase): k = nk_tile * k_steps assert m % m_tile == 0 and n % nk_tile == 0 - small_nk_tile = 8 - small_rhs_tile = rhs_tiling_kind != "large" transpose_rhs_tiles = rhs_tiling_kind != "small+no_transpose" - rhs_tiling = (small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile) + rhs_tiling = (8, nk_tile) if small_rhs_tile else (nk_tile, nk_tile) + small_lhs_tile = lhs_tiling_kind != "large" + transpose_lhs_tiles = lhs_tiling_kind != "small+no_transpose" + lhs_tiling = (8, nk_tile) if small_lhs_tile else (128, nk_tile) def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers, acc = scratch - lhs_transform = (mgpu.TileTransform((m_tile, nk_tile)),) - if lhs_transpose: - assert nk_tile == m_tile # Make sure we didn't have to transpose tiling + lhs_transform = (mgpu.TileTransform(lhs_tiling),) + if lhs_transpose and transpose_lhs_tiles: lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) rhs_transform = (mgpu.TileTransform(rhs_tiling),) if rhs_transpose and transpose_rhs_tiles: @@ -966,7 +977,8 @@ class TCGen05Test(TestCase): barriers[1].wait() with mgpu.single_thread(): if lhs_transpose: - lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2)) + perm = (0, 1, 3, 2) if transpose_lhs_tiles else (1, 0, 3, 2) + lhs_smem = memref_transpose(lhs_smem, perm) if rhs_transpose: perm = (0, 1, 3, 2) if transpose_rhs_tiles else (1, 0, 3, 2) rhs_smem = memref_transpose(rhs_smem, perm) @@ -993,8 +1005,13 @@ class TCGen05Test(TestCase): rhs_smem_shape = (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling) else: rhs_smem_shape = tile_shape(y_shape, rhs_tiling) + if transpose_lhs_tiles: + lhs_tiling_t = lhs_tiling[::-1] if lhs_transpose else lhs_tiling + lhs_smem_shape = (m // lhs_tiling_t[0], k // lhs_tiling_t[1], *lhs_tiling) + else: + lhs_smem_shape = tile_shape(x_shape, lhs_tiling) scratch_shape = [ - jax.ShapeDtypeStruct(tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype), + jax.ShapeDtypeStruct(lhs_smem_shape, in_jax_dtype), jax.ShapeDtypeStruct(rhs_smem_shape, in_jax_dtype), mgpu.TMABarrier(3), mgpu.TMEM((128, n), out_jax_dtype), From c099e8081dd2561eef55f883f00d0ccb0ea18115 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 13 Feb 2025 20:53:26 +0000 Subject: [PATCH 027/100] support e2m1fn --- jax/_src/dtypes.py | 14 ++++++++++ jax/_src/export/serialization.fbs | 1 + jax/_src/export/serialization.py | 2 ++ jax/_src/export/serialization_generated.py | 1 + jax/_src/interpreters/mlir.py | 3 +++ jax/_src/numpy/scalar_types.py | 2 ++ jax/_src/public_test_util.py | 5 ++++ jax/_src/test_util.py | 2 ++ jax/numpy/__init__.py | 1 + tests/dtypes_test.py | 31 +++++++++++++++++++++- tests/export_test.py | 2 ++ 11 files changed, 63 insertions(+), 1 deletion(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 808d129ba..853fb5d1c 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -109,6 +109,12 @@ _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) _float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2) _float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz) +# fp4 support +# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 +float4_e2m1fn: type[np.generic] | None = None + +_float4_e2m1fn_dtype: np.dtype | None = None + def supports_inf(dtype: DTypeLike) -> bool: """Return true if the dtype supports infinity, else return False.""" typ = np.dtype(dtype).type @@ -144,6 +150,8 @@ _float8_dtypes = [ _float8_e5m2fnuz_dtype, ] +_float4_dtypes: list[np.dtype] = [] + # TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 if hasattr(ml_dtypes, "float8_e4m3"): float8_e4m3 = ml_dtypes.float8_e4m3 @@ -163,6 +171,12 @@ if hasattr(ml_dtypes, "float8_e8m0fnu"): _custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type] _custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype) _float8_dtypes.insert(0, _float8_e8m0fnu_dtype) +if hasattr(ml_dtypes, "float4_e2m1fn"): + float4_e2m1fn = ml_dtypes.float4_e2m1fn + _float4_e2m1fn_dtype = np.dtype(float4_e2m1fn) + _custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float4_e2m1fn_dtype) + _float4_dtypes.insert(0, _float4_e2m1fn_dtype) # 2-bit integer support int2: type[np.generic] | None = None diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index dd0ae3edc..7d3e342f1 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -75,6 +75,7 @@ enum DType: byte { f8_e5m2 = 20, f8_e5m2fnuz = 21, f8_e8m0fnu = 25, + f4_e2m1fn = 26, } table AbstractValue { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 7707670f1..ac97c11d1 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -365,6 +365,8 @@ if dtypes._float8_e4m3_dtype is not None: _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 if dtypes._float8_e8m0fnu_dtype is not None: _dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu +if dtypes._float4_e2m1fn_dtype is not None: + _dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() } diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index 69092cd7e..b1fc13333 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -62,6 +62,7 @@ class DType(object): f8_e5m2fnuz = 21 f0 = 22 f8_e8m0fnu = 25 + f4_e2m1fn = 26 class ShardingKind(object): diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 7c10c7b8d..c20fa34d4 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -199,6 +199,9 @@ if dtypes.float8_e4m3 is not None: if dtypes.float8_e8m0fnu is not None: _dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get +if dtypes.float4_e2m1fn is not None: + _dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get + def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): # TODO Support different-size underlying dtypes to take advantage of the diff --git a/jax/_src/numpy/scalar_types.py b/jax/_src/numpy/scalar_types.py index 5d20b73af..585a5484a 100644 --- a/jax/_src/numpy/scalar_types.py +++ b/jax/_src/numpy/scalar_types.py @@ -93,6 +93,8 @@ float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz) float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz) +if dtypes.float4_e2m1fn is not None: + float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) bfloat16 = _make_scalar_type(dtypes.bfloat16) float16 = _make_scalar_type(np.float16) float32 = single = _make_scalar_type(np.float32) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 220342ce5..455a3b98c 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -100,6 +100,9 @@ if _dtypes.float8_e4m3 is not None: if _dtypes.float8_e8m0fnu is not None: _default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 +if _dtypes.float4_e2m1fn is not None: + _default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 + default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -124,6 +127,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): custom_float_dtypes.insert(0, _dtypes.float8_e3m4) if _dtypes.float8_e8m0fnu is not None: custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu) + if _dtypes.float4_e2m1fn is not None: + custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn) def maybe_upcast(x): if x.dtype in custom_float_dtypes: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 9f2bab2b4..18f7efa16 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1640,6 +1640,8 @@ class _LazyDtypes: float_dtypes += [_dtypes.float8_e4m3] if _dtypes.float8_e8m0fnu is not None: float_dtypes += [_dtypes.float8_e8m0fnu] + if _dtypes.float4_e2m1fn is not None: + float_dtypes += [_dtypes.float4_e2m1fn] return self.supported(float_dtypes) @_cached_property diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index ad71b9f74..cb291bdca 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -310,6 +310,7 @@ try: float8_e3m4 as float8_e3m4, float8_e4m3 as float8_e4m3, float8_e8m0fnu as float8_e8m0fnu, + float4_e2m1fn as float4_e2m1fn, ) except ImportError: pass diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 8127aed7a..fca3f4320 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -73,6 +73,12 @@ if dtypes.float8_e8m0fnu is not None: float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes +fp4_dtypes = [] +if dtypes.float4_e2m1fn is not None: + fp4_dtypes += [np.dtype(dtypes.float4_e2m1fn)] +float_dtypes += fp4_dtypes +custom_float_dtypes += fp4_dtypes + complex_dtypes = [np.dtype('complex64'), np.dtype('complex128')] @@ -238,6 +244,8 @@ class DtypesTest(jtu.JaxTestCase): continue if t1 in intn_dtypes: continue + if t1 in fp4_dtypes: + continue self.assertEqual(np.dtype(np.complex128), dtypes.promote_types(t1, np.complex128)) @@ -247,6 +255,8 @@ class DtypesTest(jtu.JaxTestCase): continue if t2 in intn_dtypes: continue + if t2 in fp4_dtypes: + continue # Symmetry self.assertEqual(dtypes.promote_types(t1, t2), dtypes.promote_types(t2, t1)) @@ -261,6 +271,8 @@ class DtypesTest(jtu.JaxTestCase): # TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8. if t in fp8_dtypes: continue + if t in fp4_dtypes: + continue if t in intn_dtypes or i in intn_dtypes: continue self.assertEqual(t, dtypes.promote_types(t, i)) @@ -951,10 +963,12 @@ class TestPromotionTables(jtu.JaxTestCase): self.skipTest("XLA support for int2 and int4 is incomplete.") if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']): self.skipTest("TPU does not support float8_e8m0fnu.") + if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): + self.skipTest("TPU does not support float4_e2m1fn.") x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) if weak_type: expected = dtypes.canonicalize_dtype( - dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes] else x.dtype.kind]) + dtypes._default_types['f' if x.dtype in ["bfloat16", *fp8_dtypes, *fp4_dtypes] else x.dtype.kind]) else: expected = x.dtype self.assertEqual(dtypes.result_type(x), expected) @@ -971,6 +985,17 @@ class TestPromotionTables(jtu.JaxTestCase): ".*8-bit floats do not support implicit promotion"): x + y + @jax.numpy_dtype_promotion('standard') + def testFloat4PromotionError(self): + for dtype in fp4_dtypes: + if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): + self.skipTest("TPU does not support float4_e2m1fn.") + x = jnp.array(1, dtype=dtype) + y = jnp.array(1, dtype='float32') + with self.assertRaisesRegex(dtypes.TypePromotionError, + ".*4-bit floats do not support implicit promotion"): + x + y + @jax.numpy_dtype_promotion('standard') @jtu.run_on_devices('tpu') def testInt2PromotionError(self): @@ -995,6 +1020,8 @@ class TestPromotionTables(jtu.JaxTestCase): def testBinaryNonPromotion(self, dtype, weak_type, promotion): if dtype in fp8_dtypes: self.skipTest("XLA support for float8 is incomplete.") + if dtype in fp4_dtypes: + self.skipTest("XLA support for float4 is incomplete.") if dtype in intn_dtypes: self.skipTest("XLA support for int2 and int4 is incomplete.") # Regression test for https://github.com/jax-ml/jax/issues/6051 @@ -1027,6 +1054,8 @@ class TestPromotionTables(jtu.JaxTestCase): self.skipTest('XLA support for int2 is incomplete.') if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']): self.skipTest('TPU does not support float8_e8m0fnu.') + if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): + self.skipTest('TPU does not support float4_e2m1fn.') val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) rep = repr(val) self.assertStartsWith(rep, 'Array(') diff --git a/tests/export_test.py b/tests/export_test.py index 60c96fca4..6baecebe1 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1014,6 +1014,8 @@ class JaxExportTest(jtu.JaxTestCase): self.skipTest(f"TODO: serialization not supported for {str(dtype)}") if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']): self.skipTest("TPU does not support float8_e8m0fnu.") + if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): + self.skipTest("TPU does not support float4_e2m1fn.") @jax.jit def f_jax(x): return x + x From 4a93c8b30c4dafd1c53ab788aba1122b5e56e458 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 5 Mar 2025 10:21:52 -0800 Subject: [PATCH 028/100] Reverts 342cb7b99a09180472823a33c7cdad8a8db77875 PiperOrigin-RevId: 733782497 --- CHANGELOG.md | 3 - jax/_src/custom_derivatives.py | 253 ++++++++++++++++++++++++++---- jax/custom_derivatives.py | 1 + jax/experimental/jax2tf/jax2tf.py | 12 +- tests/api_test.py | 78 +++------ 5 files changed, 250 insertions(+), 97 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65f1dafa8..86d0cab0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,9 +23,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. true, matching the current behavior. If set to false, JAX does not need to emit code clamping negative indices, which improves code size. -* Breaking changes - * The ``jax.custom_derivatives.remat_opt_p`` helper primitive was removed. - ## jax 0.5.1 (Feb 24, 2025) * New Features diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 579086f36..32856106a 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -16,7 +16,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence import dataclasses -from functools import update_wrapper, reduce, partial +from functools import update_wrapper, reduce, partial, wraps from typing import Any, Generic, TypeVar from jax._src import config @@ -32,7 +32,6 @@ from jax._src.ad_util import ( from jax._src.api_util import ( argnums_partial, flatten_fun_nokwargs, resolve_kwargs, prepend_static_args, debug_info) -from jax._src.custom_dce import custom_dce from jax._src.errors import UnexpectedTracerError from jax._src.state.types import AbstractRef from jax._src.interpreters import ad @@ -658,12 +657,10 @@ class custom_vjp(Generic[ReturnValue]): # TODO(necula): figure out how to construct the debug_bwd args debug_bwd = debug_info("custom_vjp bwd", self.bwd, args, {}) if self.optimize_remat: - if self.symbolic_zeros: - # TODO(dfm): This probably shouldn't be too hard to support. - raise NotImplementedError( - "remat optimization for custom_vjp does not support symbolic zeros") fwd = optimize_remat_of_custom_vjp_fwd( - self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums) + self.fun, debug_fun, self.fwd, debug_fwd, + nondiff_argnums=self.nondiff_argnums, + symbolic_zeros=self.symbolic_zeros) else: fwd = self.fwd if config.enable_custom_vjp_by_custom_transpose.value: @@ -1574,31 +1571,229 @@ custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr") # simpler, but it would be worth revisiting this. def optimize_remat_of_custom_vjp_fwd( fun: Callable[..., ReturnValue], + debug_fun: core.DebugInfo, fwd: Callable[..., tuple[ReturnValue, Any]], + debug_fwd: core.DebugInfo, nondiff_argnums: Sequence[int] = (), + symbolic_zeros: bool = False, ) -> Callable[..., tuple[ReturnValue, Any]]: - wrapped_fwd = custom_dce( - # It might seem like we don't need this lambda, but there are some real - # world use cases where the signature of `fwd` is wrong, and we shouldn't - # error out when resolving the arguments in those cases. This is fine, - # because the arguments have already been resolved in custom_vjp. - lambda *args: fwd(*args), # pylint: disable=unnecessary-lambda - static_argnums=nondiff_argnums, - ) + if symbolic_zeros: + # TODO(dfm): This probably shouldn't be too hard to support. + raise NotImplementedError( + "remat optimization for custom_vjp does not support symbolic zeros") - @wrapped_fwd.def_dce - def _(*args): - static_args, used_outs, args = split_list(args, [len(nondiff_argnums), 1]) - static_args_iter = iter(static_args) - args_iter = iter(args) - nondiff_argnums_ = set(nondiff_argnums) - fun_args = [ - next(static_args_iter) if i in nondiff_argnums_ else next(args_iter) - for i in range(len(static_args) + len(args))] - used_outs, = used_outs - _, used_res = used_outs - if any(tree_leaves(used_res)): - return fwd(*fun_args) - return fun(*fun_args), None + @wraps(fwd) + def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: + # TODO(dfm): This initial logic is duplicated from custom_vjp.__call__ + # above and it would be good to consolidate it. + fwd_name = debug_fwd.func_name if debug_fwd else str(fwd) + # Note: we use `fun` instead of `fwd` here for consistency with + # custom_vjp.__call__ above. + args = resolve_kwargs(fun, args, kwargs) + if nondiff_argnums: + for i in nondiff_argnums: _check_for_tracers(args[i]) + nondiff_argnums_ = set(nondiff_argnums) + dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums_] + f_, dyn_args = argnums_partial(lu.wrap_init(fun, debug_info=debug_fun), + dyn_argnums, + args, require_static_args_hashable=False) + fwd_, _ = argnums_partial(lu.wrap_init(fwd, debug_info=debug_fwd), + dyn_argnums, args, + require_static_args_hashable=False) + else: + f_, dyn_args = lu.wrap_init(fun, debug_info=debug_fun), args + fwd_ = lu.wrap_init(fwd, debug_info=debug_fwd) + args_flat, in_tree = tree_flatten(dyn_args) + flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) + flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False, + debug_fun, debug_fwd, in_tree, out_type) + flat_fwd = _fix_fwd_args(flat_fwd) + + in_avals = [core.get_aval(x) for x in args_flat] + fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals) + fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr)) + prim_tree, res_tree = out_trees() + num_res = res_tree.num_leaves + + if fwd_jaxpr.effects: + raise NotImplementedError( + "remat optimization for custom_vjp does not support forward " + f"functions with side effects, but {fwd_name} has the following " + f"effects: {fwd_jaxpr.effects}") + + @pe._memoize + def fun_jaxpr_thunk(): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + return jaxpr, consts + + out_flat = remat_opt_p.bind(*consts, *args_flat, + num_consts=len(consts), + num_res=num_res, + fwd_jaxpr=fwd_jaxpr, + fun_jaxpr_thunk=fun_jaxpr_thunk) + res, out_flat = split_list(out_flat, [num_res]) + out_tree = treedef_tuple((prim_tree, res_tree)) + return tree_unflatten(out_tree, (*out_flat, *res)) return wrapped_fwd + +@lu.transformation2 +def _fix_fwd_args(f, *args): + args = [(x, True) for x in args] + args = [x for pair in args for x in pair] + return f(*args) + +def _remat_opt_impl( + *args, + num_consts: int, + num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], +): + del num_consts, num_res, fun_jaxpr_thunk # unused + return core.jaxpr_as_fun(fwd_jaxpr)(*args) + +def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_): + del args + return fwd_jaxpr.out_avals, fwd_jaxpr.effects + +def _remat_opt_vmap( + axis_data, args, in_dims, + *, + num_consts: int, + num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], +): + args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 + else x for x, d in zip(args, in_dims)] + in_batched = [d is not not_mapped for d in in_dims] + batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( + fwd_jaxpr, axis_data, in_batched, False) + extra_consts = batched_fwd_jaxpr.consts + batched_fwd_jaxpr = pe.close_jaxpr( + pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr)) + out_dims = [0 if b else not_mapped for b in out_batched] + + _, prim_batched = split_list(in_batched, [num_consts]) + + @pe._memoize + def batched_fun_jaxpr_thunk(): + fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) + batched_fun_jaxpr, out_batched = batching.batch_jaxpr( + fun_jaxpr, axis_data, prim_batched, False) + return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts + + batched_outs = remat_opt_p.bind(*extra_consts, *args, + num_consts=num_consts + len(extra_consts), + num_res=num_res, + fwd_jaxpr=batched_fwd_jaxpr, + fun_jaxpr_thunk=batched_fun_jaxpr_thunk) + + return batched_outs, out_dims + +def _remat_opt_jvp( + primals, + tangents, + *, + num_consts: int, + num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], +): + consts, primals = split_list(primals, [num_consts]) + consts_dot, tangents = split_list(tangents, [num_consts]) + # Tangents must be instantated in case we end up DCEing later. + tangents = map(ad.instantiate_zeros, tangents) + consts_nz = [not isinstance(t, Zero) for t in consts_dot] + consts_dot = [c for nz, c in zip(consts_nz, consts_dot) if nz] + in_nz = consts_nz + [True] * len(tangents) + fwd_jaxpr_jvp_, out_nz = ad.jvp_jaxpr(fwd_jaxpr, in_nz, True) + num_out = len(out_nz) - num_res + fwd_jaxpr_jvp_ = ad.rearrange_binders( + fwd_jaxpr_jvp_, [num_consts, len(primals)], + [len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out]) + fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr)) + + # @pe._memoize + def fun_jvp_jaxpr_thunk(): + fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) + in_nz = [True] * len(primals) + fun_jvp_jaxpr, _ = ad.jvp_jaxpr(fun_jaxpr, in_nz, True) + return fun_jvp_jaxpr.jaxpr, fun_jvp_jaxpr.consts + + new_num_consts = len(fwd_jaxpr_jvp_.consts) + num_consts + len(consts_dot) + outs = remat_opt_p.bind(*fwd_jaxpr_jvp_.consts, *consts, *consts_dot, + *primals, *tangents, num_consts=new_num_consts, + num_res=2 * num_res, fwd_jaxpr=fwd_jaxpr_jvp, + fun_jaxpr_thunk=fun_jvp_jaxpr_thunk) + res, res_dot, outs, outs_dot = split_list(outs, [num_res, num_res, num_out]) + return (*res, *outs), (*res_dot, *outs_dot) + +def _remat_opt_transpose( + cts, *args, + num_consts: int, + num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], +): + # TODO(dfm): It shouldn't be too hard to implement this as needed in the + # future. + raise NotImplementedError( + "remat optimization for custom_vjp does not support higher-order AD") + +def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): + if not any(used_outs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]]) + outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] + if any(used_res): + # If any of the residuals are used, we still need to run fwd at this point, + # but we may end up DCEing again in the future, so we must instantiate all + # the input primals. + instantiate = [False] * eqn.params["num_consts"] + instantiate += [True] * (len(eqn.invars) - eqn.params["num_consts"]) + new_jaxpr, used_ins = pe.dce_jaxpr(eqn.params["fwd_jaxpr"].jaxpr, used_outs, + instantiate=instantiate) + assert not new_jaxpr.constvars + closed_jaxpr = pe.close_jaxpr(new_jaxpr) + invars = [v for used, v in zip(used_ins, eqn.invars) if used] + new_params = dict(eqn.params) + new_num_consts = sum(split_list(used_ins, [eqn.params["num_consts"]])[0]) + new_params["num_consts"] = new_num_consts + new_params["fwd_jaxpr"] = closed_jaxpr + new_params["num_res"] = sum(used_res) + new_eqn = pe.new_jaxpr_eqn( + invars, outvars, remat_opt_p, new_params, closed_jaxpr.effects, + eqn.source_info, eqn.ctx) + return used_ins, new_eqn + else: + # If none of the residuals are used, we run the primal computation instead. + # At this point we drop this custom DCE behavior, but since the primal might + # have different consts than fwd, we build a new JaxprEqn with a closed_call + # primitive. + fun_jaxpr, consts = eqn.params["fun_jaxpr_thunk"]() + new_jaxpr, used_consts, used_ins = pe.dce_jaxpr_consts(fun_jaxpr, used_prims) + consts = [c for used, c in zip(used_consts, consts) if used] + closed_jaxpr = core.ClosedJaxpr(new_jaxpr, consts) + _, invars = split_list(eqn.invars, [eqn.params["num_consts"]]) + invars = [v for used, v in zip(used_ins, invars) if used] + new_eqn = pe.new_jaxpr_eqn( + invars, outvars, core.closed_call_p, dict(call_jaxpr=closed_jaxpr), + closed_jaxpr.effects, eqn.source_info, eqn.ctx) + used_ins = [False] * eqn.params["num_consts"] + used_ins + return used_ins, new_eqn + +remat_opt_p = core.Primitive("remat_opt") +remat_opt_p.multiple_results = True +remat_opt_p.def_impl(_remat_opt_impl) +remat_opt_p.def_effectful_abstract_eval(_remat_opt_abstract_eval) +xla.register_initial_style_primitive(remat_opt_p) +mlir.register_lowering(remat_opt_p, mlir.lower_fun( + _remat_opt_impl, multiple_results=True)) + + +batching.fancy_primitive_batchers[remat_opt_p] = _remat_opt_vmap +ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp +ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose +pe.dce_rules[remat_opt_p] = _remat_opt_dce diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 0b0c8621e..3628ae4aa 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -30,6 +30,7 @@ from jax._src.custom_derivatives import ( custom_vjp_primal_tree_values as custom_vjp_primal_tree_values, CustomVJPPrimal as CustomVJPPrimal, linear_call as linear_call, + remat_opt_p as remat_opt_p, ) from jax._src.ad_util import ( diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 8f39f53ea..d58a1bb0d 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -45,7 +45,6 @@ from jax._src import api from jax._src import api_util from jax._src import config from jax._src import core -from jax._src import custom_dce from jax._src import dispatch from jax._src import dtypes from jax._src import linear_util as lu @@ -3474,14 +3473,15 @@ def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]: tf_impl[ad.custom_lin_p] = _custom_lin -def _custom_dce(*args: TfVal, num_consts: int, fun_jaxpr: core.ClosedJaxpr, - dce_jaxpr_thunk: Callable) -> Sequence[TfVal]: - del num_consts, dce_jaxpr_thunk - return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_dce_call", +def _remat_opt(*args: TfVal, num_consts: int, num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable) -> Sequence[TfVal]: + del num_consts, num_res, fun_jaxpr_thunk + return _interpret_jaxpr(fwd_jaxpr, *args, extra_name_stack="remat_opt", fresh_constant_cache=False) -tf_impl[custom_dce.custom_dce_p] = _custom_dce +tf_impl[custom_derivatives.remat_opt_p] = _remat_opt PartitionsOrReplicated = Union[tuple[int, ...], None] diff --git a/tests/api_test.py b/tests/api_test.py index a5e441259..543335529 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -9599,7 +9599,10 @@ class CustomVJPTest(jtu.JaxTestCase): return np.array([2.0])*x*x/np.array([1.0]), (x,) x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed @@ -9609,7 +9612,9 @@ class CustomVJPTest(jtu.JaxTestCase): def fwd(x): return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,) x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x) self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x) @@ -9620,7 +9625,9 @@ class CustomVJPTest(jtu.JaxTestCase): return x*x, (x,) x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) def g(x): return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x) @@ -9634,7 +9641,9 @@ class CustomVJPTest(jtu.JaxTestCase): def fwd_(x): return x*x, (x,) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd_) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}), + fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {})) calc = jax.jvp(fwd, (3.2,), (1.0,)) expected = jax.jvp(fwd_, (3.2,), (1.0,)) self.assertAllClose(calc, expected) @@ -9731,55 +9740,6 @@ class CustomVJPTest(jtu.JaxTestCase): x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5) jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error - def test_optimize_remat_nondiff_argnums(self): - @partial(jax.custom_vjp, nondiff_argnums=(2,)) - def f(x, y, fun): - return fun(x, y) - - def f_fwd(x, y, fun): - del fun - return jnp.cos(x) * y, (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(fun, res, g): - del fun - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - def fun(x, y): - return jnp.sin(x) * y - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = 0.5, 0.1 - res = jax.value_and_grad(lambda *args: f(*args, fun))(x, y)[0] - self.assertAllClose(res, f_fwd(x, y, fun)[0]) - res = jax.jit(lambda *args: jax.value_and_grad( - lambda *args: f(*args, fun))(*args)[0])(x, y) - self.assertAllClose(res, fun(x, y)) - - def test_optimize_remat_incorrect_signature(self): - def f_(x, y): - return jnp.sin(x) * y - - @jax.custom_vjp - def f(x, y): - return f_(x, y) - - def wrong_signature(x, y, z): - self.fail("wrong_signature should not be called") - - @functools.wraps(wrong_signature) - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = 3.2, 1.0 - self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y)) - - def test_dce(self): @jax.custom_vjp def f(x, y): @@ -10508,20 +10468,20 @@ class CustomDceTest(jtu.JaxTestCase): self.assertAllClose(v, jnp.tan(3.2)**2) def test_static_argnums(self): - @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(1,)) - def g(x, f): + @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(0,)) + def g(f, x): return f(x), 10 * f(x) @g.def_dce def g_dce(f, used_outs, x): # note: static_argnums are always passes first self.assertTrue(callable(f)) - return [2 * v if used else None for used, v in zip(used_outs, g(x, f))] + return [2 * v if used else None for used, v in zip(used_outs, g(f, x))] x = 1.1234 f = lambda x: jnp.exp(x) - expected = g(x, f) - self.assertAllClose(jax.jit(lambda x: g(x, f)[0])(x), 2 * expected[0]) - self.assertAllClose(jax.jit(lambda x: g(x, f)[1])(x), 2 * expected[1]) + expected = g(f, x) + self.assertAllClose(jax.jit(lambda x: g(f, x)[0])(x), 2 * expected[0]) + self.assertAllClose(jax.jit(lambda x: g(f, x)[1])(x), 2 * expected[1]) def test_shape_mismatch_error(self): @jax.experimental.custom_dce.custom_dce From 1ae7dd7f76001de98303db6e2c563df136c93c3d Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 5 Mar 2025 10:28:18 -0800 Subject: [PATCH 029/100] Update `.bazelrc` with Apple CC toolchain changes. PiperOrigin-RevId: 733784816 --- .bazelrc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/.bazelrc b/.bazelrc index f86af3a9b..04bfcf2d7 100644 --- a/.bazelrc +++ b/.bazelrc @@ -54,6 +54,12 @@ build:macos --apple_platform_type=macos build:macos --linkopt=-Wl,-undefined,dynamic_lookup build:macos --host_linkopt=-Wl,-undefined,dynamic_lookup +# Use cc toolchains from apple_support for Apple builds. +# https://github.com/bazelbuild/apple_support/tree/master?tab=readme-ov-file#bazel-6-setup +build:macos --apple_crosstool_top=@local_config_apple_cc//:toolchain +build:macos --crosstool_top=@local_config_apple_cc//:toolchain +build:macos --host_crosstool_top=@local_config_apple_cc//:toolchain + # Windows has a relatively short command line limit, which JAX has begun to hit. # See https://docs.bazel.build/versions/main/windows.html build:windows --features=compiler_param_file @@ -196,9 +202,6 @@ build:public_cache_push --config=public_cache --remote_upload_local_results=true # "oct2023" in the URL is just the date when the bucket was created and can be # disregarded. It still contains the latest cache that is being used. build:macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false -# This flag is to address mac arm64 nightly build failures, which are believed -# to be caused by cache poisoning after the Bazel 7.4.0 toolchain upgrade. -build:macos_cache --remote_default_platform_properties='properties:{name:"cache-silo-key" value:"cache-poisoning-2025-03-03"}' # Cache pushes are limited to JAX's CI system. build:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials From 8df00e2666cdc2df9eedc1cc2214e746cf380ecf Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 5 Mar 2025 10:34:05 -0800 Subject: [PATCH 030/100] [Mosaic GPU] Remove support for large tiles on Blackwell We don't have many Blackwell kernels yet, so let's begin the deprecation there! Small tiles have clearer semantics when it comes to transposes too, which allows us to enable more test cases. PiperOrigin-RevId: 733786884 --- .../mosaic/gpu/examples/matmul_blackwell.py | 25 ++--- jax/experimental/mosaic/gpu/mma_utils.py | 13 ++- jax/experimental/mosaic/gpu/tcgen05.py | 2 - tests/mosaic/gpu_test.py | 99 +++++++------------ 4 files changed, 55 insertions(+), 84 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index a6772a575..d15cecbdc 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -53,13 +53,12 @@ def build_kernel( index = ir.IndexType.get() swizzle = 128 - tile_k = swizzle // 2 + swizzle_elems = tile_k = swizzle // 2 + tiling = (8, swizzle_elems) in_dtype = jnp.float16 k_loop_iter = k // tile_k max_concurrent_steps = min(max_concurrent_steps, k_loop_iter) - tma_tile_m = 128 - tma_tile_kn = 64 block_tile_m = tile_m block_tile_n = tile_n @@ -123,17 +122,14 @@ def build_kernel( src_ref=a, dst_ref=mgpu.memref_slice(a_smem, slot), gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)), - gmem_transform=mgpu.TileTransform((tma_tile_m, tma_tile_kn)), + gmem_transform=mgpu.TileTransform(tiling), **common_args, ) ctx.async_copy( src_ref=b, dst_ref=mgpu.memref_slice(b_smem, slot), gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)), - gmem_transform=( - mgpu.TileTransform((tma_tile_kn, tma_tile_kn)), - mgpu.TransposeTransform((1, 0, 2, 3)), - ), + gmem_transform=mgpu.TileTransform(tiling), **common_args, ) @@ -145,7 +141,7 @@ def build_kernel( tcgen05.mma( acc, mgpu.memref_slice(a_smem, slot), - mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (0, 1, 3, 2)), + mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (1, 0, 3, 2)), a_swizzle=swizzle, b_swizzle=swizzle, accumulate=accumulate, @@ -172,26 +168,23 @@ def build_kernel( src_ref=d_smem, dst_ref=d, gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)), - gmem_transform=mgpu.TileTransform((128, 64)), + gmem_transform=mgpu.TileTransform((128, swizzle_elems)), swizzle=swizzle, ) ctx.await_async_copy(0) compute_buffers = ( jax.ShapeDtypeStruct( - mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k), - (tma_tile_m, tma_tile_kn)), + mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k), tiling), jnp.float16), jax.ShapeDtypeStruct( - mgpu.tile_shape((max_concurrent_steps, tile_k, block_tile_n), - (tma_tile_kn, tma_tile_kn)), + mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling), jnp.float16), ) epilogue_buffer = jax.ShapeDtypeStruct( - mgpu.tile_shape((block_tile_m, tile_n), (tma_tile_m, tma_tile_kn)), + mgpu.tile_shape((block_tile_m, tile_n), (128, swizzle_elems)), jnp.float16) smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer]) - assert block_tile_m == 128 smem = ( smem_buffers, [mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2, diff --git a/jax/experimental/mosaic/gpu/mma_utils.py b/jax/experimental/mosaic/gpu/mma_utils.py index 4359ab8b4..81f6af1a9 100644 --- a/jax/experimental/mosaic/gpu/mma_utils.py +++ b/jax/experimental/mosaic/gpu/mma_utils.py @@ -44,29 +44,33 @@ class Dim(enum.Enum): def create_descriptor( ref: ir.Value, swizzle: int, - large_tile: tuple[int, int], # Soft deprecated. Use small tiling instead. group_size: tuple[int, int], # Instruction group size on each operand dim. logical_k_major: bool, # False for LHS, True for RHS. + # Soft deprecated. Use small tiling instead. + large_tile: tuple[int, int] | None = None, ): ref_ty = ir.MemRefType(ref.type) element_bytewidth = utils.bytewidth(ref_ty.element_type) swizzle_elems = swizzle // element_bytewidth ref_strides, _ = ref_ty.get_strides_and_offset() ref_byte_strides = [s * element_bytewidth for s in ref_strides] + mn_large_tile = k_large_tile = None if logical_k_major: _, mn_tiles, k_tiling, mn_tiling = ref_ty.shape k_tile_stride, mn_tile_stride, k_tiling_stride, mn_tiling_stride = ( ref_byte_strides ) - k_large_tile, mn_large_tile = large_tile k_group_size, mn_group_size = group_size + if large_tile is not None: + k_large_tile, mn_large_tile = large_tile else: mn_tiles, _, mn_tiling, k_tiling = ref_ty.shape mn_tile_stride, k_tile_stride, mn_tiling_stride, k_tiling_stride = ( ref_byte_strides ) - mn_large_tile, k_large_tile = large_tile mn_group_size, k_group_size = group_size + if large_tile is not None: + mn_large_tile, k_large_tile = large_tile IGNORED = 0 MMA_ATOM_ROWS = 8 @@ -83,7 +87,8 @@ def create_descriptor( # the same coordinate along that dim. The slower dimension is called a # "stride" dimension. if ( - k_large_tile == k_tiling + large_tile is not None + and k_large_tile == k_tiling and (mn_large_tile == mn_tiling or mn_tiles == 1 and mn_tiling < mn_large_tile) # There are configurations where large tiles are same size as small ones. # We use the small path since it has fewer restrictions. diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 94969c630..7a349f50c 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -163,7 +163,6 @@ def mma( ) = mma_utils.create_descriptor( a, swizzle=swizzle, - large_tile=(m_group_elems, k_group_elems), group_size=(m_group_elems, k_group_elems), logical_k_major=False, ) @@ -174,7 +173,6 @@ def mma( ) = mma_utils.create_descriptor( b, swizzle=swizzle, - large_tile=(k_group_elems,) * 2, # It's not a typo that we use k for n. group_size=(k_group_elems, n_group_elems), logical_k_major=True, ) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index d28927967..3b759a3b7 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -911,7 +911,7 @@ class TCGen05Test(TestCase): self.skipTest("Only works on GPU with capability sm_100a or sm_101a") @parameterized.product( - lhs_transpose=(False,), # TODO(apaszke): True + lhs_transpose=(False, True), rhs_transpose=(False, True), in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32 out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation @@ -919,8 +919,8 @@ class TCGen05Test(TestCase): n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 k_steps=(1, 2), swizzle=(32, 64, 128,), - rhs_tiling_kind=("large", "small", "small+no_transpose"), - lhs_tiling_kind=("large", "small", "small+no_transpose"), + rhs_transpose_tiles=(False, True), + lhs_transpose_tiles=(False, True), ) def test_mma_basic( self, @@ -932,32 +932,24 @@ class TCGen05Test(TestCase): rhs_transpose, in_jax_dtype, out_jax_dtype, - rhs_tiling_kind, - lhs_tiling_kind, + rhs_transpose_tiles, + lhs_transpose_tiles, ): if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: self.skipTest("Only f16 input is supported for f16 output.") in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) - m_tile = 128 - nk_tile = swizzle // bytewidth(in_mlir_dtype) - k = nk_tile * k_steps - assert m % m_tile == 0 and n % nk_tile == 0 - - small_rhs_tile = rhs_tiling_kind != "large" - transpose_rhs_tiles = rhs_tiling_kind != "small+no_transpose" - rhs_tiling = (8, nk_tile) if small_rhs_tile else (nk_tile, nk_tile) - small_lhs_tile = lhs_tiling_kind != "large" - transpose_lhs_tiles = lhs_tiling_kind != "small+no_transpose" - lhs_tiling = (8, nk_tile) if small_lhs_tile else (128, nk_tile) + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + k = swizzle_elems * k_steps + lhs_tiling = rhs_tiling = (8, swizzle_elems) def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers, acc = scratch lhs_transform = (mgpu.TileTransform(lhs_tiling),) - if lhs_transpose and transpose_lhs_tiles: + if lhs_transpose_tiles: lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) rhs_transform = (mgpu.TileTransform(rhs_tiling),) - if rhs_transpose and transpose_rhs_tiles: + if rhs_transpose_tiles: rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) ctx.async_copy( src_ref=lhs, @@ -976,12 +968,14 @@ class TCGen05Test(TestCase): barriers[0].wait() barriers[1].wait() with mgpu.single_thread(): + if lhs_transpose_tiles: + lhs_smem = memref_transpose(lhs_smem, (1, 0, 2, 3)) if lhs_transpose: - perm = (0, 1, 3, 2) if transpose_lhs_tiles else (1, 0, 3, 2) - lhs_smem = memref_transpose(lhs_smem, perm) + lhs_smem = memref_transpose(lhs_smem, (1, 0, 3, 2)) + if rhs_transpose_tiles: + rhs_smem = memref_transpose(rhs_smem, (1, 0, 2, 3)) if rhs_transpose: - perm = (0, 1, 3, 2) if transpose_rhs_tiles else (1, 0, 3, 2) - rhs_smem = memref_transpose(rhs_smem, perm) + rhs_smem = memref_transpose(rhs_smem, (1, 0, 3, 2)) tcgen05.mma( acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, ) @@ -1000,14 +994,16 @@ class TCGen05Test(TestCase): y_shape = (n, k) if rhs_transpose else (k, n) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) - if transpose_rhs_tiles: - rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling - rhs_smem_shape = (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling) + if rhs_transpose_tiles: + rhs_smem_shape = ( + y_shape[1] // rhs_tiling[1], y_shape[0] // rhs_tiling[0], *rhs_tiling, + ) else: rhs_smem_shape = tile_shape(y_shape, rhs_tiling) - if transpose_lhs_tiles: - lhs_tiling_t = lhs_tiling[::-1] if lhs_transpose else lhs_tiling - lhs_smem_shape = (m // lhs_tiling_t[0], k // lhs_tiling_t[1], *lhs_tiling) + if lhs_transpose_tiles: + lhs_smem_shape = ( + x_shape[1] // lhs_tiling[1], x_shape[0] // lhs_tiling[0], *lhs_tiling, + ) else: lhs_smem_shape = tile_shape(x_shape, lhs_tiling) scratch_shape = [ @@ -1025,15 +1021,14 @@ class TCGen05Test(TestCase): np.testing.assert_allclose(z, ref, atol=atol) @parameterized.product( - lhs_transpose=(False,), # TODO(apaszke): True - rhs_transpose=(True,), + lhs_transpose=(False, True), + rhs_transpose=(False, True), in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32 out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation m=(256,), # TODO(apaszke): 64, 192, 256 n=(128, 256), # TODO(apaszke): 512, 192, other non-power-of-2 k_steps=(1, 2), swizzle=(32, 64, 128,), - small_rhs_tile=(False, True), ) def test_mma_collective( self, @@ -1045,42 +1040,27 @@ class TCGen05Test(TestCase): rhs_transpose, in_jax_dtype, out_jax_dtype, - small_rhs_tile, ): if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: raise self.skipTest("Only f16 input is supported for f16 output.") in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) m_block_tile = m // 2 - m_tma_tile = 128 n_block_tile = n // 2 - nk_tma_tile = swizzle // bytewidth(in_mlir_dtype) - k = nk_tma_tile * k_steps - assert m % m_tma_tile == 0 and n % nk_tma_tile == 0 + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + k = swizzle_elems * k_steps index = ir.IndexType.get() - small_nk_tile = 8 if rhs_transpose else 16 - rhs_tiling = ( - (small_nk_tile, nk_tma_tile) - if small_rhs_tile - else (nk_tma_tile, nk_tma_tile) - ) + tiling = (8, swizzle_elems) def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers, acc = scratch - lhs_transform = (mgpu.TileTransform((m_tma_tile, nk_tma_tile)),) - if lhs_transpose: - assert nk_tma_tile == m_tma_tile # Make sure we didn't have to transpose tiling - lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) - rhs_transform = (mgpu.TileTransform(rhs_tiling),) - if rhs_transpose: - rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) block_id = gpu.cluster_block_id(gpu.Dimension.x) ctx.async_copy( src_ref=lhs, dst_ref=lhs_smem, swizzle=swizzle, - gmem_transform=lhs_transform, + gmem_transform=mgpu.TileTransform(tiling), barrier=barriers[0], collective=gpu.Dimension.x, partitioned=1 if lhs_transpose else 0, # Split non-contracting dim. @@ -1089,7 +1069,7 @@ class TCGen05Test(TestCase): src_ref=rhs, dst_ref=rhs_smem, swizzle=swizzle, - gmem_transform=rhs_transform, + gmem_transform=mgpu.TileTransform(tiling), barrier=barriers[1], collective=gpu.Dimension.x, partitioned=0 if rhs_transpose else 1, # Split non-contracting dim. @@ -1100,9 +1080,9 @@ class TCGen05Test(TestCase): barriers[0].wait() barriers[1].wait() if lhs_transpose: - lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2)) + lhs_smem = memref_transpose(lhs_smem, (1, 0, 3, 2)) if rhs_transpose: - rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2)) + rhs_smem = memref_transpose(rhs_smem, (1, 0, 3, 2)) tcgen05.mma( acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, collective=True ) @@ -1118,20 +1098,15 @@ class TCGen05Test(TestCase): return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) x_shape = (k, m) if lhs_transpose else (m, k) + x_block_shape = (k, m_block_tile) if lhs_transpose else (m_block_tile, k) x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) y_shape = (n, k) if rhs_transpose else (k, n) + y_block_shape = (n_block_tile, k) if rhs_transpose else (k, n_block_tile) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) - rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling scratch_shape = [ - jax.ShapeDtypeStruct( - tile_shape((m_block_tile, k), (m_tma_tile, nk_tma_tile)), - in_jax_dtype, - ), - jax.ShapeDtypeStruct( - (k // rhs_tiling_t[0], n_block_tile // rhs_tiling_t[1], *rhs_tiling), - in_jax_dtype, - ), + jax.ShapeDtypeStruct(tile_shape(x_block_shape, tiling), in_jax_dtype), + jax.ShapeDtypeStruct(tile_shape(y_block_shape, tiling), in_jax_dtype), mgpu.TMABarrier(3), mgpu.TMEM((128, n), out_jax_dtype, collective=True), ] From 3e4dc0d4909f0ebe3f85513e320eda3bf8b5d805 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Wed, 5 Mar 2025 12:14:24 -0800 Subject: [PATCH 031/100] add pmap axes hints --- jax/_src/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index baf2af6e9..68c11a08d 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1135,8 +1135,8 @@ def pmap( fun: Callable, axis_name: AxisName | None = None, *, - in_axes=0, - out_axes=0, + in_axes: int | None | Sequence[Any] = 0, + out_axes: Any = 0, static_broadcasted_argnums: int | Iterable[int] = (), devices: Sequence[xc.Device] | None = None, # noqa: F811 backend: str | None = None, From 3edc068f8c2e8375556e4fa0e5f518a02fb1de65 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 5 Mar 2025 12:58:37 -0800 Subject: [PATCH 032/100] Fix ambiguous cpu definition for JAX wheels. Should fix the error in https://github.com/jax-ml/jax/actions/runs/13682579939/job/38258344926. PiperOrigin-RevId: 733838895 --- jaxlib/jax.bzl | 5 +++-- jaxlib/tools/BUILD.bazel | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 633cd07ab..a5f02937c 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -522,9 +522,10 @@ def jax_wheel( # TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0. cpu = select({ "//jaxlib/tools:macos_arm64": "arm64", + "//jaxlib/tools:macos_x86_64": "x86_64", "//jaxlib/tools:win_amd64": "AMD64", - "//jaxlib/tools:arm64": "aarch64", - "@platforms//cpu:x86_64": "x86_64", + "//jaxlib/tools:linux_aarch64": "aarch64", + "//jaxlib/tools:linux_x86_64": "x86_64", }), source_files = source_files, ) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 6eab64823..baf996d50 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -186,6 +186,14 @@ selects.config_setting_group( ], ) +selects.config_setting_group( + name = "macos_x86_64", + match_all = [ + "@platforms//cpu:x86_64", + ":macos", + ], +) + selects.config_setting_group( name = "win_amd64", match_all = [ @@ -194,6 +202,22 @@ selects.config_setting_group( ], ) +selects.config_setting_group( + name = "linux_x86_64", + match_all = [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], +) + +selects.config_setting_group( + name = "linux_aarch64", + match_all = [ + ":arm64", + "@platforms//os:linux", + ], +) + string_flag( name = "jaxlib_git_hash", build_setting_default = "", From 69d66f66df8a117ce3e7fa7babe7254ad9d7bbfe Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Wed, 5 Mar 2025 12:55:00 -0800 Subject: [PATCH 033/100] vmap mismatch size error message: handle *args Fixes: https://github.com/jax-ml/jax/issues/26908 --- jax/_src/api.py | 3 +++ tests/api_test.py | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/jax/_src/api.py b/jax/_src/api.py index baf2af6e9..8e7e4f868 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1094,6 +1094,9 @@ def _mapped_axis_size(fn, tree, vals, dims, name): return f"args{keystr(key_path)}" # args is a tuple, so key_path[0].idx is the index into args. i = key_path[0].idx + # This can happen with star arguments (*args) + if i >= len(signature_parameters): + return f"args{keystr(key_path)}" res = f"argument {signature_parameters[i]}" if len(key_path) > 1: res += keystr(key_path[1:]) diff --git a/tests/api_test.py b/tests/api_test.py index 543335529..571a33e24 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1974,6 +1974,17 @@ class APITest(jtu.JaxTestCase): x2=jnp.ones(2, dtype=jnp.float32) ) + def test_vmap_inconsistent_sizes_constructs_proper_error_message_starargs(self): + # regression test for https://github.com/jax-ml/jax/issues/26908 + def f(x, *args): + return x - functools.reduce(jnp.add, args) + + with self.assertRaisesRegex( + ValueError, + "vmap got inconsistent sizes for array axes to be mapped:" + ): + jax.vmap(f)(jnp.ones(4), jnp.ones(2), jnp.ones(2)) + def test_device_get_scalar(self): x = np.arange(12.).reshape((3, 4)).astype("float32") x = api.device_put(x) From 0913cd7583ca927b1df22589aa7fd2e169b1245a Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 5 Mar 2025 13:53:44 -0800 Subject: [PATCH 034/100] Fix build rule for free-threaded python builds. PiperOrigin-RevId: 733857126 --- jaxlib/jax.bzl | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index a5f02937c..58a83d9b0 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -325,11 +325,18 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) -def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_name, cpu_name, wheel_version): +def _get_full_wheel_name( + package_name, + no_abi, + platform_independent, + platform_name, + cpu_name, + wheel_version, + py_freethreaded): if no_abi or platform_independent: wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl" else: - wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl" + wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}{free_threaded_suffix}-{wheel_platform_tag}.whl" python_version = HERMETIC_PYTHON_VERSION.replace(".", "") return wheel_name_template.format( package_name = package_name, @@ -339,6 +346,7 @@ def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_na wheel_platform_tag = "any" if platform_independent else "_".join( PLATFORM_TAGS_DICT[platform_name, cpu_name], ), + free_threaded_suffix = "t" if py_freethreaded.lower() == "yes" else "", ) def _get_source_distribution_name(package_name, wheel_version): @@ -352,6 +360,7 @@ def _jax_wheel_impl(ctx): override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value output_path = ctx.attr.output_path[BuildSettingInfo].value git_hash = ctx.attr.git_hash[BuildSettingInfo].value + py_freethreaded = ctx.attr.py_freethreaded[BuildSettingInfo].value executable = ctx.executable.wheel_binary if include_cuda_libs and not override_include_cuda_libs: @@ -387,6 +396,7 @@ def _jax_wheel_impl(ctx): platform_name = platform_name, cpu_name = cpu, wheel_version = full_wheel_version, + py_freethreaded = py_freethreaded, ) wheel_file = ctx.actions.declare_file(output_path + "/" + wheel_name) @@ -463,6 +473,7 @@ _jax_wheel = rule( "enable_rocm": attr.bool(default = False), "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), + "py_freethreaded": attr.label(default = Label("@rules_python//python/config_settings:py_freethreaded")), }, implementation = _jax_wheel_impl, executable = False, From 016b351f004fe25c9e3d614a6992bf91d592f61b Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Wed, 5 Mar 2025 15:14:37 -0800 Subject: [PATCH 035/100] [Pallas] Adds a simple dynamic race detector for TPU interpret mode. PiperOrigin-RevId: 733885890 --- jax/_src/pallas/mosaic/interpret.py | 948 ++++++++++++------ .../tpu_pallas_interpret_distributed_test.py | 187 +++- tests/pallas/tpu_pallas_interpret_test.py | 60 +- 3 files changed, 842 insertions(+), 353 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index b0ada86cd..4f2d0b4f3 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -17,6 +17,7 @@ from collections.abc import Iterable, Sequence import dataclasses import enum import functools +import itertools import math import threading from typing import Any, Literal @@ -72,29 +73,130 @@ class TPUInterpretParams: device is waiting on a DMA semaphore that will be signaled when the read or write is complete. Default: "on_wait". + detect_races: If True, a dynamic, happens-before race detector will be + used to detect data races during kernel interpretation. If any races are + detected, a message will be printed and `races.races_found` will be set + to True. + Default: False. """ dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" + detect_races: bool = False + + +VectorClock = np.ndarray + +# Conceptually, each DMA runs on its own, independent device. Representing +# this precisely would require vector clocks to have sizes linear in the number +# of DMAs. +# +# Instead, we use approximate vector clocks of fixed size. We assign each DMA +# a virtual device ID in the range [num_devices + 1, NUM_VIRTUAL_DEVICES] -- +# and each operation of a DMA increments the corresponding coordinate in its +# vector clock. (So the "virtual" part of a vector clock is effectively +# counting, for each virtual device, the number of DMAs that happened-before +# the vector clock and were assigned to that virtual device.) +# +# If two approximate clocks are unordered, then their corresponding events are +# not ordered by the happens-before relation. So this approximation will not +# introduce any false positives in detecting data races. But we may fail to +# detect some true data races because there can be cases where two approximate +# clocks are ordered, and we will treat the corresponding events as ordered +# by the happens-before relation, but the corresponding events are not +# actually ordered. +NUM_VIRTUAL_DEVICES = 32 + +def make_vector_clock(num_devices: int) -> VectorClock: + del num_devices + return np.zeros(NUM_VIRTUAL_DEVICES, dtype=np.int32) + +def copy_vector_clock(x: VectorClock) -> VectorClock: + if x is None: + return None + return x.copy() + +def update_vector_clock(x: VectorClock, y: VectorClock): + x[:] = np.maximum(x, y) + +def lt(x: VectorClock, y: VectorClock) -> bool: + return bool((x <= y).all() & (x < y).any()) + +def ordered(x: VectorClock, y: VectorClock) -> bool: + return lt(x, y) | lt(y, x) + +def inc_vector_clock(x: VectorClock, device_id: int): + if device_id >= len(x): + raise ValueError(f'device_id={device_id} is out of range for x={x}') + assert device_id < len(x) + x[device_id] += 1 + class Semaphore: def __init__(self, semaphore_id=None): + shared_memory = _get_shared_memory() + self.id = semaphore_id + + # TODO(jburnim): Use one Condition variable per device. (Which will be + # easier to do when we're using single integer device IDs.) self.cv = threading.Condition() - # TODO(jburnim): Make this an array. - self.counts = collections.defaultdict(int) + self.counts = np.zeros(shared_memory.num_devices, dtype=np.int32) - def signal(self, inc, device_id): + self.interpret_params = shared_memory.interpret_params + if self.interpret_params.detect_races: + # We associate a vector clock with each count in self.counts. Whenever + # self.counts[i] is signaled, self.clocks[i] is updated with the vector + # clock of the signaling device. Whenever device i successfully waits on + # self.counts[i], the vector clock of device i is updated with + # self.clocks[i]. + # + # TODO(jburnim): Model happens-before more precisely for the case where + # semaphores are over-signaled. + self.clocks = [None] * shared_memory.num_devices + + def signal(self, inc, device_id, clock): + """Signal the semaphore on `device_id` by `inc`. + + Args: + inc: A positive integer. The amount by which to increment the semaphore + on the target device. + device_id: The ID of the target device. + clock: The vector clock of the signaling device at the time of the signal. + """ + device_id = int(device_id) with self.cv: self.counts[device_id] += inc + if self.interpret_params.detect_races: + if self.clocks[device_id] is None: + self.clocks[device_id] = copy_vector_clock(clock) + else: + update_vector_clock(self.clocks[device_id], clock) self.cv.notify_all() - def wait(self, value, device_id, *, is_dma=False, interpret_params=None): + def read(self, device_id): + with self.cv: + return self.counts[device_id] + + def wait(self, value, device_id, *, is_dma=False): + device_id = int(device_id) + shared_memory = _get_shared_memory() + + # TODO(jburnim): + # - If the count is larger than value, raise an error? + # - If the count is equal to value, but there DMAs waiting to signal us, + # raise an error? + # Simple implementation for non-DMA semaphores. - if not is_dma or (interpret_params.dma_execution_mode == "eager"): + if not is_dma or (self.interpret_params.dma_execution_mode == "eager"): with self.cv: while self.counts[device_id] < value: self.cv.wait() self.counts[device_id] -= value + if self.interpret_params.detect_races: + clock = copy_vector_clock(self.clocks[device_id]) + if self.interpret_params.detect_races: + with shared_memory.lock: + update_vector_clock(shared_memory.clocks[device_id], clock) return # For DMA semaphores (when dma_execution_mode=='on_wait'), while our count @@ -106,10 +208,18 @@ class Semaphore: # up separate threads to handle executing DMAs. shared_memory = _get_shared_memory() while True: + clock = None with self.cv: if self.counts[device_id] >= value: self.counts[device_id] -= value - return + if self.interpret_params.detect_races: + clock = copy_vector_clock(self.clocks[device_id]) + else: + return + if clock is not None: + with shared_memory.lock: + update_vector_clock(shared_memory.clocks[device_id], clock) + return with shared_memory.lock: dma_queue = shared_memory.dmas_by_sem[self.id] @@ -121,14 +231,27 @@ class Semaphore: # Only execute the DMA as far as necessary to signal us. assert (dma.src_sem is self) or (dma.dst_sem is self) with dma.lock: + if dma.virtual_device_id is None: + dma.virtual_device_id = np.random.randint( + shared_memory.num_devices, NUM_VIRTUAL_DEVICES) + if dma.state == DmaState.STARTED: # Do the read. - dma.data = get(dma.src_device_id, dma.src_memory_space, - dma.src_buffer_id, dma.src_transforms) + if self.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) + dma.data = get(dma.src_device_id, + dma.src_memory_space, + dma.src_buffer_id, + dma.src_transforms, + clock=copy_vector_clock(dma.clock), + src_device_id=dma.id, + source_info=dma.source_info) + if self.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) if dma.src_sem is not None: data_size = dma.data.itemsize * dma.data.size dma.src_sem.signal( - data_size, device_id=dma.src_device_id) + data_size, device_id=dma.src_device_id, clock=dma.clock) dma.state = DmaState.READ if dma.src_sem is self: @@ -138,11 +261,22 @@ class Semaphore: assert dma.state == DmaState.READ # Do the write. - store(dma.dst_device_id, dma.dst_memory_space, dma.dst_buffer_id, - dma.dst_transforms, dma.data) assert dma.dst_sem is self + if self.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) + store(dma.dst_device_id, + dma.dst_memory_space, + dma.dst_buffer_id, + dma.dst_transforms, + dma.data, + clock=copy_vector_clock(dma.clock), + src_device_id=dma.id, + source_info=dma.source_info) + if self.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) data_size = dma.data.itemsize * dma.data.size - dma.dst_sem.signal(data_size, device_id=dma.dst_device_id) + dma.dst_sem.signal( + data_size, device_id=dma.dst_device_id, clock=dma.clock) dma.data = None dma.state = DmaState.COMPLETED @@ -168,17 +302,146 @@ class DMA: src_sem: Semaphore dst_sem: Semaphore + clock: VectorClock + source_info: source_info_util.SourceInfo | None = None state: DmaState = DmaState.STARTED data: np.ndarray | None = None + virtual_device_id: int | None = None lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) @dataclasses.dataclass -class SharedMemory: +class RaceDetectionState: num_devices: int + # (memory_space, buffer_id, device_id) -> [(device_id, VectorClock, range)] + reads: dict = dataclasses.field( + default_factory=lambda: collections.defaultdict(list)) + + # (memory_space, buffer_id, device_id) -> [(device_id, VectorClock, range)] + writes: dict = dataclasses.field( + default_factory=lambda: collections.defaultdict(list)) + + lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) + + races_found: bool = False + +def _is_empty_slice(slice_or_idx: slice | int): + if isinstance(slice_or_idx, int) or (slice_or_idx == slice(None)): + return False + + # NOTE: All slices here will have known size. + start = int(slice_or_idx.start) if slice_or_idx.start is not None else 0 + stop = int(slice_or_idx.stop) + return (start < stop) + +def slices_overlap(slice_or_idx1: slice | int, slice_or_idx2: slice | int): + if isinstance(slice_or_idx1, int): + slice_or_idx1 = slice(slice_or_idx1, slice_or_idx1 + 1) + if isinstance(slice_or_idx2, int): + slice_or_idx2 = slice(slice_or_idx2, slice_or_idx2 + 1) + + if slice_or_idx1 == slice(None): + return _is_empty_slice(slice_or_idx2) + if slice_or_idx2 == slice(None): + return _is_empty_slice(slice_or_idx1) + + # TODO(jburnim): Handle non-zero steps. + assert (slice_or_idx1.step == 1) or (slice_or_idx1.step is None) + assert (slice_or_idx2.step == 1) or (slice_or_idx2.step is None) + + # NOTE: We are only comparing slices with known stops (and sizes). + # Do we need to handle zero-length slices? + return ((slice_or_idx1.start <= slice_or_idx2.start < slice_or_idx1.stop) + | (slice_or_idx2.start <= slice_or_idx1.start < slice_or_idx2.stop)) + +def ranges_overlap(range1: tuple[slice | int, ...], + range2: tuple[slice | int, ...]) -> bool: + return all(slices_overlap(r1, r2) for r1, r2 + in itertools.zip_longest(range1, range2, fillvalue=slice(None))) + +def check_read(device_id, clock, buffer_key, rnge, source_info=None): + if source_info is not None: + user_frame = source_info_util.summarize(source_info) + else: + user_frame = 'pallas_call' + + with races.lock: + writes = races.writes[buffer_key] + num_writes = len(writes) + races.reads[buffer_key].append((device_id, clock, rnge, user_frame)) + + for i in range(num_writes): + write_device_id, write_clock, write_range, write_frame = writes[i] + if ordered(write_clock, clock): + continue + if not ranges_overlap(rnge, write_range): + continue + # TODO(jburnim): When printing device IDs for reads/writes, distinguish + # between real device IDs vs. DMA IDs. + print('RACE DETECTED\n' + f' read of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' + f' write of {buffer_key}[{write_range}] from {write_device_id}, {write_frame}') + with races.lock: + races.races_found = True + return + +def check_write(device_id, clock, buffer_key, rnge, source_info=None): + if source_info is not None: + user_frame = source_info_util.summarize(source_info) + else: + user_frame = 'pallas_call' + + with races.lock: + writes = races.writes[buffer_key] + reads = races.reads[buffer_key] + num_writes = len(writes) + num_reads = len(reads) + races.writes[buffer_key].append((device_id, clock, rnge, user_frame)) + + # TODO(jburnim): For performance, we should also probably remove any + # conflicting reads and writes that happened-before the current write. + + for i in range(num_writes): + write_device_id, write_clock, write_range, write_frame = writes[i] + if ordered(write_clock, clock): + continue + if not ranges_overlap(rnge, write_range): + continue + # TODO(jburnim): When printing device IDs for reads/writes, distinguish + # between real device IDs vs. DMA IDs. + print('RACE DETECTED\n' + f' write of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' + f' write of {buffer_key}[{write_range}] from {write_device_id}, {write_frame}') + with races.lock: + races.races_found = True + break + + for i in range(num_reads): + read_device_id, read_clock, read_range, read_frame = reads[i] + if ordered(read_clock, clock): + continue + if not ranges_overlap(rnge, read_range): + continue + # TODO(jburnim): When printing device IDs for reads/writes, distinguish + # between real device IDs vs. DMA IDs. + print('RACE DETECTED\n' + f' write of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' + f' read of {buffer_key}[{read_range}] from {read_device_id}, {read_frame}') + with races.lock: + races.races_found = True + return + + +@dataclasses.dataclass +class SharedMemory: + interpret_params: TPUInterpretParams + num_devices: int + clocks: list[VectorClock] + barrier: threading.Barrier + # (memory_space, buffer_id, device_id) -> NumPy array # TODO(jburnim): Handle Megacore. mem: dict[tuple[int, int, int], np.ndarray] = dataclasses.field( @@ -208,6 +471,7 @@ class SharedMemory: # Maybe for running multiple distinct interpreted computations in parallel? _shared_memory : SharedMemory | None = None _shared_memory_init_lock = threading.Lock() +races : RaceDetectionState | None = None def _get_shared_memory() -> SharedMemory: assert _shared_memory is not None @@ -218,15 +482,29 @@ def _clear_shared_memory(): with _shared_memory_init_lock: _shared_memory = None -def _initialize_shared_memory(device_id, num_devices): +def _initialize_shared_memory(device_id, num_devices, *, interpret_params): global _shared_memory del device_id num_devices = int(num_devices) with _shared_memory_init_lock: if _shared_memory is None: - _shared_memory = SharedMemory(num_devices=num_devices) + _shared_memory = SharedMemory( + interpret_params=interpret_params, + num_devices=num_devices, + clocks=[make_vector_clock(num_devices) for _ in range(num_devices)], + barrier=threading.Barrier(num_devices)) assert _shared_memory.num_devices == num_devices + global races + races = RaceDetectionState(num_devices=num_devices) + +def _clean_up_shared_memory(device_id): + device_id = int(device_id) + shared_memory = _get_shared_memory() + shared_memory.barrier.wait() + if device_id == 0: + _clear_shared_memory() + def _validate(device_id): device_id = int(device_id) @@ -235,7 +513,9 @@ def _validate(device_id): for sem in shared_memory.sem.values(): with sem.cv: if sem.counts[device_id] != 0: - raise ValueError( + # TODO(jburnim): Make this raise an error, but in a way that doesn't + # cause other devices to hang later in `_clean_up_shared_memory`. + print( f'Semaphore {sem.id} has non-zero count for {device_id} at ' f'kernel exit: {sem.counts[device_id]}') @@ -248,6 +528,8 @@ def _allocate_buffer(device_id, memory_space, val): with shared_memory.lock: buffer_id = shared_memory.next_buffer_id[device_id] shared_memory.next_buffer_id[device_id] = buffer_id + 1 + # TODO(jburnim): Add options for initializing memory (e.g., with NaNs, + # with zeros, or with the buffer ID). shared_memory.mem[(memory_space, buffer_id, device_id)] = val # TODO(jburnim): Raise an error if buffer_id is too big for int16. @@ -273,7 +555,7 @@ def _allocate_semaphores(device_id, shape): semaphore_id = shared_memory.next_semaphore_id[device_id] shared_memory.next_semaphore_id[device_id] = semaphore_id + num_semaphores for i in range(semaphore_id, semaphore_id + num_semaphores): - if not i in shared_memory.sem: + if i not in shared_memory.sem: shared_memory.sem[i] = Semaphore(i) # NOTE: For now, we use a relatively uncommon datatype (int16) for @@ -305,7 +587,7 @@ def get_barrier_semaphore(device_id, collective_id): shared_memory = _get_shared_memory() with shared_memory.lock: semaphore_id = collective_id - if not semaphore_id in shared_memory.sem: + if semaphore_id not in shared_memory.sem: shared_memory.sem[semaphore_id] = Semaphore() return np.int16(semaphore_id) @@ -314,10 +596,9 @@ def _transform_slice_or_index(slice_or_idx): if isinstance(slice_or_idx, int): return slice_or_idx else: - start, size, stride = ( - int(slice_or_idx.start), - int(slice_or_idx.size), - int(slice_or_idx.stride)) + start = int(slice_or_idx.start) + size = int(slice_or_idx.size) + stride = int(slice_or_idx.stride) return slice(start, start + size * stride, stride) def _compose_slice_or_index(slice_or_idx1, slice_or_idx2): @@ -355,7 +636,8 @@ def _to_range(transforms) -> tuple[slice | int, ...]: ret, tuple(_transform_slice_or_index(i) for i in transform.indices)) return ret -def get(device_id, memory_space, buffer_id, transforms): +def get(device_id, memory_space, buffer_id, transforms, *, + src_device_id=None, clock=None, source_info=None): device_id = int(device_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) @@ -367,6 +649,10 @@ def get(device_id, memory_space, buffer_id, transforms): shared_memory = _get_shared_memory() with shared_memory.lock: read_range = _to_range(transforms) + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + if clock is None: + clock = copy_vector_clock(shared_memory.clocks[device_id]) buffer = shared_memory.mem[(memory_space, buffer_id, device_id)] ret = buffer[read_range].copy() if transforms: @@ -377,9 +663,17 @@ def get(device_id, memory_space, buffer_id, transforms): raise ValueError( f'Out-of-bounds read of ({device_id} {memory_space} {buffer_id}): ' f'reading [{read_range}] but bufer has shape {buffer.shape} .') - return ret -def store(device_id, memory_space, buffer_id, transforms, val): + if shared_memory.interpret_params.detect_races: + if src_device_id is None: + src_device_id = device_id + check_read(src_device_id, clock, (memory_space, buffer_id, device_id), + read_range, source_info=source_info) + + return ret + +def store(device_id, memory_space, buffer_id, transforms, val, *, + src_device_id=None, clock=None, source_info=None): device_id = int(device_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) @@ -391,6 +685,11 @@ def store(device_id, memory_space, buffer_id, transforms, val): shared_memory = _get_shared_memory() with shared_memory.lock: + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + if clock is None: + clock = copy_vector_clock(shared_memory.clocks[device_id]) + buff = shared_memory.mem[(memory_space, buffer_id, device_id)] assert buff.dtype == val.dtype # TODO(jburnim): Catch this statically. write_range = _to_range(transforms) @@ -402,7 +701,14 @@ def store(device_id, memory_space, buffer_id, transforms, val): f'writing [{write_range}] but buffer has shape {buff.shape} .') buff[write_range] = val -def swap(device_id, memory_space, buffer_id, transforms, val, mask): + if shared_memory.interpret_params.detect_races: + if src_device_id is None: + src_device_id = device_id + check_write(src_device_id, clock, (memory_space, buffer_id, device_id), + write_range, source_info=source_info) + +def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, + source_info=None): device_id = int(device_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) @@ -417,6 +723,9 @@ def swap(device_id, memory_space, buffer_id, transforms, val, mask): shared_memory = _get_shared_memory() with shared_memory.lock: + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + clock = copy_vector_clock(shared_memory.clocks[device_id]) buff = shared_memory.mem[(memory_space, buffer_id, device_id)] assert buff.dtype == val.dtype # TODO(jburnim): Catch this statically. read_write_range = _to_range(transforms) @@ -447,31 +756,64 @@ def swap(device_id, memory_space, buffer_id, transforms, val, mask): mask[in_bounds_idx], raw_result, val[in_bounds_idx]) buff[read_write_range] = np.where( mask[in_bounds_idx], val[in_bounds_idx], raw_result) - return result + + if shared_memory.interpret_params.detect_races: + check_write(device_id, clock, (memory_space, buffer_id, device_id), + read_write_range, source_info=source_info) + return result def execute_dma(dma): + # TODO(jburnim) Eliminate duplicate code here and in Semaphore.wait. + shared_memory = _get_shared_memory() with dma.lock: assert dma.state == DmaState.STARTED - # Do the read. - dma.data = get(dma.src_device_id, dma.src_memory_space, - dma.src_buffer_id, dma.src_transforms) - data_size = dma.data.itemsize * dma.data.size + if dma.virtual_device_id is None: + # See comment in Semaphore.wait . + dma.virtual_device_id = np.random.randint( + shared_memory.num_devices, NUM_VIRTUAL_DEVICES) - # Signal the send semaphore. - if dma.src_sem is not None: - dma.src_sem.signal(data_size, device_id=dma.src_device_id) + # Do the read. + if shared_memory.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) + dma.data = get(dma.src_device_id, + dma.src_memory_space, + dma.src_buffer_id, + dma.src_transforms, + clock=copy_vector_clock(dma.clock), + src_device_id=dma.id, + source_info=dma.source_info) + data_size = dma.data.itemsize * dma.data.size - # Do the write. - store(dma.dst_device_id, dma.dst_memory_space, dma.dst_buffer_id, - dma.dst_transforms, dma.data) + # Signal the send semaphore. + if shared_memory.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) + if dma.src_sem is not None: + dma.src_sem.signal( + data_size, device_id=dma.src_device_id, clock=dma.clock) + dma.state = DmaState.READ - # Signal the receive semaphore. - if dma.dst_sem is not None: - dma.dst_sem.signal(data_size, device_id=dma.dst_device_id) + # Do the write. + if shared_memory.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) + store(dma.dst_device_id, + dma.dst_memory_space, + dma.dst_buffer_id, + dma.dst_transforms, + dma.data, + clock=copy_vector_clock(dma.clock), + src_device_id=dma.id, + source_info=dma.source_info) - dma.data = None - dma.state = DmaState.COMPLETED + # Signal the receive semaphore. + if shared_memory.interpret_params.detect_races: + inc_vector_clock(dma.clock, dma.virtual_device_id) + if dma.dst_sem is not None: + dma.dst_sem.signal( + data_size, device_id=dma.dst_device_id, clock=dma.clock) + + dma.data = None + dma.state = DmaState.COMPLETED def print_memory(device_id): device_id = int(device_id) @@ -483,7 +825,7 @@ def print_memory(device_id): def dma_start(device_id, src_memory_space, src_id, src_transforms, dst_memory_space, dst_id, dst_transforms, dst_sem_id, src_sem_id, dst_device_id, - *, interpret_params, source_info=None): + source_info=None): device_id = int(device_id) src_memory_space, src_id = int(src_memory_space), int(src_id) src_transforms = jax.tree.map(int, src_transforms) @@ -501,6 +843,10 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms, dst_sem = shared_memory.sem[dst_sem_id] src_sem = shared_memory.sem[src_sem_id] if src_sem_id is not None else None + clock = None + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + clock = copy_vector_clock(shared_memory.clocks[device_id]) dma_id = shared_memory.next_dma_id shared_memory.next_dma_id += 1 @@ -510,31 +856,35 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms, dst_device_id, dst_memory_space, dst_id, dst_transforms, src_sem, dst_sem, + clock=clock, source_info=source_info, ) - if interpret_params.dma_execution_mode == 'on_wait': + if shared_memory.interpret_params.dma_execution_mode == 'on_wait': shared_memory.dmas_by_sem[dst_sem_id].append(dma) if src_sem_id is not None: shared_memory.dmas_by_sem[src_sem_id].append(dma) return - assert interpret_params.dma_execution_mode == 'eager' + assert shared_memory.interpret_params.dma_execution_mode == 'eager' execute_dma(dma) -def dma_wait(device_id, sem, size, *, interpret_params): +def dma_wait(device_id, sem_id, size): device_id = int(device_id) - sem = int(sem) + sem_id = int(sem_id) size = int(size) shared_memory = _get_shared_memory() with shared_memory.lock: - sem = shared_memory.sem[sem] - sem.wait(size, device_id, is_dma=True, interpret_params=interpret_params) + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + sem = shared_memory.sem[sem_id] + sem.wait(size, device_id, is_dma=True) -def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index): +def semaphore_signal(device_id, sem_id, inc, target_device_id, + target_core_index): device_id = int(device_id) - sem = int(sem) + sem_id = int(sem_id) inc = int(inc) if target_device_id is None: target_device_id = device_id @@ -542,21 +892,28 @@ def semaphore_signal(device_id, sem, inc, target_device_id, target_core_index): target_device_id = int(target_device_id) if target_core_index is not None: - raise NotImplementedError('semaphore_signal with target_core_index') + if int(target_core_index) != 0: + raise NotImplementedError('semaphore_signal with target_core_index != 0') shared_memory = _get_shared_memory() with shared_memory.lock: - sem = shared_memory.sem[sem] - sem.signal(inc, target_device_id) + clock = None + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + clock = copy_vector_clock(shared_memory.clocks[device_id]) + sem = shared_memory.sem[sem_id] + sem.signal(inc, target_device_id, clock) -def semaphore_wait(device_id, sem, value): +def semaphore_wait(device_id, sem_id, value): device_id = int(device_id) - sem = int(sem) + sem_id = int(sem_id) value = int(value) shared_memory = _get_shared_memory() with shared_memory.lock: - sem = shared_memory.sem[sem] + if shared_memory.interpret_params.detect_races: + inc_vector_clock(shared_memory.clocks[device_id], device_id) + sem = shared_memory.sem[sem_id] sem.wait(value, device_id) def _compute_transformed_shape_and_dtype(shape, dtype, transforms): @@ -627,236 +984,238 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): _interpret_jaxpr, compiler_params=compiler_params, interpret_params=interpret_params) for eqn in jaxpr.eqns: - prim = eqn.primitive - invals = jax.util.safe_map(read, eqn.invars) + with source_info_util.user_context( + eqn.source_info.traceback, name_stack=eqn.source_info.name_stack): + prim = eqn.primitive + invals = jax.util.safe_map(read, eqn.invars) - if prim is primitives.load_p: - (ref, transforms, mask, _) = jax.tree.unflatten( - eqn.params['args_tree'], invals) - if mask is not None: - raise NotImplementedError('masked load_p') - out = callback.io_callback( - get, - eqn.outvars[0].aval, - device_id, - TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], - ref, - transforms, - ordered=True) + if prim is primitives.load_p: + (ref, transforms, mask, _) = jax.tree.unflatten( + eqn.params['args_tree'], invals) + if mask is not None: + raise NotImplementedError('masked load_p') + out = callback.io_callback( + functools.partial(get, source_info=eqn.source_info), + eqn.outvars[0].aval, + device_id, + TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], + ref, + transforms, + ordered=True) - elif prim is primitives.swap_p: - (ref, transforms, val, mask) = jax.tree.unflatten( - eqn.params['args_tree'], invals) - out = callback.io_callback( - swap, - eqn.outvars[0].aval, - device_id, - TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], - ref, - transforms, - val, - mask, - ordered=True) + elif prim is primitives.swap_p: + (ref, transforms, val, mask) = jax.tree.unflatten( + eqn.params['args_tree'], invals) + out = callback.io_callback( + functools.partial(swap, source_info=eqn.source_info), + eqn.outvars[0].aval, + device_id, + TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], + ref, + transforms, + val, + mask, + ordered=True) - elif prim is lax.cond_p: - def _make_branch(jaxpr): - return lambda *args: _interpret(jaxpr, *args) - out = lax.switch( - invals[0], - [_make_branch(branch_jaxpr.jaxpr) - for branch_jaxpr in eqn.params['branches']], - *invals[1:]) + elif prim is mosaic_primitives.delay_p: + out = [] - elif prim is lax.scan_p: - consts, init_carry, xs = split_list( - invals, [eqn.params['num_consts'], eqn.params['num_carry']]) - def _scan_body(c, a): - return split_list( - _interpret(eqn.params['jaxpr'].jaxpr, *consts, *c, *a), - [eqn.params['num_carry']]) - carry, out = lax.scan(_scan_body, init_carry, xs=xs, - length=eqn.params.get('length', None)) - out = carry + out + elif prim is lax.cond_p: + def _make_branch(jaxpr): + return lambda *args: _interpret(jaxpr, *args) + out = lax.switch( + invals[0], + [_make_branch(branch_jaxpr.jaxpr) + for branch_jaxpr in eqn.params['branches']], + *invals[1:]) - elif prim is lax.while_p: - cond_consts, body_consts, init_vals = split_list( - invals, [eqn.params['cond_nconsts'], eqn.params['body_nconsts']]) - out = lax.while_loop( - lambda args: _interpret( - eqn.params['cond_jaxpr'].jaxpr, *cond_consts, *args)[0], - lambda args: _interpret( - eqn.params['body_jaxpr'].jaxpr, *body_consts, *args), - init_vals) + elif prim is lax.scan_p: + consts, init_carry, xs = split_list( + invals, [eqn.params['num_consts'], eqn.params['num_carry']]) + def _scan_body(c, a): + return split_list( + _interpret(eqn.params['jaxpr'].jaxpr, *consts, *c, *a), + [eqn.params['num_carry']]) + carry, out = lax.scan(_scan_body, init_carry, xs=xs, + length=eqn.params.get('length', None)) + out = carry + out - elif prim is for_loop.for_p: - raise NotImplementedError('for_p') + elif prim is lax.while_p: + cond_consts, body_consts, init_vals = split_list( + invals, [eqn.params['cond_nconsts'], eqn.params['body_nconsts']]) + out = lax.while_loop( + lambda args: _interpret( + eqn.params['cond_jaxpr'].jaxpr, *cond_consts, *args)[0], + lambda args: _interpret( + eqn.params['body_jaxpr'].jaxpr, *body_consts, *args), + init_vals) - elif prim is pjit.pjit_p: - def f(*args, jaxpr): - return _interpret(jaxpr.jaxpr, *jaxpr.consts, *args) - in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals) - new_jaxpr = _to_jaxpr( - lu.wrap_init(functools.partial(f, jaxpr=eqn.params['jaxpr']), - debug_info=eqn.params['jaxpr'].jaxpr.debug_info), - in_avals) - out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr})) + elif prim is for_loop.for_p: + raise NotImplementedError('for_p') - elif prim is primitives.run_scoped_p: - # Allocate a buffer or semaphore for each element of - # eqn.params['jaxpr'].invars . - allocs = [] - for v in eqn.params['jaxpr'].invars: - if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: - allocs.append(callback.io_callback( - _allocate_semaphores, - jax.ShapeDtypeStruct(v.aval.shape, jnp.int16), - device_id, - v.aval.shape, - ordered=True)) - else: - allocs.append(callback.io_callback( - _allocate_buffer, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], - primitives.uninitialized_value(v.aval.shape, v.aval.dtype), - ordered=True)) + elif prim is pjit.pjit_p: + def f(*args, jaxpr): + return _interpret(jaxpr.jaxpr, *jaxpr.consts, *args) + in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals) + new_jaxpr = _to_jaxpr( + lu.wrap_init(functools.partial(f, jaxpr=eqn.params['jaxpr']), + debug_info=eqn.params['jaxpr'].jaxpr.debug_info), + in_avals) + out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr})) - out = _interpret(eqn.params['jaxpr'], *invals, *allocs) + elif prim is primitives.run_scoped_p: + # Allocate a buffer or semaphore for each element of + # eqn.params['jaxpr'].invars . + allocs = [] + for v in eqn.params['jaxpr'].invars: + if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: + allocs.append(callback.io_callback( + _allocate_semaphores, + jax.ShapeDtypeStruct(v.aval.shape, jnp.int16), + device_id, + v.aval.shape, + ordered=True)) + else: + allocs.append(callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], + primitives.uninitialized_value(v.aval.shape, v.aval.dtype), + ordered=True)) - for a in allocs: - if isinstance(a, tuple): - callback.io_callback( - _deallocate_buffer, - None, - device_id, - TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], - a, - ordered=True) - else: - # TODO(jburnim): Delete semaphores. - # callback.io_callback( - # _deallocate_semaphores, - # None, - # device_id, - # a, - # ordered=True) - pass + out = _interpret(eqn.params['jaxpr'], *invals, *allocs) - elif prim is state_primitives.get_p: - out = callback.io_callback( - get, - eqn.outvars[0].aval, - device_id, - TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], - invals[0], - jax.tree.unflatten(eqn.params['tree'], invals[1:]), - ordered=True) + for a in allocs: + if isinstance(a, tuple): + callback.io_callback( + _deallocate_buffer, + None, + device_id, + TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], + a, + ordered=True) + else: + # TODO(jburnim): De-allocate semaphores. + # callback.io_callback( + # _deallocate_semaphores, + # None, + # device_id, + # a, + # ordered=True) + pass - elif prim is state_primitives.swap_p: - out = callback.io_callback( - swap, - eqn.outvars[0].aval, - device_id, - TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], - invals[0], - jax.tree.unflatten(eqn.params['tree'], invals[2:]), - invals[1], - None, - ordered=True) + elif prim is state_primitives.get_p: + out = callback.io_callback( + functools.partial(get, source_info=eqn.source_info), + eqn.outvars[0].aval, + device_id, + TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], + invals[0], + jax.tree.unflatten(eqn.params['tree'], invals[1:]), + ordered=True) - elif prim is mosaic_primitives.dma_start_p: - (src, src_transforms, - dst, dst_transforms, - dst_sem, dst_sem_transforms, - src_sem, src_sem_transforms, - target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals) - target_device_id = _device_id_to_logical( - target_device_id, eqn.params['device_id_type'], axis_sizes) - (orig_src_ref, _, orig_dst_ref, *_ - ) = jax.tree.unflatten(eqn.params['tree'], eqn.invars) - callback.io_callback( - functools.partial(dma_start, interpret_params=interpret_params, - source_info=eqn.source_info), - (), - device_id, - TPU_MEMORY_SPACE_IDXS[orig_src_ref.aval.memory_space], - src, src_transforms, - TPU_MEMORY_SPACE_IDXS[orig_dst_ref.aval.memory_space], - dst, dst_transforms, - state_discharge.transform_array(dst_sem, dst_sem_transforms), - state_discharge.transform_array(src_sem, src_sem_transforms), - target_device_id, - ordered=True) - out = [] + elif prim is state_primitives.swap_p: + out = callback.io_callback( + functools.partial(swap, source_info=eqn.source_info), + eqn.outvars[0].aval, + device_id, + TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], + invals[0], + jax.tree.unflatten(eqn.params['tree'], invals[2:]), + invals[1], + None, + ordered=True) - elif prim is mosaic_primitives.dma_wait_p: - (src, src_transforms, - dst, dst_transforms, - dst_sem, dst_sem_transforms, - src_sem, src_sem_transforms, - target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals) - read_shape, read_dtype = _compute_transformed_shape_and_dtype( - eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms) - callback.io_callback( - functools.partial(dma_wait, interpret_params=interpret_params), - (), - device_id, - state_discharge.transform_array(dst_sem, dst_sem_transforms), - math.prod(read_shape) * read_dtype.itemsize, - ordered=True) - out = [] + elif prim is mosaic_primitives.dma_start_p: + (src, src_transforms, + dst, dst_transforms, + dst_sem, dst_sem_transforms, + src_sem, src_sem_transforms, + target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals) + target_device_id = _device_id_to_logical( + target_device_id, eqn.params['device_id_type'], axis_sizes) + (orig_src_ref, _, orig_dst_ref, *_ + ) = jax.tree.unflatten(eqn.params['tree'], eqn.invars) + callback.io_callback( + functools.partial(dma_start, source_info=eqn.source_info), + (), + device_id, + TPU_MEMORY_SPACE_IDXS[getattr(orig_src_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)], + src, src_transforms, + TPU_MEMORY_SPACE_IDXS[getattr(orig_dst_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)], + dst, dst_transforms, + state_discharge.transform_array(dst_sem, dst_sem_transforms), + state_discharge.transform_array(src_sem, src_sem_transforms), + target_device_id, + ordered=True) + out = [] - elif prim is mosaic_primitives.get_barrier_semaphore_p: - out = callback.io_callback( - get_barrier_semaphore, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - compiler_params['mosaic']['collective_id'], - ordered=True) + elif prim is mosaic_primitives.dma_wait_p: + (src, src_transforms, + dst, dst_transforms, + dst_sem, dst_sem_transforms, + src_sem, src_sem_transforms, + target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals) + read_shape, read_dtype = _compute_transformed_shape_and_dtype( + eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms) + callback.io_callback( + dma_wait, + (), + device_id, + state_discharge.transform_array(dst_sem, dst_sem_transforms), + math.prod(read_shape) * read_dtype.itemsize, + ordered=True) + out = [] - elif prim is mosaic_primitives.semaphore_signal_p: - sem, sem_transforms, inc, target_device_id, core_index = ( - jax.tree.unflatten(eqn.params['args_tree'], invals)) - target_device_id = _device_id_to_logical( - target_device_id, eqn.params['device_id_type'], axis_sizes) - callback.io_callback( - semaphore_signal, - (), - device_id, - state_discharge.transform_array(sem, sem_transforms), - inc, - target_device_id, - core_index, - ordered=True) - out = [] + elif prim is mosaic_primitives.get_barrier_semaphore_p: + out = callback.io_callback( + get_barrier_semaphore, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + compiler_params['mosaic']['collective_id'], + ordered=True) - elif prim is mosaic_primitives.semaphore_wait_p: - sem, sem_transforms, value = ( - jax.tree.unflatten(eqn.params['args_tree'], invals)) - callback.io_callback( - semaphore_wait, - (), - device_id, - state_discharge.transform_array(sem, sem_transforms), - value, - ordered=True) - out = [] + elif prim is mosaic_primitives.semaphore_signal_p: + sem, sem_transforms, inc, target_device_id, core_index = ( + jax.tree.unflatten(eqn.params['args_tree'], invals)) + target_device_id = _device_id_to_logical( + target_device_id, eqn.params['device_id_type'], axis_sizes) + callback.io_callback( + semaphore_signal, + (), + device_id, + state_discharge.transform_array(sem, sem_transforms), + inc, + target_device_id, + core_index, + ordered=True) + out = [] - elif prim is primitives.atomic_rmw_p: - raise NotImplementedError('atomic_rmw_p') + elif prim is mosaic_primitives.semaphore_wait_p: + sem, sem_transforms, value = ( + jax.tree.unflatten(eqn.params['args_tree'], invals)) + callback.io_callback( + semaphore_wait, + (), + device_id, + state_discharge.transform_array(sem, sem_transforms), + value, + ordered=True) + out = [] - elif prim is primitives.atomic_cas_p: - raise NotImplementedError('atomic_cas_p') + elif prim is primitives.atomic_rmw_p: + raise NotImplementedError('atomic_rmw_p') - else: - # TODO(jburnim): Add special handling for nested pallas_call_p. - # (For example, so that buffers can be shared with nested Pallas calls.) - subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - out = prim.bind(*subfuns, *invals, **bind_params) + elif prim is primitives.atomic_cas_p: + raise NotImplementedError('atomic_cas_p') - out = out if prim.multiple_results else [out] - jax.util.safe_map(write, eqn.outvars, out) + else: + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + out = prim.bind(*subfuns, *invals, **bind_params) + + out = out if prim.multiple_results else [out] + jax.util.safe_map(write, eqn.outvars, out) return jax.util.safe_map(read, jaxpr.outvars) @@ -961,7 +1320,8 @@ def interpret_pallas_call( tuple(lax.axis_index(s) for s in axis_sizes.keys()), axis_sizes) callback.io_callback( - _initialize_shared_memory, + functools.partial( + _initialize_shared_memory, interpret_params=interpret_params), (), device_id, num_devices, @@ -999,13 +1359,12 @@ def interpret_pallas_call( TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], padded_val, ordered=True)) - # Allocate buffers for all kernel arguments (e.g., scalars, inputs, # outputs, scratch). io_alias_map = dict(input_output_aliases) oi_alias_map = {v: k for k, v in input_output_aliases} kernel_buffer_ids = [] - for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): + for _, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): kernel_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), @@ -1107,6 +1466,8 @@ def interpret_pallas_call( input_args[j], is_indexing_dim[j]) assert(sliced_val.shape == var.aval.shape) callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # store is involved in a data race. store, (), device_id, @@ -1129,6 +1490,8 @@ def interpret_pallas_call( if _is_any(var.aval.memory_space): continue kernel_output_val = callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # get is involved in a data race. get, var.aval, device_id, @@ -1144,6 +1507,8 @@ def interpret_pallas_call( shape=output_vals[j].shape, int_indexer_shape=()) callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # store is involved in a data race. store, (), device_id, @@ -1165,6 +1530,8 @@ def interpret_pallas_call( # Read the output from the allocated output buffers. ret = [ callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # get is involved in a data race. get, val, device_id, @@ -1178,33 +1545,22 @@ def interpret_pallas_call( output_vals, output_buffer_ids, output_buffer_shapes) ] - for buffer_id in output_buffer_ids: - callback.io_callback( - _deallocate_buffer, - (), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - buffer_id, - ordered=True) - for buffer_id, var in zip(kernel_buffer_ids, jaxpr.invars): - if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: - pass - else: - callback.io_callback( - _deallocate_buffer, - (), - device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - buffer_id, - ordered=True) + callback.io_callback( + _validate, + (), + device_id, + ordered=True) - # TODO(jburnim): Either validate just the semaphores allocated for this - # pallas_call, or only do validation if we are exiting a top-level - # (i.e., not nested) pallas_call. - # callback.io_callback( - # _validate, - # (), - # device_id, - # ordered=True) + # For now, when we're done with a pallas_call, we delete the shared memory. + # We use a barrier to ensure that all devices are done running the kernel. + # + # TODO(jburnim): Get rid of this barrier. And figure out how this should + # work if we want to invoke successive pallas_calls that use the same + # shared memory. + callback.io_callback( + _clean_up_shared_memory, + (), + device_id, + ordered=True) return ret diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index e5fbbdd4b..518c16ed2 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -18,6 +18,8 @@ To work around https://github.com/jax-ml/jax/issues/25671 , this file contains only tests that use shard_map. """ +import functools + from absl.testing import absltest from absl.testing import parameterized @@ -39,12 +41,16 @@ P = jax.sharding.PartitionSpec class InterpretDistributedTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if jax.device_count() < 4: + self.skipTest(f'requires at least 4 devices, found {jax.device_count()}') - @parameterized.parameters('eager', 'on_wait') - def test_right_permute_example(self, dma_execution_mode): + @parameterized.product( + dma_execution_mode=['eager', 'on_wait'], + detect_races=[True, False]) + def test_right_permute_example(self, dma_execution_mode, detect_races): num_devices = jax.device_count() - if num_devices < 4: - self.skipTest(f'requires at least 4 devices, found {num_devices}') partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) @@ -61,36 +67,26 @@ class InterpretDistributedTest(jtu.JaxTestCase): right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices)) barrier_sem = pltpu.get_barrier_semaphore() - def _body(ijk): - i, (j, k) = ijk - lax.cond( - (i == 0) | (j == 0), - lambda: pltpu.semaphore_signal( - barrier_sem, - device_id=(left_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH), - lambda: pltpu.semaphore_signal( - barrier_sem, - device_id=(right_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH)) - return (i + 1, (j + 1, k + 1)) - lax.while_loop(lambda ijk: ijk[0] < 2, _body, (0, (0, 0))) + pltpu.semaphore_signal( + barrier_sem, + device_id=(left_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH) + pltpu.semaphore_signal( + barrier_sem, + device_id=(right_neighbor,), + device_id_type=pltpu.DeviceIdType.MESH) pltpu.semaphore_wait(barrier_sem, 2) - def _body2(i, a): - remote_copy_op = pltpu.make_async_remote_copy( + remote_copy_op = pltpu.make_async_remote_copy( src_ref=input_ref, dst_ref=output_ref, send_sem=send_sem, recv_sem=recv_sem, device_id=(right_neighbor,), device_id_type=pltpu.DeviceIdType.MESH, - ) - remote_copy_op.start() - remote_copy_op.wait() - - return i + 1, a + 1 - _ = lax.scan(_body2, 0, jnp.arange(4.0), unroll=2) + ) + remote_copy_op.start() + remote_copy_op.wait() out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( @@ -111,7 +107,7 @@ class InterpretDistributedTest(jtu.JaxTestCase): grid_spec=grid_spec, compiler_params=pltpu.TPUCompilerParams(collective_id=13), interpret=mosaic_interpret.TPUInterpretParams( - dma_execution_mode=dma_execution_mode), + dma_execution_mode=dma_execution_mode, detect_races=detect_races), ) # Wrap the kernel within a shard_map to call. pallas_result = jax.jit( @@ -133,12 +129,14 @@ class InterpretDistributedTest(jtu.JaxTestCase): )(input_arr) np.testing.assert_allclose(xla_result, pallas_result) + if detect_races: + self.assertFalse(mosaic_interpret.races.races_found) - @parameterized.parameters('eager', 'on_wait') - def test_all_gather_example(self, dma_execution_mode): + @parameterized.product( + dma_execution_mode=['eager', 'on_wait'], + detect_races=[True, False]) + def test_all_gather_example(self, dma_execution_mode, detect_races): num_devices = jax.device_count() - if num_devices < 4: - self.skipTest(f'requires at least 4 devices, found {num_devices}') partition = P('x', None) mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) @@ -230,7 +228,7 @@ class InterpretDistributedTest(jtu.JaxTestCase): out_shape=out_shape, grid_spec=grid_spec, interpret=mosaic_interpret.TPUInterpretParams( - dma_execution_mode=dma_execution_mode), + dma_execution_mode=dma_execution_mode, detect_races=detect_races), compiler_params=pltpu.TPUCompilerParams(collective_id=0), ) @@ -254,12 +252,14 @@ class InterpretDistributedTest(jtu.JaxTestCase): )(input_arr) np.testing.assert_allclose(xla_result, pallas_result) + if detect_races: + self.assertFalse(mosaic_interpret.races.races_found) - @parameterized.parameters('eager', 'on_wait') - def test_all_reduce_sum_example(self, dma_execution_mode): + @parameterized.product( + dma_execution_mode=['eager', 'on_wait'], + detect_races=[True, False]) + def test_all_reduce_sum_example(self, dma_execution_mode, detect_races): num_devices = jax.device_count() - if num_devices < 4: - self.skipTest(f'requires at least 4 devices, found {num_devices}') partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) @@ -388,7 +388,7 @@ class InterpretDistributedTest(jtu.JaxTestCase): out_shape=out_shape, grid_spec=grid_spec, interpret=mosaic_interpret.TPUInterpretParams( - dma_execution_mode=dma_execution_mode), + dma_execution_mode=dma_execution_mode, detect_races=detect_races), compiler_params=pltpu.TPUCompilerParams(collective_id=0), ) @@ -413,12 +413,14 @@ class InterpretDistributedTest(jtu.JaxTestCase): )(input_arr) np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5) + if detect_races: + self.assertFalse(mosaic_interpret.races.races_found) - @parameterized.parameters('eager', 'on_wait') - def test_reduce_scatter_sum_example(self, dma_execution_mode): + @parameterized.product( + dma_execution_mode=['eager', 'on_wait'], + detect_races=[True, False]) + def test_reduce_scatter_sum_example(self, dma_execution_mode, detect_races): num_devices = jax.device_count() - if num_devices < 4: - self.skipTest(f'requires at least 4 devices, found {num_devices}') partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) @@ -670,7 +672,7 @@ class InterpretDistributedTest(jtu.JaxTestCase): out_shape=out_shape, grid_spec=grid_spec, interpret=mosaic_interpret.TPUInterpretParams( - dma_execution_mode=dma_execution_mode), + dma_execution_mode=dma_execution_mode, detect_races=True), compiler_params=pltpu.TPUCompilerParams(collective_id=7), )(input_arr)[0] @@ -700,17 +702,19 @@ class InterpretDistributedTest(jtu.JaxTestCase): )(input_arr) np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5) + if detect_races: + self.assertFalse(mosaic_interpret.races.races_found) - @parameterized.parameters('eager', 'on_wait') + @parameterized.product( + dma_execution_mode=['eager', 'on_wait'], + detect_races=[True, False]) def test_reduce_scatter_sum_with_emit_pipeline_example( - self, dma_execution_mode): + self, dma_execution_mode, detect_races): self.skipTest('requires a patched pallas.emit_pipeline to specify/fake ' 'the TPU generation') if jax.config.jax_enable_x64: self.skipTest('pallas.emit_pipeline + x64 is not currently supported') num_devices = jax.device_count() - if num_devices < 4: - self.skipTest(f'requires at least 4 devices, found {num_devices}') partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, partition) @@ -972,7 +976,7 @@ class InterpretDistributedTest(jtu.JaxTestCase): out_shape=out_shape, grid_spec=grid_spec, interpret=mosaic_interpret.TPUInterpretParams( - dma_execution_mode=dma_execution_mode), + dma_execution_mode=dma_execution_mode, detect_races=detect_races), compiler_params=pltpu.TPUCompilerParams(collective_id=19), )(input_arr)[0] @@ -1001,6 +1005,95 @@ class InterpretDistributedTest(jtu.JaxTestCase): )(input_arr) np.testing.assert_allclose(xla_result, pallas_result, atol=1e-5) + if detect_races: + self.assertFalse(mosaic_interpret.races.races_found) + + def test_race_detection(self): + num_devices = 4 + mesh = jax.sharding.Mesh(np.array(jax.devices()[:4]), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, P('x', None)) + + input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128)) + input_arr = jax.device_put(input_arr, sharding) + + def kernel(src_dst_ids_ref, x_ref, o_ref, send_sem, recv_sem): + # Barrier with all devices before doing any DMAs. + barrier_sem = pltpu.get_barrier_semaphore() + @functools.partial(jax.lax.fori_loop, 0, num_devices, init_val=None) + def _(i, _): + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(jnp.int32(i),), + device_id_type=pltpu.DeviceIdType.MESH, + ) + return None + pltpu.semaphore_wait(barrier_sem, num_devices) + + # Send the specified DMAs. + my_id = lax.axis_index('x') + src_dst_ids = src_dst_ids_ref[:] + recv_count = 0 + for i in range(src_dst_ids.shape[0]): + src_id = src_dst_ids[i, 0] + dst_id = src_dst_ids[i, 1] + @pl.when(src_id == my_id) + def _(): + dma = pltpu.make_async_remote_copy( + src_ref=x_ref, + dst_ref=o_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id=(dst_id,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + dma.start() + dma.wait_send() + recv_count += jnp.where(dst_id == my_id, 1, 0) + + # Wait until we have received all DMAs. + @pl.when(recv_count > 0) + def _(): + fake_dma = pltpu.make_async_remote_copy( + src_ref=x_ref.at[pl.ds(0, 8 * recv_count)], + dst_ref=o_ref.at[pl.ds(0, 8 * recv_count)], + send_sem=send_sem, + recv_sem=recv_sem, + device_id=(my_id,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + fake_dma.wait_recv() + + @jax.jit + def run(src_dst_ids): + return shard_map.shard_map( + pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), input_arr.dtype), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + scratch_shapes=[pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA], + compiler_params=pltpu.TPUCompilerParams(collective_id=0), + interpret=mosaic_interpret.TPUInterpretParams( + dma_execution_mode='eager', + detect_races=True, + ), + ), + mesh=mesh, + in_specs=(P(None), P('x', None)), + out_specs=P('x', None), + check_rep=False, + )(src_dst_ids, input_arr) + + run(jnp.array([[0, 1], [1, 2], [2, 3]], jnp.int32)).block_until_ready() + self.assertFalse(mosaic_interpret.races.races_found) + + # Racing writes to device 2. + run(jnp.array([[0, 1], [1, 2], [3, 2], [3, 0]], jnp.int32)).block_until_ready() + self.assertTrue(mosaic_interpret.races.races_found) if __name__ == "__main__": diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 21bfb57d6..632dba949 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -19,11 +19,13 @@ contains only tests that do not use shard_map. """ from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import test_util as jtu import jax._src.pallas.mosaic.interpret as mosaic_interpret from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -33,13 +35,14 @@ jax.config.parse_flags_with_absl() class InterpretTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + self.num_devices = jax.device_count() + if self.num_devices > 1: + # Workaround for https://github.com/jax-ml/jax/issues/25671 + self.skipTest(f'requires 1 device, found {self.num_devices}') def test_matmul_example(self): - num_devices = jax.device_count() - if num_devices > 1: - # Workaround for https://github.com/jax-ml/jax/issues/25671 - self.skipTest(f'requires 1 device, found {num_devices}') - def matmul_kernel(x_ref, y_ref, z_ref): z_ref[...] = x_ref[...] @ y_ref[...] @@ -66,11 +69,6 @@ class InterpretTest(jtu.JaxTestCase): np.testing.assert_allclose(z, x @ y, atol=1e-4) def test_dynamic_grid(self): - num_devices = jax.device_count() - if num_devices > 1: - # Workaround for https://github.com/jax-ml/jax/issues/25671 - self.skipTest(f'requires 1 device, found {num_devices}') - def kernel(x_ref, o_ref): o_ref[...] = x_ref[...] @@ -90,5 +88,47 @@ class InterpretTest(jtu.JaxTestCase): y = f(x) np.testing.assert_allclose(y, x) + @parameterized.parameters('eager', 'on_wait') + def test_race_detection(self, dma_execution_mode): + def kernel_without_race(x_ref, o_ref, t_ref, sem): + copy = pltpu.make_async_copy(x_ref, t_ref, sem) + copy.start() + copy.wait() + o_ref[...] = t_ref[...] + 1.0 + + def kernel_with_race(x_ref, o_ref, t_ref, sem): + copy = pltpu.make_async_copy(x_ref, t_ref, sem) + copy.start() + # This read of t_ref races with the above DMA's write of t_ref. + o_ref[...] = t_ref[...] + 1.0 + copy.wait() + + x = jnp.zeros((8, 128), jnp.float32) + y = pl.pallas_call(kernel_without_race, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + scratch_shapes=[ + pltpu.VMEM(x.shape, x.dtype), + pltpu.SemaphoreType.DMA, + ], + interpret=mosaic_interpret.TPUInterpretParams( + detect_races=True, dma_execution_mode=dma_execution_mode), + )(x).block_until_ready() + self.assertFalse(mosaic_interpret.races.races_found) + np.testing.assert_allclose(y, x + 1.0) + + pl.pallas_call(kernel_with_race, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + scratch_shapes=[ + pltpu.VMEM(x.shape, x.dtype), + pltpu.SemaphoreType.DMA, + ], + interpret=mosaic_interpret.TPUInterpretParams( + detect_races=True, dma_execution_mode=dma_execution_mode), + )(x).block_until_ready() + self.assertTrue(mosaic_interpret.races.races_found) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From c16f37d89daa83dce20cd04a7f82a4c62ca0639d Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 5 Mar 2025 18:08:35 -0800 Subject: [PATCH 036/100] Set `USERPROFILE` for Windows builds to fix CI issue. This change fixes https://github.com/jax-ml/jax/actions/runs/13686468791/job/38270929632. From the [documentation](https://docs.python.org/3/library/os.path.html#os.path.expanduser): `On Windows, USERPROFILE will be used if set, otherwise a combination of HOMEPATH and HOMEDRIVE will be used.` PiperOrigin-RevId: 733935305 --- jaxlib/tools/build_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jaxlib/tools/build_utils.py b/jaxlib/tools/build_utils.py index 9c7f61fc2..4c50cff16 100644 --- a/jaxlib/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -70,6 +70,12 @@ def build_wheel( env = dict(os.environ) if git_hash: env["JAX_GIT_HASH"] = git_hash + if is_windows() and ( + "USERPROFILE" not in env + and "HOMEDRIVE" not in env + and "HOMEPATH" not in env + ): + env["USERPROFILE"] = env.get("SYSTEMDRIVE", "C:") subprocess.run( [sys.executable, "-m", "build", "-n"] + (["-w"] if build_wheel_only else []), From ba5349f8961ed4b878357937eecb7f73c441b2ae Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 5 Mar 2025 19:34:25 -0800 Subject: [PATCH 037/100] Add a note about uneven sharding and with_sharding_constraint. Fixes https://github.com/jax-ml/jax/issues/26946 PiperOrigin-RevId: 733953836 --- jax/_src/pjit.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index eb443f572..dad5c949a 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2536,6 +2536,11 @@ def with_sharding_constraint(x, shardings): This is a strict constraint for the GSPMD partitioner and not a hint. For examples of how to use this function, see `Distributed arrays and automatic parallelization`_. + Inside of a jitted computation, with_sharding_constraint makes it possible to + constrain intermediate values to an uneven sharding. However, if such an + unevenly sharded value is output by the jitted computation, it will come out + as fully replicated, no matter the sharding annotation given. + Args: x: PyTree of jax.Arrays which will have their shardings constrained shardings: PyTree of sharding specifications. Valid values are the same as for From a67ab9fade345cdb60bdece5bcf6b97793938966 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 5 Mar 2025 20:08:54 -0800 Subject: [PATCH 038/100] Just use `jit` as the string in error messages instead of `jit` and `pjit` based on resource_env. This is to start deprecating the need for `with mesh` and replace it with `use_mesh(mesh)`. PiperOrigin-RevId: 733959962 --- jax/_src/pjit.py | 18 +++++++----------- tests/name_stack_test.py | 6 +++--- tests/pjit_test.py | 22 +++++++++++----------- 3 files changed, 21 insertions(+), 25 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index dad5c949a..86df66301 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -199,10 +199,9 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): profiler = None except pxla.DeviceAssignmentMismatchError as e: fails, = e.args - api_name = 'jit' if p.params['resource_env'] is None else 'pjit' fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) msg = _device_assignment_mismatch_error( - fun_name, fails, args_flat, api_name, p.arg_names) + fun_name, fails, args_flat, 'jit', p.arg_names) raise ValueError(msg) from None except xla.InvalidInputException as e: arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names @@ -591,13 +590,12 @@ def _infer_params_impl( in_shardings_leaves = out_shardings_leaves = tuple(leaves) in_shardings_treedef = out_shardings_treedef = treedef else: - jit_name = 'pjit' if pjit_mesh is not None else 'jit' in_shardings_leaves = tuple( - _create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name) + _create_sharding_for_array(pjit_mesh, x, 'in_shardings', 'jit') for x in ji.in_shardings_leaves) in_shardings_treedef = ji.in_shardings_treedef out_shardings_leaves = tuple( - _create_sharding_for_array(pjit_mesh, x, 'out_shardings', jit_name) + _create_sharding_for_array(pjit_mesh, x, 'out_shardings', 'jit') for x in ji.out_shardings_leaves) out_shardings_treedef = ji.out_shardings_treedef @@ -1760,12 +1758,10 @@ def _pjit_lower( lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): util.test_event("pjit_lower") - if resource_env is not None: - mesh, api_name = resource_env.physical_mesh, 'pjit' - else: - mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit' + mesh = (resource_env.physical_mesh if resource_env is not None else + mesh_lib.get_concrete_mesh()) return pxla.lower_sharding_computation( - jaxpr, api_name, name, in_shardings, out_shardings, + jaxpr, 'jit', name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), keep_unused=keep_unused, context_mesh=mesh, compiler_options_kvs=compiler_options_kvs, @@ -1929,7 +1925,7 @@ def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str, func = _pjit_cached_lower_jaxpr_to_fun( ctx, name, jaxpr, tuple(effects), in_shardings, out_shardings, in_layouts, out_layouts, - api_name=('jit' if resource_env is None else 'pjit')) + api_name='jit') tokens_in = [ctx.tokens_in.get(eff) for eff in effects] args = (*ctx.dim_var_values, *tokens_in, *args) diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index 270707934..f371c431e 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -263,9 +263,9 @@ class NameStackTransformationTest(jtu.JaxTestCase): return g(x) hlo_text = _get_hlo(f)(2.) - self.assertIn('jvp(pjit(f))/pjit(g)/sin', hlo_text) - self.assertIn('jvp(pjit(f))/pjit(g)/cos', hlo_text) - self.assertIn('transpose(jvp(pjit(f)))/pjit(g)/mul', hlo_text) + self.assertIn('jvp(jit(f))/jit(g)/sin', hlo_text) + self.assertIn('jvp(jit(f))/jit(g)/cos', hlo_text) + self.assertIn('transpose(jvp(jit(f)))/jit(g)/mul', hlo_text) def test_remat_appears_in_hlo(self): @ad_checkpoint.remat diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 7a7f6b7d6..4c20ac649 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -2076,7 +2076,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with global_mesh: with self.assertRaisesRegex( - ValueError, "Received incompatible devices for pjitted computation"): + ValueError, "Received incompatible devices for jitted computation"): pjit(lambda x: x)(input_array) def test_array_lower_compile(self): @@ -2177,7 +2177,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with m1: with self.assertRaisesRegex( - ValueError, "Received incompatible devices for pjitted computation"): + ValueError, "Received incompatible devices for jitted computation"): pjit(lambda x, y: (x, y), out_shardings=(NamedSharding(m1, spec), NamedSharding(m2, spec)))(a1, a1) @@ -2192,7 +2192,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with m1: with self.assertRaisesRegex( - ValueError, "Received incompatible devices for pjitted computation"): + ValueError, "Received incompatible devices for jitted computation"): pjit( lambda x, y: (x, y), in_shardings=NamedSharding(m2, spec), @@ -2348,7 +2348,7 @@ class ArrayPjitTest(jtu.JaxTestCase): arr = jnp.array([1, 2, 3]) with self.assertRaisesRegex( RuntimeError, - r'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or' + r'jit requires a non-empty mesh if you are passing `PartitionSpec`s or' r' `None` to in_shardings.*'): pjit(lambda x: x, in_shardings=P('x'))(arr) @@ -2396,7 +2396,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with jtu.create_mesh((2, 2), ('x', 'y')): with self.assertRaisesRegex( ValueError, - "Received incompatible devices for pjitted computation"): + "Received incompatible devices for jitted computation"): pjit(lambda x, y: (x, y))(uarr, carr) def test_pjit_uncommitted_array_multi_devices(self): @@ -2418,7 +2418,7 @@ class ArrayPjitTest(jtu.JaxTestCase): b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1]) with self.assertRaisesRegex( ValueError, - "Received incompatible devices for pjitted computation. Got argument " + "Received incompatible devices for jitted computation. Got argument " r"x of.*\ with shape int.*\[3\] and device ids \[0\].*and " r"argument y of.*\ with shape int.*\[3\] and device ids \[1\].*"): pjit(lambda x, y: (x, y))(a, b) @@ -2430,7 +2430,7 @@ class ArrayPjitTest(jtu.JaxTestCase): b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1]) with self.assertRaisesRegex( ValueError, - "Received incompatible devices for pjitted computation. Got argument " + "Received incompatible devices for jitted computation. Got argument " r"x\[0\] of.*\ with shape int.*\[3\] and device ids \[0\].*and " r"argument x\[1\] of.*\ with shape int.*\[3\] and device ids " r"\[1\].*"): @@ -2443,7 +2443,7 @@ class ArrayPjitTest(jtu.JaxTestCase): c = jax.device_put(np.arange(16).reshape(8, 2), NamedSharding(mesh, P('x', 'y'))) - msg = ("Received incompatible devices for pjitted computation. Got " + msg = ("Received incompatible devices for jitted computation. Got " r"argument {} of.* with shape int.*\[3\] and device ids " r"\[0\].*and argument {} of.* with shape int.*\[8,2\] and " r"device ids.*") @@ -2617,9 +2617,9 @@ class ArrayPjitTest(jtu.JaxTestCase): return f(inp1, inp2, inp3) with self.assertRaisesRegex( ValueError, - "Received incompatible devices for pjitted computation. Got argument " + "Received incompatible devices for jitted computation. Got argument " r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*" - r"pjit inside pjit with device ids.*"): + r"pjit inside jit with device ids.*"): my_nested_pjit(committed_inp, committed_inp, committed_inp) @jtu.ignore_warning(category=DeprecationWarning, @@ -7236,7 +7236,7 @@ class PJitErrorTest(jtu.JaxTestCase): xshape = (2, 5, 6) x = jnp.arange(math.prod(xshape)).reshape(xshape) with self.assertRaisesRegex( - ValueError, "Received incompatible devices for pjitted computation.*"): + ValueError, "Received incompatible devices for jitted computation.*"): f(x) @parameterized.named_parameters( From fe577b5dc4c775178c017a0d5055ae53a57db81e Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 6 Mar 2025 00:44:23 -0800 Subject: [PATCH 039/100] [Pallas/Mosaic GPU] Enable `ops_test` for Mosaic GPU. For now, most of the tests are skipped. PiperOrigin-RevId: 734026728 --- tests/pallas/BUILD | 37 ++++++++++++++ tests/pallas/ops_test.py | 106 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 35363de8c..9a42d64d3 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -127,6 +127,43 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) +jax_multiplatform_test( + name = "ops_test_mgpu", + srcs = [ + "ops_test.py", + ], + disable_configs = [ + "gpu_v100", + "gpu_v100_x32", + "gpu_p100", + "gpu_p100_x32", + "gpu_a100", + "gpu_a100_x32", + ], + enable_backends = [ + "gpu", + ], + enable_configs = [ + "gpu_h100", + "gpu_h100_x32", + ], + env = { + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + "JAX_PALLAS_VERBOSE_ERRORS": "0", + }, + tags = [ + "noasan", # Times out. + "nomsan", # Times out. + "notsan", # Times out. + ], + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", # build_cleaner: keep + "//jax:pallas_mosaic_gpu", # build_cleaner: keep + "//jax:pallas_tpu", + ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), +) + jax_multiplatform_test( name = "indexing_test", srcs = [ diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index daa384cc5..d79a172ed 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -58,6 +58,7 @@ import hypothesis.strategies as hps jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=50) +use_mosaic_gpu = jax.config.read("jax_pallas_use_mosaic_gpu") intx = dtypes.canonicalize_dtype(jnp.int64) floatx = dtypes.canonicalize_dtype(jnp.float64) @@ -280,6 +281,10 @@ class PallasBaseTest(jtu.JaxTestCase): def pallas_call(cls, *args, **kwargs): return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) + def skip_if_mosaic_gpu(self): + if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: + self.skipTest("TODO: Mosaic GPU does not support this yet") + class OpsTest(PallasBaseTest): @@ -295,6 +300,8 @@ class OpsTest(PallasBaseTest): ] ) def test_weak_dtype(self, fn, dtype): + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), dtype), ) @@ -332,6 +339,7 @@ class OpsTest(PallasBaseTest): We don't really expect that the results would be wrong, but rather we want to exercise the lowering rules. """ + self.skip_if_mosaic_gpu() def kernel(x_ref, y_ref, o_ref): x = x_ref[0, 0] @@ -393,6 +401,7 @@ class OpsTest(PallasBaseTest): We don't really expect that the results would be wrong, but rather we want to exercise the lowering rules. """ + self.skip_if_mosaic_gpu() def kernel(x_ref, y_ref, o_ref): x = x_ref[:] @@ -532,6 +541,8 @@ class OpsTest(PallasBaseTest): ) @hp.given(hps.data()) def test_unary_primitives(self, name, func, shape_dtype_strategy, data): + self.skip_if_mosaic_gpu() + if self.INTERPRET: self.skipTest("This hypothesis test is slow, even more so in interpret mode.") # We want exact equality here to match how JAX lowers to XLA @@ -555,6 +566,8 @@ class OpsTest(PallasBaseTest): @parameterized.product(from_dtype=_DTYPES_32BIT, to_dtype=_DTYPES) @hp.given(hps.data()) def test_cast_from_32bit(self, from_dtype, to_dtype, data): + self.skip_if_mosaic_gpu() + if from_dtype == to_dtype: self.skipTest("Unnecessary test") if jtu.is_device_tpu(version=4): @@ -599,6 +612,8 @@ class OpsTest(PallasBaseTest): # miss bugs that would be hidden due to exhaustive enumeration being in order. @parameterized.product(from_dtype=_DTYPES_SUB_32BIT, to_dtype=_DTYPES, randomize=(False, True)) def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): + self.skip_if_mosaic_gpu() + if from_dtype == to_dtype: self.skipTest("Unnecessary test") if jtu.is_device_tpu(version=4): @@ -717,6 +732,7 @@ class OpsTest(PallasBaseTest): dtype=(jnp.int32, jnp.int16, jnp.int8), ) def test_scalar_map(self, shape, dtype): + self.skip_if_mosaic_gpu() if pltpu is None: self.skipTest("No TPU module available.") if dtype != jnp.int32 and len(shape) < 2: @@ -754,6 +770,7 @@ class OpsTest(PallasBaseTest): self.assertAllClose(f(x).item(), 10.0) def test_concat_constant(self): + self.skip_if_mosaic_gpu() if pltpu is None: self.skipTest("No TPU module available.") axis = 0 @@ -794,6 +811,8 @@ class OpsTest(PallasBaseTest): for value in values ) def test_sign(self, dtype, value): + self.skip_if_mosaic_gpu() + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") @@ -825,6 +844,7 @@ class OpsTest(PallasBaseTest): jnp.int32, ) def test_add_constant(self, dtype): + self.skip_if_mosaic_gpu() shape = (256, 256) @@ -844,6 +864,8 @@ class OpsTest(PallasBaseTest): -3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4, ) def test_erf_inv(self, value): + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), floatx), @@ -935,6 +957,8 @@ class OpsTest(PallasBaseTest): for fn, dtype in itertools.product(*args) ) def test_elementwise(self, fn, dtype): + self.skip_if_mosaic_gpu() + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") @@ -997,6 +1021,8 @@ class OpsTest(PallasBaseTest): for fn, dtype in itertools.product(*args) ) def test_elementwise_scalar(self, fn, dtype): + self.skip_if_mosaic_gpu() + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") @@ -1044,6 +1070,8 @@ class OpsTest(PallasBaseTest): self.assertAllClose(kernel(x), fn(x), rtol=1e-6) def test_abs_weak_type(self): + self.skip_if_mosaic_gpu() + # see https://github.com/jax-ml/jax/issues/23191 @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), floatx), @@ -1061,6 +1089,8 @@ class OpsTest(PallasBaseTest): ("float64", "float64"), ) def test_pow(self, x_dtype, y_dtype): + self.skip_if_mosaic_gpu() + if not jax.config.x64_enabled and jnp.dtype(x_dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") @@ -1079,6 +1109,8 @@ class OpsTest(PallasBaseTest): @parameterized.parameters(0, 1, 2, 3, 4, 5, -1, -2, -3) def test_integer_pow(self, y): + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), ) @@ -1097,6 +1129,8 @@ class OpsTest(PallasBaseTest): ) ) def test_nextafter(self, dtype, x, y): + self.skip_if_mosaic_gpu() + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") @@ -1132,6 +1166,8 @@ class OpsTest(PallasBaseTest): ) ) def test_comparison(self, fn, dtype): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["gpu"]) and dtype == jnp.bool_: self.skipTest("Not implemented on GPU.") @@ -1159,6 +1195,8 @@ class OpsTest(PallasBaseTest): ) ) def test_comparison_scalar(self, fn, dtype): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]) and dtype == jnp.float16: self.skipTest("float16 is not supported on TPU") @@ -1188,6 +1226,8 @@ class OpsTest(PallasBaseTest): self.assertArraysEqual(out, expected) def test_isnan(self): + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), ) @@ -1220,6 +1260,8 @@ class OpsTest(PallasBaseTest): ("bfloat16", "bfloat16"), ) def test_true_divide(self, dtype, out_dtype): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): if out_dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6): self.skipTest("bfloat16 is not supported on older TPU generations") @@ -1249,6 +1291,8 @@ class OpsTest(PallasBaseTest): @parameterized.parameters("float16", "bfloat16") def test_true_divide_unsupported(self, dtype): + self.skip_if_mosaic_gpu() + if self.INTERPRET: self.skipTest("No lowering in interpret mode") @@ -1286,6 +1330,8 @@ class OpsTest(PallasBaseTest): for fn, dtype in itertools.product(*args) ) def test_binary(self, f, dtype): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") @@ -1309,6 +1355,8 @@ class OpsTest(PallasBaseTest): for fn, dtype in itertools.product(*args) ) def test_binary_scalar(self, f, dtype): + self.skip_if_mosaic_gpu() + if not jtu.test_device_matches(["tpu"]): self.skipTest("Test only supported on TPU.") if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: @@ -1336,6 +1384,8 @@ class OpsTest(PallasBaseTest): ((8, 16, 2), jnp.int8, 1), ) def test_broadcasted_iota(self, shape, dtype, dimension): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Only 32-bit integer iota supported") @@ -1351,6 +1401,8 @@ class OpsTest(PallasBaseTest): @parameterized.parameters("float16", "bfloat16", "float32") def test_approx_tanh(self, dtype): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented on TPU") @@ -1378,6 +1430,8 @@ class OpsTest(PallasBaseTest): ) def test_elementwise_inline_asm(self): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented: elementwise_inline_asm_p") @@ -1403,6 +1457,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_allclose(kernel(x), jnp.tanh(x), atol=5e-3, rtol=5e-3) def test_debug_barrier(self): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented: debug_barrier_p") @@ -1425,6 +1481,8 @@ class OpsTest(PallasBaseTest): "plgpu.TritonCompilerParams unavailable on Windows", ) def test_debug_print(self): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Test for TPU is covered in tpu_pallas_test.py") @@ -1481,6 +1539,8 @@ class OpsTest(PallasBaseTest): ((64,), (32, 2)), ) def test_reshape(self, in_shape, out_shape): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1512,6 +1572,8 @@ class OpsTest(PallasBaseTest): # fmt: on ) def test_reshape_noop_or_singleton_dims(self, in_shape, out_shape): + self.skip_if_mosaic_gpu() + # Unsupported implicit dim change: from "32,{0,0},(2,128),-1" to none if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1528,6 +1590,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_allclose(f(x), expected) def test_num_programs(self): + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, out_specs=pl.BlockSpec(memory_space=smem_on_tpu()), @@ -1542,6 +1606,8 @@ class OpsTest(PallasBaseTest): ) def test_where_broadcasting(self): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1570,6 +1636,8 @@ class OpsTest(PallasBaseTest): ((), (2, 2), ()), ) def test_broadcast_in_dim(self, in_shape, out_shape, dims): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1617,6 +1685,8 @@ class OpsTest(PallasBaseTest): trans_y=[False, True], ) def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y): + self.skip_if_mosaic_gpu() + # TODO(apaszke): Remove after 12 weeks have passed. if not jtu.if_cloud_tpu_at_least(2024, 12, 19): self.skipTest("Requires libtpu built after 2024-12-19") @@ -1679,6 +1749,8 @@ class OpsTest(PallasBaseTest): block_size=[1, 2, 32, 64, 128], ) def test_masked_load_store(self, size, block_size): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented") @@ -1699,6 +1771,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_allclose(kernel(x), x + 1.0, atol=1e-5, rtol=1e-5) def test_masked_oob_load_store_slice(self): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1723,6 +1797,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_array_equal(out, o_new) def test_strided_load(self): + self.skip_if_mosaic_gpu() + # Reproducer from https://github.com/jax-ml/jax/issues/20895. @functools.partial( self.pallas_call, @@ -1735,6 +1811,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_array_equal(kernel(x), x[::4]) def test_broadcasted_load_store(self): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Unimplemented primitive: broadcast_to") @@ -1758,6 +1836,8 @@ class OpsTest(PallasBaseTest): ((16, 32), (16, 16)), ) def test_invalid_broadcasted_load(self, x_shape, mask_shape): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1784,6 +1864,8 @@ class OpsTest(PallasBaseTest): self.fail("Expected exception due to invalid broadcasting") def test_swap(self): + self.skip_if_mosaic_gpu() + # TODO: skipped due to https://github.com/jax-ml/jax/issues/24023 if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU this is only supported in interpret mode") @@ -1807,6 +1889,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_array_equal(out[1], x) def test_masked_swap(self): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented on TPU") @@ -1830,6 +1914,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_array_equal(out[1], jnp.where(mask, x, y)) def test_masked_oob_swap_slice(self): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1871,6 +1957,8 @@ class OpsTest(PallasBaseTest): ("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min), ) def test_scalar_atomic(self, op, value, numpy_op): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1906,6 +1994,8 @@ class OpsTest(PallasBaseTest): @parameterized.parameters((0,), (1,)) def test_array_atomic_add(self, axis): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Unimplemented primitive: broadcast_to") @@ -1946,6 +2036,8 @@ class OpsTest(PallasBaseTest): (2, 1, 1), ) def test_atomic_cas(self, init_value, cmp, new_value): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1968,6 +2060,8 @@ class OpsTest(PallasBaseTest): @parameterized.parameters(1, 2, 3, 4, 8) def test_atomic_counter(self, num_threads): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1997,6 +2091,8 @@ class OpsTest(PallasBaseTest): @parameterized.parameters(False, True) def test_reduce_only_dim(self, use_store): + self.skip_if_mosaic_gpu() + # The Pallas TPU lowering currently supports only blocks of rank >= 1 if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -2040,6 +2136,8 @@ class OpsTest(PallasBaseTest): ] ]) def test_array_reduce(self, op, dtype, axis): + self.skip_if_mosaic_gpu() + if not isinstance(axis, int): self.skipTest("TODO: tuple axes are not yet supported") @@ -2097,6 +2195,8 @@ class OpsTest(PallasBaseTest): dtype=["float16", "float32", "int32", "uint32"], ) def test_cumsum(self, dtype, axis): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented on TPU") @@ -2134,6 +2234,8 @@ class OpsTest(PallasBaseTest): (-1, jnp.bfloat16), ) def test_triu(self, k, dtype): + self.skip_if_mosaic_gpu() + if dtype == jnp.bfloat16 and jtu.test_device_matches(["tpu"]): # TODO(mvoz): b/376330700 raise unittest.SkipTest('NYI - bf16 select') @@ -2159,6 +2261,8 @@ class OpsTest(PallasBaseTest): (jnp.int32, jnp.uint32), ) def test_bitcast_convert_type(self, in_dtype, out_dtype): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented on TPU") @@ -2176,6 +2280,8 @@ class OpsTest(PallasBaseTest): np.testing.assert_array_equal(y, y_ref) def test_bitcast_convert_type_scalar(self): + self.skip_if_mosaic_gpu() + if jtu.test_device_matches(["tpu"]): self.skipTest("Not implemented on TPU") From 16bb919020aae0706bbce595c377f55fdcce170a Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 6 Mar 2025 02:39:36 -0800 Subject: [PATCH 040/100] Update XLA dependency to use revision http://github.com/openxla/xla/commit/6e396aae2e534dc7fc5387e2aa8b1a3a8d79a3db. PiperOrigin-RevId: 734059108 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 569518a67..4e3e77ea6 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "e0e56a1190d5be336f7d3e308457349fbfacebd2" -XLA_SHA256 = "cf64ab04fa47dd86bb222e87dc573d083e869d211c4ee806786fc4da17c1dafc" +XLA_COMMIT = "6e396aae2e534dc7fc5387e2aa8b1a3a8d79a3db" +XLA_SHA256 = "03e73863ff041c57d2fdefb4216c2774bb12faf70dd083dbe8e961ccb9ea42b2" def repo(): tf_http_archive( From d6b97c20264c323f4fd35f4508babcc700967ff3 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Thu, 6 Mar 2025 04:07:32 -0800 Subject: [PATCH 041/100] [pallas] Add support for `pl.dot` with `int8` inputs. PiperOrigin-RevId: 734081057 --- jax/_src/pallas/primitives.py | 4 +++- tests/pallas/pallas_test.py | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 07a6fcf0a..af84c2836 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -695,12 +695,14 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False, if precision is not None: raise ValueError("Only one of allow_tf32 and precision can be specified") precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST + dtype = jnp.promote_types(a.dtype, b.dtype) + out_dtype = jnp.int32 if jnp.issubdtype(dtype, jnp.integer) else jnp.float32 return jax.lax.dot_general( a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())), precision=precision, - preferred_element_type=jnp.float32, + preferred_element_type=out_dtype, ) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index bb3826dbf..9ee0dfc29 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -733,6 +733,27 @@ class PallasCallTest(PallasBaseTest): ) self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3) + @parameterized.parameters(jnp.int8, jnp.uint8) + def test_integer_dot(self, dtype): + if jtu.test_device_matches(["tpu"]) and not jtu.is_device_tpu_at_least(5): + self.skipTest("`int8` dot is only supported on v5 TPUs and newer.") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((32, 64), jnp.int32), + ) + def dot_kernel(x_ref, y_ref, o_ref): + o_ref[()] = pl.dot(x_ref[()], y_ref[()]) + + key0, key1 = random.split(random.key(0)) + # FIXME(cjfj): TPU fails with `uint8` values >= 128. + kwargs = dict(minval=jnp.iinfo(dtype).min, maxval=128, dtype=dtype) + # TODO(cjfj): Investigate why this fails on GPU with `k == 16`. + x = random.randint(key0, (32, 128), **kwargs) + y = random.randint(key1, (128, 64), **kwargs) + expected = jnp.dot(x, y, preferred_element_type=jnp.int32) + self.assertAllClose(dot_kernel(x, y), expected, atol=0.0, rtol=0.0) + def test_dot_with_vector(self): if not jtu.test_device_matches(["gpu"]) or self.INTERPRET: self.skipTest( From 2a34019388cf06ab8b6819e014bb8be17cadf579 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 6 Mar 2025 04:08:57 -0800 Subject: [PATCH 042/100] [pallas:mosaic_gpu] Added WG lowering rule for `lax.bitcast_convert_type_p` PiperOrigin-RevId: 734081448 --- jax/_src/pallas/mosaic_gpu/lowering.py | 23 +++++++---- .../mosaic/gpu/dialect_lowering.py | 19 +++++++++ .../mosaic/gpu/layout_inference.py | 1 + tests/pallas/mosaic_gpu_test.py | 39 ++++++++++++------- 4 files changed, 60 insertions(+), 22 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index e7331d2d5..5454b826f 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1908,27 +1908,36 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): @register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule( + lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Warpgroup +) def _bitcast_convert_type_lowering_rule( - ctx: LoweringRuleContext, operand, *, new_dtype + ctx: LoweringRuleContext, x, *, new_dtype ): - # TODO(petebu) Handle case where src and dst types have different bitwidths - [operand_aval] = ctx.avals_in - operand = _ensure_fa(operand, operand_aval.dtype) - src_elem_type = mgpu_utils.dtype_to_ir_type(operand_aval.dtype) + [x_aval] = ctx.avals_in + src_elem_type = mgpu_utils.dtype_to_ir_type(x_aval.dtype) dst_elem_type = mgpu_utils.dtype_to_ir_type(new_dtype) assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType)) assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType)) if src_elem_type.width != dst_elem_type.width: raise NotImplementedError( - f"Can't bitcast from {operand_aval.dtype} to {new_dtype} because they" + f"Cannot bitcast from {x_aval.dtype} to {new_dtype} because they" " have different widths" ) + + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: + x = _ensure_ir_value(x, x_aval.dtype) + return arith_dialect.bitcast( + ir.VectorType.get(x_aval.shape, dst_elem_type), x + ) + + x = _ensure_fa(x, x_aval.dtype) if ir.IntegerType.isinstance(dst_elem_type): output_is_signed = mgpu_utils.is_signed(new_dtype) else: output_is_signed = None return mgpu.FragmentedArray.bitcast( - operand, dst_elem_type, output_is_signed=output_is_signed + x, dst_elem_type, output_is_signed=output_is_signed ) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 3ca7a8571..c93c01d0d 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -525,6 +525,25 @@ def _cmpf_op_lowering_rule( return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)] +@_register_lowering(arith.BitcastOp) +def _bitcast_op_lowering_rule( + _: LoweringContext, op: arith.BitcastOp +) -> Sequence[ir.Value]: + in_layouts = inference_utils.in_layouts(op) + [layout] = inference_utils.out_layouts(op) + if any(in_layout != layout for in_layout in in_layouts): + raise ValueError("Layout mismatch") + in_ = _fragmented_array_from_ir(op.in_, layout) + out_element_type = ir.VectorType(op.result.type).element_type + out = in_.bitcast( + out_element_type, + output_is_signed=False + if ir.IntegerType.isinstance(out_element_type) + else None, + ) + return [_fragmented_array_to_ir(out, op.result.type)] + + @_register_lowering(mgpu.WGMMAOp) def _mgpu_wgmma_op_lowering_rule( _: LoweringContext, wgmma_op: mgpu.WGMMAOp diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 5971cfb85..6b2010cf5 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -194,6 +194,7 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts: for op in [ arith.AddIOp, arith.AddFOp, arith.AndIOp, + arith.BitcastOp, arith.CmpFOp, arith.CmpIOp, arith.ExtFOp, arith.ExtSIOp, arith.ExtUIOp, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 74f9cc617..22337ae68 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1161,29 +1161,38 @@ class PallasCallTest(PallasTest): self.assertEqual(data.count('"name": "store"'), 2) np.testing.assert_array_equal(y, x + x) - @parameterized.parameters( - (jnp.float16, jnp.float16), # Noop - (jnp.int16, jnp.bfloat16), - (jnp.int16, jnp.float16), - (jnp.uint16, jnp.float16), - (jnp.float32, jnp.int32), - (jnp.float32, jnp.uint32), - (jnp.uint32, jnp.int32), - (jnp.int32, jnp.uint32), + @parameterized.product( + dtypes=[ + (jnp.float16, jnp.float16), # Noop + (jnp.int16, jnp.bfloat16), + (jnp.int16, jnp.float16), + (jnp.uint16, jnp.float16), + (jnp.float32, jnp.int32), + (jnp.float32, jnp.uint32), + (jnp.uint32, jnp.int32), + (jnp.int32, jnp.uint32), + ], + thread_semantics=[*plgpu.ThreadSemantics], ) - def test_bitcast_convert_type(self, in_dtype, out_dtype): + def test_bitcast_convert_type(self, dtypes, thread_semantics): + in_dtype, out_dtype = dtypes m, n = 16, 8 out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) - grid = () - @functools.partial(pl.pallas_call, out_shape=out_shape, grid=grid) + @functools.partial( + pl.pallas_call, + out_shape=out_shape, + compiler_params=plgpu.GPUCompilerParams( + thread_semantics=thread_semantics + ), + ) def convert(x_ref, y_ref): y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n)) - y = convert(x) - y_ref = jax.lax.bitcast_convert_type(x, out_dtype) - np.testing.assert_array_equal(y, y_ref) + np.testing.assert_array_equal( + convert(x), jax.lax.bitcast_convert_type(x, out_dtype) + ) class PallasCallSm90ATest(PallasSm90ATest): From 623865fe9538100d877ba9d36f788d0f95a11ed2 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 6 Mar 2025 06:47:28 -0800 Subject: [PATCH 043/100] Build JAX wheels instead of installing it from the source repository This change allows us to get rid of extra env vars which used to control whether to install `jax` at head. Now, `jax` will be be built and consumed in the same way as the other wheels in the continuous jobs. PiperOrigin-RevId: 734123590 --- .github/workflows/pytest_cpu.yml | 24 ++++--------------- .github/workflows/pytest_cuda.yml | 14 ++--------- .github/workflows/wheel_tests_continuous.yml | 16 +++++++++---- .../workflows/wheel_tests_nightly_release.yml | 6 ----- ci/envs/default.env | 7 +----- ci/utilities/install_wheels_locally.sh | 8 +------ 6 files changed, 21 insertions(+), 54 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 3d153f6bc..e64f81809 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -29,11 +29,6 @@ on: type: string required: true default: "0" - install-jax-current-commit: - description: "Should the 'jax' package be installed from the current commit?" - type: string - required: true - default: "1" gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true @@ -62,7 +57,6 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}" - JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -88,7 +82,7 @@ jobs: # `*-cp-cp-*`, while free-threaded wheels use # `*-cp-cpt-*`. echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV - - name: Download jaxlib wheel from GCS (non-Windows runs) + - name: Download wheels from GCS (non-Windows runs) id: download-wheel-artifacts-nw # Set continue-on-error to true to prevent actions from failing the workflow if this step # fails. Instead, we verify the outcome in the step below so that we can print a more @@ -96,14 +90,10 @@ jobs: continue-on-error: true if: ${{ !contains(inputs.runner, 'windows-x86') }} run: | - mkdir -p $(pwd)/dist && + mkdir -p $(pwd)/dist + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ - - # Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1 - if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - fi - - name: Download jaxlib wheel from GCS (Windows runs) + - name: Download wheels from GCS (Windows runs) id: download-wheel-artifacts-w # Set continue-on-error to true to prevent actions from failing the workflow if this step # fails. Instead, we verify the outcome in step below so that we can print a more @@ -115,12 +105,8 @@ jobs: mkdir dist @REM Use `call` so that we can run sequential gsutil commands on Windows @REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652 + call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/ call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/ - - @REM Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1 - if not "${{ inputs.install-jax-current-commit }}"=="1" ( - call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/ - ) - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure' run: | diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index fde109f9e..3dbd5bb0a 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -34,11 +34,6 @@ on: type: string required: true default: "0" - install-jax-current-commit: - description: "Should the 'jax' package be installed from the current commit?" - type: string - required: true - default: "1" gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true @@ -66,7 +61,6 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}" - JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -86,7 +80,7 @@ jobs: # `*-cp-cp-*`, while free-threaded wheels use # `*-cp-cpt-*`. echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV - - name: Download the wheel artifacts from GCS + - name: Download wheels from GCS id: download-wheel-artifacts # Set continue-on-error to true to prevent actions from failing the workflow if this step # fails. Instead, we verify the outcome in the next step so that we can print a more @@ -94,14 +88,10 @@ jobs: continue-on-error: true run: | mkdir -p $(pwd)/dist && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ - - # Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1 - if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index fe1304c14..5c818bf56 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -27,6 +27,16 @@ concurrency: cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} jobs: + build-jax-artifact: + uses: ./.github/workflows/build_artifacts.yml + with: + # Note that since jax is a pure python package, the runner OS and Python values do not + # matter. In addition, cloning main XLA also has no effect. + runner: "linux-x86-n2-16" + artifact: "jax" + upload_artifacts_to_gcs: true + gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + build-jaxlib-artifact: uses: ./.github/workflows/build_artifacts.yml strategy: @@ -66,7 +76,7 @@ jobs: # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we # still want to run the tests for other platforms. if: ${{ !cancelled() }} - needs: build-jaxlib-artifact + needs: [build-jax-artifact, build-jaxlib-artifact] uses: ./.github/workflows/pytest_cpu.yml strategy: fail-fast: false # don't cancel all jobs on failure @@ -80,7 +90,6 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - install-jax-current-commit: 1 gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} run-pytest-cuda: @@ -88,7 +97,7 @@ jobs: # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we # still want to run the tests for other platforms. if: ${{ !cancelled() }} - needs: [build-jaxlib-artifact, build-cuda-artifacts] + needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts] uses: ./.github/workflows/pytest_cuda.yml strategy: fail-fast: false # don't cancel all jobs on failure @@ -111,7 +120,6 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} - install-jax-current-commit: 1 # GCS upload URI is the same for both artifact build jobs gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 33d62db4f..b88b000e4 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -40,9 +40,6 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} - # Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the - # GCS bucket. - install-jax-current-commit: 0 gcs_download_uri: ${{inputs.gcs_download_uri}} run-pytest-cuda: @@ -61,7 +58,4 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} - # Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the - # GCS bucket. - install-jax-current-commit: 0 gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file diff --git a/ci/envs/default.env b/ci/envs/default.env index 66578efac..7a2448944 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -74,9 +74,4 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} # JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels # on the system. By default, it is set to match the version of the hermetic # Python used by Bazel for building the wheels. -export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} - -# Installs the JAX package in editable mode at the current commit. Enabled by -# default. Nightly/Release builds disable this flag in the Github action -# workflow files. -export JAXCI_INSTALL_JAX_CURRENT_COMMIT=${JAXCI_INSTALL_JAX_CURRENT_COMMIT:-"1"} \ No newline at end of file +export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} \ No newline at end of file diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index f0e245e14..41274b95f 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -19,7 +19,7 @@ # avoid using the Windows version of `find` on Msys. WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) ) -if [[ -z "$WHEELS" ]]; then +if [[ -z "${WHEELS[@]}" ]]; then echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" exit 1 fi @@ -38,10 +38,4 @@ if [[ $(uname -s) =~ "MSYS_NT" ]]; then "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") else "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" -fi - -if [[ "$JAXCI_INSTALL_JAX_CURRENT_COMMIT" == "1" ]]; then - echo "Installing the JAX package at the current commit..." - # Install JAX package at the current commit. - "$JAXCI_PYTHON" -m uv pip install . fi \ No newline at end of file From 8c89da7cdcf78848674b76a032f9a461be4e6f99 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 6 Mar 2025 06:57:07 -0800 Subject: [PATCH 044/100] Minor bug fixes in error checking PiperOrigin-RevId: 734126415 --- jax/_src/error_check.py | 38 ++++++++++++++++++++------------------ tests/error_check_test.py | 7 ++++++- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index 1c009224d..edfcdf3f7 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -33,12 +33,11 @@ class JaxValueError(ValueError): """Exception raised for failed runtime error checks in JAX.""" +#: The default error code for no error. +#: +#: This value is chosen because we can use `jnp.min()` to obtain the +#: first error when performing reductions. _NO_ERROR = jnp.iinfo(jnp.uint32).max -"""The default error code for no error. - -We choose this value because when performing reductions, we can use `min` to -obtain the smallest error code. -""" _error_list_lock = threading.Lock() @@ -62,7 +61,7 @@ def _initialize_error_code_ref() -> None: def set_error_if(pred: jax.Array, msg: str) -> None: - """Set error if pred is true. + """Set error if any element of pred is true. If the error is already set, the new error will be ignored. It will not override the existing error. @@ -74,7 +73,7 @@ def set_error_if(pred: jax.Array, msg: str) -> None: traceback = source_info_util.current().traceback assert traceback is not None with _error_list_lock: - new_error_code = len(_error_list) + new_error_code = jnp.uint32(len(_error_list)) _error_list.append((msg, traceback)) pred = pred.any() @@ -86,18 +85,21 @@ def set_error_if(pred: jax.Array, msg: str) -> None: def raise_if_error() -> None: - """Raise error if an error is set.""" - if _error_storage.ref is None: - return # if not initialized, do nothing + """Raise error if an error is set. + + This function should be called after the computation is finished. It should + be used outside jit. + """ + if _error_storage.ref is None: # if not initialized, do nothing + return error_code = _error_storage.ref[...] if error_code == jnp.uint32(_NO_ERROR): return - try: - msg, traceback = _error_list[error_code] - exc = JaxValueError(msg) - traceback = traceback.as_python_traceback() - filtered_traceback = traceback_util.filter_traceback(traceback) - raise exc.with_traceback(filtered_traceback) - finally: - _error_storage.ref[...] = jnp.uint32(_NO_ERROR) + _error_storage.ref[...] = jnp.uint32(_NO_ERROR) + + msg, traceback = _error_list[error_code] + exc = JaxValueError(msg) + traceback = traceback.as_python_traceback() + filtered_traceback = traceback_util.filter_traceback(traceback) + raise exc.with_traceback(filtered_traceback) diff --git a/tests/error_check_test.py b/tests/error_check_test.py index 8ac435cbb..653f901a6 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -148,7 +148,8 @@ class ErrorCheckTests(jtu.JaxTestCase): with self.assertRaisesRegex(JaxValueError, "x must be less than 10"): error_check.raise_if_error() - def test_error_check_works_with_scan(self): + @parameterized.product(jit=[True, False]) + def test_error_check_works_with_scan(self, jit): def f(carry, x): error_check.set_error_if(x >= 4, "x must be less than 4") return carry + x, x + 1 @@ -156,6 +157,9 @@ class ErrorCheckTests(jtu.JaxTestCase): def body(init, xs): return jax.lax.scan(f, init=init, xs=xs) + if jit: + body = jax.jit(body) + init = jnp.int32(0) xs = jnp.arange(5, dtype=jnp.int32) _ = body(init, xs) @@ -166,5 +170,6 @@ class ErrorCheckTests(jtu.JaxTestCase): _ = body(init, xs) error_check.raise_if_error() # should not raise error + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 5d64b3d2dde4b342bd1ae5c8092f2efa568581e7 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 6 Mar 2025 08:39:32 -0800 Subject: [PATCH 045/100] [Mosaic GPU] Fix `scf.ForOp` lowering to put lowered ops at the right place. Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error. The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics. PiperOrigin-RevId: 734157829 --- jax/experimental/mosaic/gpu/dialect_lowering.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index c93c01d0d..d6156278c 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -708,8 +708,12 @@ def _for_op_lowering_rule( new_args = (new_for_op.induction_variable, *recreated_carry) for old_carry, new_carry in zip(for_op.body.arguments, new_args, strict=True): old_carry.replace_all_uses_with(new_carry) - for op in ops_to_lower: + + for op in ops_to_lower: + with ir.InsertionPoint(op): ctx.lower_op(op) + + with ir.InsertionPoint(new_for_op.body): new_yield_operands = lower_carry(yield_op.operands) yield_op.erase() scf.yield_(new_yield_operands) From 4b49c0352355c4d71c701aba74e52c678d6627fd Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 6 Mar 2025 13:36:01 -0800 Subject: [PATCH 046/100] Open source TPU-friendly ragged paged attention kernel. Key features: * ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.) * ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly. * ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***! * ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode. * ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine. PiperOrigin-RevId: 734269519 --- .../pallas/ops/tpu/ragged_paged_attention.py | 712 ++++++++++++++++++ tests/pallas/BUILD | 18 + .../pallas/tpu_ragged_paged_attention_test.py | 305 ++++++++ 3 files changed, 1035 insertions(+) create mode 100644 jax/experimental/pallas/ops/tpu/ragged_paged_attention.py create mode 100644 tests/pallas/tpu_ragged_paged_attention_test.py diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py new file mode 100644 index 000000000..5d20dad6b --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -0,0 +1,712 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TPU-Friendly Ragged Paged Attention kernel. + +This kernel offers a highly optimized implementation of ragged paged attention, +specifically designed for TPU and compatible with a wide range of model +specifications. It supports mixed prefill and decoding, enhancing throughput +during inference. +""" + +import functools +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp + +DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) + + +class MultiPageAsyncCopyDescriptor: + """Descriptor for async copy of multiple K/V pages from HBM.""" + + def __init__( + self, + pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads_per_blk, head_dim] + vmem_buf, # [num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + sem, + page_indices_ref, # i32[max_num_seqs, pages_per_seq] + offset, # [seq_idx, kv_pages_start] + ): + self._vmem_buf = vmem_buf + seq_id, kv_pages_start = offset + self._async_copies = [ + pltpu.make_async_copy( + pages_hbm_ref.at[page_indices_ref[seq_id, kv_pages_start + i]], + vmem_buf.at[i], + sem, + ) + for i in range(vmem_buf.shape[0]) + ] + + def start(self): + """Starts the async copies.""" + for async_copy in self._async_copies: + async_copy.start() + + def wait(self): + for async_copy in self._async_copies: + async_copy.wait() + return self._vmem_buf + + +def ref_ragged_paged_attention( + queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1], + *, + sm_scale: float = 1.0, + mask_value: float = DEFAULT_MASK_VALUE, +): + _, _, num_kv_heads, head_dim = k_pages.shape + num_q_heads = queries.shape[1] + assert num_q_heads % num_kv_heads == 0 + num_query_per_kv = num_q_heads // num_kv_heads + outputs = [] + for i in range(num_seqs[0]): + q_start = cu_q_lens[i] + q_end = cu_q_lens[i + 1] + q_len = q_end - q_start + kv_len = kv_lens[i] + indices = page_indices[i] + q = queries[q_start:q_end] + k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + k = jnp.repeat(k, num_query_per_kv, axis=1) + v = jnp.repeat(v, num_query_per_kv, axis=1) + attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) + attn *= sm_scale + q_span = (kv_len - q_len) + jax.lax.broadcasted_iota( + jnp.int32, attn.shape, 1 + ) + kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) + attn += jnp.where(q_span < kv_span, mask_value, 0.0) + attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) + out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) + outputs.append(out) + + return jnp.concatenate(outputs, axis=0) + + +# Expect to run these checkes during runtime. +def validate_inputs_on_runtime( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs, # i32[1] +): + check_inputs_shapes( + q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs + ) + max_num_batched_tokens = q.shape[0] + page_size = k_pages.shape[1] + max_num_seqs, pages_per_seq = page_indices.shape + if num_seqs[0] > max_num_seqs: + raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}") + max_kv_len = jnp.max(kv_lens) + min_pages_per_seq = ceil_div(max_kv_len, page_size) + if pages_per_seq < min_pages_per_seq: + raise ValueError( + f"{pages_per_seq=} must be greater or equal to" + f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}." + ) + if cu_q_lens[num_seqs[0]] > max_num_batched_tokens: + raise ValueError( + f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to" + f" {max_num_batched_tokens=}." + ) + for i in range(num_seqs[0]): + q_len = cu_q_lens[i + 1] - cu_q_lens[i] + kv_len = kv_lens[i] + if q_len > kv_len: + raise ValueError( + f"{q_len=} must be less or equal to {kv_len=} at sequence {i}." + ) + + +# Expect to run these checks during compile time. +def check_inputs_shapes( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs, # i32[1] +): + _, num_q_heads, head_dim = q.shape + _, _, num_kv_heads, head_dim_k = k_pages.shape + max_num_seqs, _ = page_indices.shape + if num_seqs.shape != (1,): + raise ValueError(f"{num_seqs.shape=} must be (1,)") + if k_pages.shape != v_pages.shape: + raise ValueError( + f"{k_pages.shape=} and {v_pages.shape=} must have the same shape." + ) + if head_dim_k != head_dim: + raise ValueError( + f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}." + ) + if kv_lens.shape != (max_num_seqs,): + raise ValueError( + f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where" + " `max_num_seqs` is `page_indices.shape[0]`." + ) + if cu_q_lens.shape != (max_num_seqs + 1,): + raise ValueError( + f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where" + " `max_num_seqs` is `page_indices.shape[0]`." + ) + if ( + kv_lens.dtype != jnp.int32 + or page_indices.dtype != jnp.int32 + or cu_q_lens.dtype != jnp.int32 + ): + raise ValueError( + "The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be" + f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=}," + f" {cu_q_lens.dtype=}." + ) + if num_q_heads % num_kv_heads != 0: + raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") + + +def ragged_paged_attention_kernel( + # Prefetch + kv_lens_ref, # [max_num_seqs] + page_indices_ref, # [max_num_seqs, pages_per_seq] + cu_q_lens_ref, # [max_num_seqs + 1] + seq_buf_idx_ref, + # TODO(jevinjiang): if OOM in SMEM, consider pack to other scalar refs. + num_seqs_ref, + # Input + q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] + k_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] + # Output + o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] + # Scratch + k_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + v_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + sems, # [2, 2] + l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + *, + sm_scale: float, + mask_value: float, +): + num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape + num_seqs = num_seqs_ref[0] + _, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, _ = k_bufs.shape + num_kv_per_blk = num_kv_pages_per_blk * page_size + num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk + heads_blk_idx, q_blk_idx = ( + pl.program_id(0), + pl.program_id(1), + ) + num_heads_blks = pl.num_programs(0) + init_seq_idx = seq_buf_idx_ref[0] + init_buf_idx = seq_buf_idx_ref[1] + q_len_start = q_blk_idx * num_q_per_blk + q_len_end = q_len_start + num_q_per_blk + + def create_kv_async_copy_descriptors( + heads_blk_idx, seq_idx, kv_blk_idx, buf_idx + ): + offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) + heads_start = heads_blk_idx * num_kv_heads_per_blk + async_copy_k = MultiPageAsyncCopyDescriptor( + k_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], + k_bufs.at[buf_idx], + sems.at[buf_idx, 0], + page_indices_ref, + offset, + ) + async_copy_v = MultiPageAsyncCopyDescriptor( + v_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], + v_bufs.at[buf_idx], + sems.at[buf_idx, 1], + page_indices_ref, + offset, + ) + return async_copy_k, async_copy_v + + # TODO(jevinjiang): Add these to Mosaic: + # 1. Support arbitrary strided load/store for any dtype. + # 2. Support arbitrary strided load/store for any last dimension. + def strided_load_kv(ref, start, step): + if ref.dtype == jnp.float32: + return ref[start::step, :] + packing = get_dtype_packing(ref.dtype) + assert ref.dtype == jnp.bfloat16 + assert step % packing == 0 + b_start = start // packing + b_offset = start % packing + b_step = step // packing + b_ref = ref.bitcast(jnp.int32) + b = b_ref[b_start::b_step, :] + bw = 32 // packing + b = jnp.right_shift(b, bw * b_offset) + b = jnp.left_shift(b, bw * (packing - 1)) + return pltpu.bitcast(b, jnp.float32).astype(jnp.bfloat16) + + @pl.when(heads_blk_idx + q_blk_idx == 0) + def prefetch_first_kv_blk(): + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + heads_blk_idx, init_seq_idx, 0, init_buf_idx + ) + async_copy_k.start() + async_copy_v.start() + + def is_cur_q_blk_needed(q_states): + done, cur_seq_idx, _ = q_states + return jnp.logical_and(done == 0, cur_seq_idx < num_seqs) + + def compute_with_cur_q_blk(q_states): + done, cur_seq_idx, cur_buf_idx = q_states + q_start = cu_q_lens_ref[cur_seq_idx] + q_end = cu_q_lens_ref[cur_seq_idx + 1] + q_len = q_end - q_start + kv_len = kv_lens_ref[cur_seq_idx] + + def get_next_prefetch_ids( + heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx + ): + next_kv_blk_idx = kv_blk_idx + 1 + is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len + next_kv_blk_idx = lax.select( + is_last_kv_blk, + 0, + next_kv_blk_idx, + ) + is_cur_seq_end_in_cur_q_blk = q_end <= q_len_end + next_seq_idx = lax.select( + is_last_kv_blk, + lax.select(is_cur_seq_end_in_cur_q_blk, cur_seq_idx + 1, cur_seq_idx), + cur_seq_idx, + ) + is_last_seq = next_seq_idx == num_seqs + next_seq_idx = lax.select( + is_last_seq, + 0, + next_seq_idx, + ) + next_heads_blk_idx = lax.select( + is_last_seq, + heads_blk_idx + 1, + heads_blk_idx, + ) + next_buf_idx = lax.select(cur_buf_idx == 0, 1, 0) + return next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx + + def flash_attention( + q, # [num_q_per_blk * num_q_heads_per_kv_head, head_dim] + k, # [num_kv_per_blk, head_dim] + v, # [num_kv_per_blk, head_dim] + head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] + head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] + head_o_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] + *, + kv_blk_idx, + ): + assert q.shape == ( + num_q_per_blk * num_q_heads_per_kv_head, + head_dim, + ) + assert k.shape == ( + num_kv_per_blk, + head_dim, + ), f"{k.shape=}, {(num_kv_per_blk, head_dim)=} {k.dtype=}" + assert v.shape == (num_kv_per_blk, head_dim) + assert head_m_ref.shape == ( + num_q_per_blk * num_q_heads_per_kv_head, + 128, + ) + assert head_l_ref.shape == ( + num_q_per_blk * num_q_heads_per_kv_head, + 128, + ) + assert head_o_ref.shape == ( + num_q_per_blk, + num_q_heads_per_kv_head, + head_dim, + ) + kv_len_start = kv_blk_idx * num_kv_per_blk + + def masked_store(ref, val, start, end, group=1): + iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group + mask = jnp.logical_and(iota >= start, iota < end) + pl.store(ref, tuple(slice(None) for _ in ref.shape), val, mask=mask) + + qk = ( + jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) + * sm_scale + ) + store_start = jnp.maximum(q_start - q_len_start, 0) + store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk) + + @pl.when(kv_blk_idx == 0) + def init_scratch_ref(): + masked_store( + head_m_ref, + jnp.full_like(head_m_ref, -jnp.inf), + store_start, + store_end, + num_q_heads_per_kv_head, + ) + masked_store( + head_l_ref, + jnp.zeros_like(head_l_ref), + store_start, + store_end, + num_q_heads_per_kv_head, + ) + masked_store( + head_o_ref, + jnp.zeros_like(head_o_ref), + store_start, + store_end, + ) + + row_ids = ( + (kv_len - q_len) + + q_len_start + - q_start + + jax.lax.broadcasted_iota( + jnp.int32, + (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), + 0, + ) + // num_q_heads_per_kv_head + ) + col_ids = kv_len_start + jax.lax.broadcasted_iota( + jnp.int32, + (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), + 1, + ) + causal_mask = row_ids < col_ids + qk += jnp.where(causal_mask, mask_value, 0.0) + m_curr = jnp.max(qk, axis=1, keepdims=True) + s_curr = jnp.exp(qk - m_curr) + qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32) + lm_store_shape = head_m_ref.shape + m_curr = jnp.broadcast_to(m_curr, lm_store_shape) + l_curr = jnp.broadcast_to( + s_curr.sum(axis=1, keepdims=True), lm_store_shape + ) + m_prev = head_m_ref[...] + l_prev = head_l_ref[...] + m_next = jnp.maximum(m_prev, m_curr) + masked_store( + head_m_ref, m_next, store_start, store_end, num_q_heads_per_kv_head + ) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_alpha = alpha * l_prev + l_next = l_alpha + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + masked_store( + head_l_ref, + l_next_safe, + store_start, + store_end, + num_q_heads_per_kv_head, + ) + + def broadcast_to_shape(arr, shape): + if arr.shape == shape: + return arr + assert len(arr.shape) == len(shape) + assert arr.shape[0] == shape[0] + assert shape[1] % arr.shape[1] == 0 + # no-op concatenation. + return jnp.concatenate( + [arr for _ in range(shape[1] // arr.shape[1])], axis=1 + ) + + o_curr = head_o_ref[...].reshape(-1, head_dim) + l_alpha = broadcast_to_shape(l_alpha, qkv.shape) + beta = broadcast_to_shape(beta, qkv.shape) + l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape) + out = lax.div( + l_alpha * o_curr + beta * qkv, + l_next_safe, + ).astype(head_o_ref.dtype) + masked_store( + head_o_ref, + out.reshape(head_o_ref.shape), + store_start, + store_end, + ) + + def is_valid_kv_blk_in_cur_seq(kv_states): + kv_blk_idx, _ = kv_states + return kv_blk_idx * num_kv_per_blk < kv_len + + def compute_with_kv_blk_in_cur_seq(kv_states): + kv_blk_idx, cur_buf_idx = kv_states + next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = ( + get_next_prefetch_ids( + heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx + ) + ) + + @pl.when(next_heads_blk_idx < num_heads_blks) + def prefetch_next_kv_blk(): + # TODO(jevinjiang): reuse the same buffer if it is already prefetched! + # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and + # DMA to fixed size buffer! + next_async_copy_k, next_async_copy_v = create_kv_async_copy_descriptors( + next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx + ) + next_async_copy_k.start() + next_async_copy_v.start() + + cur_async_copy_k, cur_async_copy_v = create_kv_async_copy_descriptors( + heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx + ) + kv_to_load_shape = ( + num_kv_pages_per_blk * page_size * num_kv_heads_per_blk, + head_dim, + ) + k_ref = cur_async_copy_k.wait().reshape(kv_to_load_shape) + v_ref = cur_async_copy_v.wait().reshape(kv_to_load_shape) + for kv_head_idx in range(num_kv_heads_per_blk): + q_head_idx = kv_head_idx * num_q_heads_per_kv_head + # TODO(jevinjiang): extra handlig for packed type that can start at + # unaligned position! + q = q_ref[ + :, q_head_idx : q_head_idx + num_q_heads_per_kv_head, : + ].reshape(-1, head_dim) + k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk) + v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk) + flash_attention( + q, + k, + v, + l_ref.at[kv_head_idx], + m_ref.at[kv_head_idx], + o_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], + kv_blk_idx=kv_blk_idx, + ) + return kv_blk_idx + 1, next_buf_idx + + _, next_buf_idx = lax.while_loop( + is_valid_kv_blk_in_cur_seq, + compute_with_kv_blk_in_cur_seq, + (0, cur_buf_idx), # (kv_blk_idx, buf_idx) + ) + next_seq_idx = lax.select(q_end <= q_len_end, cur_seq_idx + 1, cur_seq_idx) + done = lax.select(q_end < q_len_end, done, 1) + return done, next_seq_idx, next_buf_idx + + _, seq_idx, buf_idx = lax.while_loop( + is_cur_q_blk_needed, + compute_with_cur_q_blk, + (0, init_seq_idx, init_buf_idx), # (done, seq_idx, buf_idx) + ) + # Reset seq_idx for next kv_heads_blk if run out of seqs! + seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0) + seq_buf_idx_ref[1] = buf_idx + + +def ceil_div(a, b): + assert b != 0 + return (a + b - 1) // b + + +def get_dtype_packing(dtype): + if dtype == jnp.float32: + return 1 + if dtype == jnp.bfloat16: + return 2 + if dtype == jnp.int8: + return 4 + if dtype == jnp.int4: + return 8 + raise ValueError(f"Not implemented: unsupported {dtype=}") + + +def get_min_heads_per_blk(num_q_heads, num_kv_heads, q_dtype, kv_dtype): + q_packing = get_dtype_packing(q_dtype) + kv_packing = get_dtype_packing(kv_dtype) + + def can_be_xla_fully_tiled(x, packing): + if x % packing != 0: + return False + x //= packing + return x in (1, 2, 4, 8) or x % 8 == 0 + + # TODO(jevinjiang): support unaligned number of heads! + if not can_be_xla_fully_tiled(num_kv_heads, kv_packing): + raise ValueError( + f"Not implemented: {num_kv_heads=} can not be XLA fully tiled." + ) + assert num_q_heads % num_kv_heads == 0 + ratio = num_q_heads // num_kv_heads + # TODO(jevinjiang): we can choose smaller tiling for packed type if large + # second minor tiling is not on. + max_kv_tiling = 8 * kv_packing + min_kv_heads = ( + max_kv_tiling if num_kv_heads % max_kv_tiling == 0 else num_kv_heads + ) + min_q_heads = min_kv_heads * ratio + if can_be_xla_fully_tiled(min_q_heads, q_packing): + return min_q_heads, min_kv_heads + return num_q_heads, num_kv_heads + + +@functools.partial( + jax.jit, + static_argnames=[ + "sm_scale", + "mask_value", + "num_kv_pages_per_block", + "num_queries_per_block", + "vmem_limit_bytes", + ], +) +def ragged_paged_attention( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + # TODO(jevinjiang): create a write_to_kv_cache kernel! + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1] + *, + sm_scale: float = 1.0, + mask_value: float = DEFAULT_MASK_VALUE, + num_kv_pages_per_block: int = 16, + num_queries_per_block: int = 128, + vmem_limit_bytes: int | None = None, +): + """Ragged paged attention that supports mixed prefill and decode. + + Args: + q: concatenated all sequences' queries. + k_pages: paged K cache. Normally in HBM. + v_pages: paged V cache. Normally in HBM. + kv_lens: padded kv lengths. Only the first num_seqs values are valid. + page_indices: the first index indicates which page to use in the kv cache + for each sequence. Only the first num_seqs values are valid. + cu_q_lens: the cumulative sum of the effective query lengths. Similar to + kv_lens, only the first num_seqs+1 values are valid. + num_seqs: the dynamic number of sequences. + sm_scale: the softmax scale which will be applied to the Q@K^T. + mask_value: mask value for causal mask. + num_kv_pages_per_block: number of kv pages to be processed in one flash + attention block in the pallas kernel. + num_queries_per_block: number of kv pages to be processed in one flash + attention block in the pallas kernel. + vmem_limit_bytes: the vmem limit for the pallas kernel. + + Returns: + The output of the attention. + """ + check_inputs_shapes( + q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs + ) + _, num_q_heads, head_dim = q.shape + _, page_size, num_kv_heads, _ = k_pages.shape + num_q_per_blk = num_queries_per_block + num_kv_pages_per_blk = num_kv_pages_per_block + num_q_heads_per_kv_head = num_q_heads // num_kv_heads + num_q_blks = ceil_div(cu_q_lens[num_seqs[0]], num_q_per_blk) + num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk( + num_q_heads, num_kv_heads, q.dtype, k_pages.dtype + ) + assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 + num_heads_blks = num_q_heads // num_q_heads_per_blk + grid = (num_heads_blks, num_q_blks) + + def q_index_map(heads_blk_idx, q_blk_idx, *_): + return (q_blk_idx, heads_blk_idx, 0) + + q_block_spec = pl.BlockSpec( + (num_q_per_blk, num_q_heads_per_blk, head_dim), + q_index_map, + ) + in_specs = [ + q_block_spec, + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ] + out_specs = q_block_spec + lm_scratch = pltpu.VMEM( + # TODO(jevinjiang): use 128 instead of 1 is due to Mosaic does not support + # unaligned slicing! + (num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128), + jnp.float32, + ) + double_buf_scratch = pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + num_kv_pages_per_blk, + page_size, + num_kv_heads_per_blk, + head_dim, + ), + k_pages.dtype, + ) + scratch_shapes = [ + double_buf_scratch, # k_bufs + double_buf_scratch, # v_bufs + pltpu.SemaphoreType.DMA((2, 2)), # [double_buffers, k_sem/v_sem] + lm_scratch, # l_ref + lm_scratch, # m_ref + ] + scalar_prefetches = ( + kv_lens, + page_indices, + cu_q_lens, + jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx + num_seqs, + ) + kernel = pl.pallas_call( + functools.partial( + ragged_paged_attention_kernel, + sm_scale=sm_scale, + mask_value=mask_value, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=len(scalar_prefetches), + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + scratch_shapes=scratch_shapes, + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=( + "arbitrary", + "arbitrary", + ), + vmem_limit_bytes=vmem_limit_bytes, + ), + out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32), + name="ragged_paged_attention_kernel", + ) + # TODO(jevinjiang): Use f32 acc scratch for output! So we only need + # to transfer output with desired dtype back to HBM. + return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 9a42d64d3..987a3aa9d 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -506,6 +506,24 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_multiplatform_test( + name = "tpu_ragged_paged_attention_test", + srcs = ["tpu_ragged_paged_attention_test.py"], + disable_configs = [ + "tpu_v5p_1x1", + ], + enable_backends = ["tpu"], + shard_count = 24, + tags = [ + "noasan", # Times out. + "nomsan", # Times out. + "notsan", # Times out. + ], + deps = [ + "//jax:pallas_tpu_ops", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_multiplatform_test( name = "tpu_splash_attention_kernel_test", srcs = [ diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py new file mode 100644 index 000000000..1ed12aecf --- /dev/null +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -0,0 +1,305 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax.experimental.pallas.ops.tpu.ragged_paged_attention import ( + ragged_paged_attention, + ref_ragged_paged_attention, + validate_inputs_on_runtime, +) +import jax.numpy as jnp + + +jax.config.parse_flags_with_absl() + + +def ceil_div(x, a): + assert a != 0 + return (x + a - 1) // a + + +@jtu.with_config(jax_numpy_dtype_promotion="standard") +class PagedAttentionKernelTest(jtu.JaxTestCase): + + def _test_ragged_paged_attention( + self, + seq_lens, # List[(q_len, kv_len)] + num_heads, # [num_q_heads, num_kv_heads] + head_dim, + page_size, + dtype, + num_pages, + *, + num_kv_pages_per_block=8, + num_queries_per_block=64, + vmem_limit_bytes=32 * 1024 * 1024, + max_num_batched_tokens=512, + max_num_seq=8, + ): + if not jtu.is_device_tpu_at_least(version=4): + self.skipTest("Expect TPUv4+") + cu_q_lens = [0] + kv_lens = [] + for q_len, kv_len in seq_lens: + assert q_len <= kv_len + cu_q_lens.append(cu_q_lens[-1] + q_len) + kv_lens.append(kv_len) + + max_num_batched_tokens = max(cu_q_lens[-1], max_num_batched_tokens) + max_num_seq = max(len(seq_lens), max_num_seq) + max_kv_len = max(kv_lens) + pages_per_seq = ceil_div(max_kv_len, page_size) + pages_per_seq = ( + ceil_div(pages_per_seq, num_kv_pages_per_block) + * num_kv_pages_per_block + ) + num_q_heads, num_kv_heads = num_heads + + cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32) + kv_lens = jnp.array(kv_lens, dtype=jnp.int32) + cu_q_lens = jnp.pad(cu_q_lens, (0, max_num_seq + 1 - cu_q_lens.shape[0])) + kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0])) + prng_key = jax.random.key(1234) + k0, k1, k2, k3 = jax.random.split(prng_key, 4) + q = jax.random.normal( + k0, + (max_num_batched_tokens, num_q_heads, head_dim), + dtype=dtype, + ) + k_pages = jax.random.normal( + k1, + (num_pages, page_size, num_kv_heads, head_dim), + dtype=dtype, + ) + v_pages = jax.random.normal( + k2, + (num_pages, page_size, num_kv_heads, head_dim), + dtype=dtype, + ) + page_indices = jax.random.randint( + k3, (max_num_seq, pages_per_seq), 0, num_pages, dtype=jnp.int32 + ) + + num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32) + + validate_inputs_on_runtime( + q, + k_pages, + v_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + ) + + output = ragged_paged_attention( + q, + k_pages, + v_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs=num_seqs, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + )[: cu_q_lens[num_seqs[0]]] + + expected = ref_ragged_paged_attention( + q, + k_pages, + v_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs=num_seqs, + ) + tols = { + "float32": 1e-1, + "bfloat16": 2e-1, + } + tol = tols[jnp.dtype(dtype).name] + self.assertAllClose(output, expected, atol=tol, rtol=tol) + + @parameterized.product( + dtype=[jnp.float32, jnp.bfloat16], + ) + def test_ragged_paged_attention_basic(self, dtype): + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @parameterized.product( + dtype=[jnp.float32, jnp.bfloat16], + ) + def test_ragged_paged_attention_decode_only(self, dtype): + seq_lens = [ + (1, 18), + (1, 129), + (1, 597), + (1, 122), + (1, 64), + (1, 322), + (1, 463), + (1, 181), + (1, 1107), + (1, 123), + (1, 31), + (1, 18), + (1, 1229), + (1, 229), + (1, 87), + (1, 1328), + ] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @parameterized.product( + dtype=[jnp.float32, jnp.bfloat16], + ) + def test_ragged_paged_attention_prefill_only(self, dtype): + seq_lens = [ + (5, 18), + (15, 129), + (120, 597), + (100, 122), + (21, 64), + (32, 322), + (251, 463), + (40, 181), + (64, 1107), + (99, 123), + (10, 31), + (5, 18), + (3, 1229), + (120, 229), + (9, 87), + (2, 1328), + ] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @parameterized.product( + dtype=[jnp.float32, jnp.bfloat16], + ) + def test_ragged_paged_attention_mixed(self, dtype): + seq_lens = [ + (5, 18), + (1, 129), + (120, 597), + (1, 122), + (1, 64), + (32, 322), + (251, 463), + (1, 181), + (1, 1107), + (99, 123), + (1, 31), + (5, 18), + (3, 1229), + (117, 229), + (1, 87), + (1, 1328), + ] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @parameterized.product( + num_seqs=[1, 5, 16], + # TODO(jevinjiang): Support more num_heads! + num_heads=[(32, 8), (32, 16), (12, 2)], + dtype=[jnp.float32, jnp.bfloat16], + num_kv_pages_per_block=[4, 8], + num_queries_per_block=[32, 64], + ) + def test_ragged_paged_attention_complex( + self, + num_seqs, + num_heads, + dtype, + num_kv_pages_per_block, + num_queries_per_block, + ): + seq_lens = [] + for _ in range(num_seqs): + q_len = random.randint(1, 100) + kv_len = q_len + random.randint(0, 50) + seq_lens.append((q_len, kv_len)) + # TODO(jevinjiang): Support non-128 head_dim! + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From b441b2b7a5003924957e9a781fc4b6838455f776 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 6 Mar 2025 14:38:11 -0800 Subject: [PATCH 047/100] Prevent tracer leaks in scipy.special.expn --- jax/_src/scipy/special.py | 14 +++++++------- tests/lax_scipy_special_functions_test.py | 5 +++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 9c7528d90..a24736ccf 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -2106,7 +2106,7 @@ def expi_jvp(primals, tangents): return expi(x), jnp.exp(x) / x * x_dot -def _expn1(n: Array, x: Array) -> Array: +def _expn1(x: Array, n: Array) -> Array: # exponential integral En _c = _lax_const MACHEP = jnp.finfo(x.dtype).eps @@ -2143,7 +2143,7 @@ def _expn1(n: Array, x: Array) -> Array: return d["z"] ** r * psi / jnp.exp(gammaln(t)) - d["ans"] -def _expn2(n: Array, x: Array) -> Array: +def _expn2(x: Array, n: Array) -> Array: # x > 1. _c = _lax_const BIG = _c(x, 1.44115188075855872e17) @@ -2194,7 +2194,7 @@ def _expn2(n: Array, x: Array) -> Array: return d["ans"] * jnp.exp(-x) -def _expn3(n: Array, x: Array) -> Array: +def _expn3(x: Array, n: Array) -> Array: # n >= 5000 _c = _lax_const one = _c(x, 1.0) @@ -2248,11 +2248,11 @@ def expn(n: ArrayLike, x: ArrayLike) -> Array: jnp.inf, one / n1, # prevent div by zero jnp.exp(-x) / x, - partial(_expn3, n), - partial(_expn2, n), - partial(_expn1, n), + _expn3, + _expn2, + _expn1, ] - ret = jnp.piecewise(x, conds, vals) + ret = jnp.piecewise(x, conds, vals, n=n) return ret diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 575362895..96d48dcd3 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -273,6 +273,11 @@ class LaxScipySpcialFunctionsTest(jtu.JaxTestCase): with self.assertRaises(TypeError): lsp_special.beta(x=1, y=1) + def testExpnTracerLeaks(self): + # Regression test for https://github.com/jax-ml/jax/issues/26972 + with jax.checking_leaks(): + lsp_special.expi(jnp.ones(())) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From cd7f03f2723b36844a20b724ff92ba7604adeaff Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 6 Mar 2025 14:57:18 -0800 Subject: [PATCH 048/100] Updates the Colocated Python's serialization (and deserialization) implementation to utilize the recently added support for string arrays. Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively. PiperOrigin-RevId: 734299259 --- jax/experimental/colocated_python/func.py | 2 +- .../colocated_python/serialization.py | 83 ++++++++----------- tests/colocated_python_test.py | 12 ++- 3 files changed, 42 insertions(+), 55 deletions(-) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index 8e2883ea4..effca1fe7 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -201,7 +201,7 @@ def _make_output_specs_and_push_result_fun( devices = specialization.devices - def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]: + def lowered_fun(*args, **kwargs) -> jax.Array: result = info.fun(*args, **kwargs) result_leaves, out_treedef = tree_util.tree_flatten(result) out_spec_leaves = tuple(_get_spec(x) for x in result_leaves) diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py index bfd5ec2e6..1ca29ab12 100644 --- a/jax/experimental/colocated_python/serialization.py +++ b/jax/experimental/colocated_python/serialization.py @@ -13,11 +13,9 @@ # limitations under the License. """Colocated Python serialization utilities.""" -# TODO(jmudigonda): Use a string-typed array for output structure when it -# becomes available. Using a fixed uint8 array is only for prototyping. - from __future__ import annotations +import base64 import collections import functools import io @@ -37,12 +35,6 @@ import numpy as np DeviceList = xc.DeviceList -# Hard-coded limit for serialized specs size. -# TODO(jmudigonda): Use a string-typed array for output structure when it -# becomes available. Using a fixed uint8 array is only for prototyping. -_MAX_SERIALIZED_SPECS_SIZE = 1048576 - - @jax.util.cache(max_size=None) def _get_cpu_device_map() -> dict[int, jax.Device]: """Returns a map from a device id to a matching device.""" @@ -185,23 +177,14 @@ def _deserialize(serialized: bytes) -> Any: def _make_specs_for_serialized_specs( devices: DeviceList, -) -> tuple[api.ShapeDtypeStruct, api.ShapeDtypeStruct]: +) -> api.ShapeDtypeStruct: """Makes output specs for serialized specs.""" - # TODO(jmudigonda): Use a string-typed array for output structure when it - # becomes available. Using a fixed uint8 array is only for prototyping. mesh = jax.sharding.Mesh(tuple(devices), ("x",)) replicated_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec() ) - return ( - api.ShapeDtypeStruct( - shape=(), dtype=np.int32, sharding=replicated_sharding - ), - api.ShapeDtypeStruct( - shape=(_MAX_SERIALIZED_SPECS_SIZE,), - dtype=np.uint8, - sharding=replicated_sharding, - ), + return api.ShapeDtypeStruct( + shape=(), dtype=np.dtypes.StringDType(), sharding=replicated_sharding # type: ignore ) @@ -209,49 +192,49 @@ def _serialize_specs( specs_treedef: tree_util.PyTreeDef, specs_leaves: tuple[api.ShapeDtypeStruct, ...], devices: DeviceList, -) -> tuple[jax.Array, ...]: - """Serializes the output specs into a tuple of arrays. +) -> jax.Array: + """Serializes the output specs into a jax.Array of string type. DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF colocated_python. See serialize() for details. """ - s = _serialize((specs_treedef, specs_leaves)) - assert ( - len(s) <= _MAX_SERIALIZED_SPECS_SIZE - ), f"Too large serialized spec size: {len(s)}" - # TODO(jmudigonda): Use a string-typed array for output structure when it - # becomes available. Using a fixed uint8 array is only for prototyping. - mesh = jax.sharding.Mesh(tuple(devices), ("x",)) + if not hasattr(np.dtypes, "StringDType"): + raise TypeError( + "Serializing Colocated Python requires StringDType. Please use" + " numpy to 2.0.0 or later, or explicityly provide an output spec" + " function." + ) + + s_bytes = _serialize((specs_treedef, specs_leaves)) + s_str = base64.b64encode(s_bytes).decode("ascii") + s_np_array = np.array(s_str, dtype=np.dtypes.StringDType()) # type: ignore + + # TODO(jmudigonda): Revisit this when JAX supports HLO sharding for making + # jax.Array via make_array_from_single_device_arrays. We should then use a + # sharding that spans all the execution devices - not just the addressable + # ones. + addressable_devices = devices.addressable_device_list + mesh = jax.sharding.Mesh(tuple(addressable_devices), ("x",)) replicated_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec() ) - len_array = jax.make_array_from_callback( - shape=(), - sharding=replicated_sharding, - data_callback=lambda _: np.array(len(s), dtype=np.int32), + + out_arrays = [ + jax.device_put(s_np_array, device) for device in addressable_devices + ] + return jax.make_array_from_single_device_arrays( + arrays=out_arrays, sharding=replicated_sharding, shape=(), ) - data_array = jax.make_array_from_callback( - shape=(_MAX_SERIALIZED_SPECS_SIZE,), - sharding=replicated_sharding, - data_callback=lambda _: np.frombuffer( - s + b"\0" * (_MAX_SERIALIZED_SPECS_SIZE - len(s)), - dtype=np.uint8, - ), - ) - return len_array, data_array def _deserialize_specs( - serialized_specs: tuple[jax.Array, ...], + serialized_specs: jax.Array, ) -> tuple[tree_util.PyTreeDef, tuple[api.ShapeDtypeStruct, ...]]: """Deserializes the specs from the serialized specs. DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF colocated_python. See serialize() for details. """ - # TODO(jmudigonda): Use a string-typed array for output structure when it - # becomes available. Using a fixed uint8 array is only for prototyping. - len_array, data_array = serialized_specs - length = int(len_array.addressable_shards[0].data) - data = np.asarray(data_array.addressable_shards[0].data).tobytes() - return _deserialize(data[:length]) + data_array = serialized_specs.addressable_shards[0].data + data = base64.b64decode(data_array.item().encode("ascii")) + return _deserialize(data) diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index d6abe8bec..52d494904 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -66,6 +66,14 @@ _count_colocated_python_specialization_cache_miss = jtu.count_events( class ColocatedPythonTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if np.lib.NumpyVersion(np.__version__) < "2.0.0": + self.skipTest( + "Serialization in Colocated Python needs StringDType, and thus" + " requires NumPy 2.0.0 or later" + ) + def testMakeColocatedPythonProgram(self): def add_one(x): return x + 1 @@ -382,8 +390,6 @@ class ColocatedPythonTest(jtu.JaxTestCase): del colocated_python._testing_global_state def testStringProcessing(self): - if np.lib.NumpyVersion(np.__version__) < "2.0.0": - self.skipTest("StringDType requires NumPy 2.0.0 or later") cpu_devices = _colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 2: self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") @@ -425,8 +431,6 @@ class ColocatedPythonTest(jtu.JaxTestCase): ) def testBinaryDataProcessing(self): - if np.lib.NumpyVersion(np.__version__) < "2.0.0": - self.skipTest("StringDType requires NumPy 2.0.0 or later") cpu_devices = _colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 1: self.skipTest("Need at least one CPU devices") From e9486920e82d2be50b336ef8b48765109a091797 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 6 Mar 2025 16:09:41 -0800 Subject: [PATCH 049/100] Auto complete specs in a sharding if aval.ndim > len(sharding.spec) with `None`. So that for a 2D input, P('data') continues to work. PiperOrigin-RevId: 734325209 --- jax/_src/core.py | 3 +++ jax/_src/numpy/einsum.py | 3 +++ tests/pjit_test.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+) diff --git a/jax/_src/core.py b/jax/_src/core.py index 210f8cb68..17b842359 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1815,6 +1815,7 @@ def _make_lengths_same(sharding, ndim): if ndim > len(sharding.spec): return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim)) if ndim < len(sharding.spec): + assert all(s is None for s in sharding.spec[ndim:]) return sharding.with_spec(sharding.spec[:ndim]) assert False, "unreachable" @@ -1840,6 +1841,8 @@ def _maybe_modify_sharding(sharding, ndim): return sharding if sharding.mesh._are_all_axes_explicit: + if ndim > len(sharding.spec): + return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim)) return sharding out = sharding.with_spec(modify_spec_for_auto_manual( diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 8aa118959..9d745643b 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -547,6 +547,9 @@ def _einsum( if not last_contraction: dot_general_out_sharding = None elif out_sharding is not None and names != result_names: + if len(result_names) > len(out_sharding.spec): + out_sharding = out_sharding.with_spec( + out_sharding.spec._normalized_spec_for_aval(len(result_names))) spec = out_sharding.spec inverse_spec = tuple(spec[result_names.index(name)] for name in names) dot_general_out_sharding = NamedSharding( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 4c20ac649..4b267196e 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6264,6 +6264,13 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, None, None, None))) + def test_aval_spec_explicit_auto_complete(self): + abstract_mesh = mesh_lib.AbstractMesh( + (('x', 2),), axis_types={AxisTypes.Explicit: 'x'}) + s = NamedSharding(abstract_mesh, P('x')) + out = core.ShapedArray((8, 2), jnp.int32, sharding=s) + self.assertEqual(out.sharding.spec, P('x', None)) + @jtu.with_user_mesh((2, 2), ('x', 'y'), axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')}) def test_full_user_mode(self, mesh): @@ -6318,6 +6325,34 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out.shape, (16, 8, 16)) self.assertEqual(out.sharding, NamedSharding(mesh, P('data', None, None))) + @jtu.with_user_mesh((4,), ('data',)) + def test_intermediate_einsum_auto_complete_spec(self, mesh): + s = NamedSharding(mesh, P('data')) + + shape1 = (8, 32, 2*16) + shape2 = (8, 32, 2, 8) + shape3 = (8, 32, 2, 8) + np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) + np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + np_inp3 = np.arange(math.prod(shape3)).reshape(shape3) + + arr1 = jax.device_put(np_inp1, s) + arr2 = jax.device_put(np_inp2, s) + arr3 = jax.device_put(np_inp3, s) + + @jax.jit + def f(x, y, z): + x = jnp.reshape(x, (8, 32, 2, 16)) + out = jnp.einsum('bthD, bthi, bthj->ijD', x, y, z, + out_sharding=P('data')) + self.assertEqual(out.shape, (8, 8, 16)) + self.assertEqual(out.aval.sharding.spec, P('data', None, None)) + return out + + out = f(arr1, arr2, arr3) + self.assertEqual(out.shape, (8, 8, 16)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('data', None, None))) + def test_where_with_prng_sharded_inp(self): mesh = jax.sharding.Mesh(jax.devices(), axis_names=['batch']) sharding = jax.sharding.NamedSharding( From ff4310f6402438fabe60688182902de3217dde58 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 6 Mar 2025 17:18:39 -0800 Subject: [PATCH 050/100] [Mosaic TPU] Support fp8 upcast to f32 PiperOrigin-RevId: 734345644 --- .../tpu/transforms/apply_vector_layout.cc | 6 ++---- .../tpu/transforms/infer_vector_layout.cc | 3 +-- tests/pallas/ops_test.py | 20 +++++++++++++++++++ 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 62fa622ae..17e8d12b5 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -684,10 +684,8 @@ LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_OP(layouts_in.front().has_value()); TPU_ASSERT_OP(layouts_out.front().has_value()); auto extf_op = cast(op); - if (layouts_in.front()->bitwidth() != 16 || - layouts_out.front()->bitwidth() != 32) { - return op.emitOpError( - "Not implemented: Only 16-bit to 32-bit conversion supported"); + if (layouts_out.front()->bitwidth() != 32) { + return op.emitOpError("Not implemented: Only support conversion to 32-bit"); } ImplicitLocOpBuilder builder(op.getLoc(), &op); FAILUREOR_ASSIGN_OR_RETURN( diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 33912fddf..082e1204c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1674,8 +1674,7 @@ class VectorLayoutInferer { auto some_layout = getLayout(op->getOperand(0)); TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); if (dyn_cast(op)) { - TPU_CHECK_OP(src_bitwidth == 16 && dst_bitwidth == 32, - "Only 16-bit to 32-bit extensions supported"); + TPU_CHECK_OP(dst_bitwidth == 32, "Only supported extensions to 32-bit"); } auto &layout = *some_layout; Layout src_layout; diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index d79a172ed..0cabb4bfe 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -106,6 +106,9 @@ _DTYPES_SUB_32BIT = ( "uint8", "uint4", "bool", + "float8_e4m3b11fnuz", + "float8_e5m2", + "float8_e4m3fn", ) _DTYPES = (*_DTYPES_32BIT, *_DTYPES_SUB_32BIT) @@ -567,6 +570,11 @@ class OpsTest(PallasBaseTest): @hp.given(hps.data()) def test_cast_from_32bit(self, from_dtype, to_dtype, data): self.skip_if_mosaic_gpu() + if to_dtype in {"float8_e4m3b11fnuz", "float8_e5m2", "float8_e4m3fn"}: + if not jtu.test_device_matches(["tpu"]) or jtu.get_tpu_version() < 5: + self.skipTest("Not supported on this hardware") + if not jtu.if_cloud_tpu_at_least(2025, 3, 8): + self.skipTest("Test requires libtpu from 2025/3/8 or later") if from_dtype == to_dtype: self.skipTest("Unnecessary test") @@ -633,6 +641,18 @@ class OpsTest(PallasBaseTest): if jtu.test_device_matches(["gpu"]) and to_dtype in {"int4", "uint4"}: self.skipTest("int4/uint4 casts are buggy on GPU") # b/391292861 + if from_dtype in { + "float8_e4m3b11fnuz", + "float8_e5m2", + "float8_e4m3fn", + } or to_dtype in {"float8_e4m3b11fnuz", "float8_e5m2", "float8_e4m3fn"}: + if not jtu.test_device_matches(["tpu"]) or jtu.get_tpu_version() < 5: + self.skipTest("Not supported on this hardware") + if not jtu.if_cloud_tpu_at_least(2025, 3, 8): + self.skipTest("Test requires libtpu from 2025/3/8 or later") + if to_dtype not in {"float32", "int32", "uint32"}: + self.skipTest("Only fp8 to x32 cast is supported") + from_int = np.issubdtype(np.dtype(from_dtype), np.integer) to_int = np.issubdtype(np.dtype(to_dtype), np.integer) if ( From 8095d842c897edac252c6178ba63df2598ae0ba2 Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Thu, 6 Mar 2025 17:43:58 -0800 Subject: [PATCH 051/100] roofline: Support computing flops for unary ops. PiperOrigin-RevId: 734351741 --- jax/experimental/roofline/rooflines.py | 51 +++++++++++++++++++++++++ tests/roofline_test.py | 53 ++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index f80e5c501..74edcb9cd 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -55,6 +55,57 @@ for prim in it.chain( roofline.register_standard_roofline(prim) +def _unary_p_roofline( + ctx: roofline.RooflineRuleContext, + *args, + **kw, +) -> roofline.RooflineResult: + (x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + return roofline.RooflineResult( + unfused_flops=x.size, + unfused_hbm_bytes=( + x.dtype.itemsize * x.size + out.dtype.itemsize * out.size + ), + ) + +roofline.register_roofline(lax.abs_p)(_unary_p_roofline) +roofline.register_roofline(lax.acos_p)(_unary_p_roofline) +roofline.register_roofline(lax.asin_p)(_unary_p_roofline) +roofline.register_roofline(lax.atan_p)(_unary_p_roofline) +roofline.register_roofline(lax.cbrt_p)(_unary_p_roofline) +roofline.register_roofline(lax.ceil_p)(_unary_p_roofline) +roofline.register_roofline(lax.conj_p)(_unary_p_roofline) +roofline.register_roofline(lax.cos_p)(_unary_p_roofline) +roofline.register_roofline(lax.cosh_p)(_unary_p_roofline) +roofline.register_roofline(lax.exp_p)(_unary_p_roofline) +roofline.register_roofline(lax.expm1_p)(_unary_p_roofline) +roofline.register_roofline(lax.floor_p)(_unary_p_roofline) +roofline.register_roofline(lax.imag_p)(_unary_p_roofline) +roofline.register_roofline(lax.integer_pow_p)(_unary_p_roofline) +roofline.register_roofline(lax.is_finite_p)(_unary_p_roofline) +roofline.register_roofline(lax.log_p)(_unary_p_roofline) +roofline.register_roofline(lax.log1p_p)(_unary_p_roofline) +roofline.register_roofline(lax.logistic_p)(_unary_p_roofline) +roofline.register_roofline(lax.neg_p)(_unary_p_roofline) +roofline.register_roofline(lax.not_p)(_unary_p_roofline) +roofline.register_roofline(lax.real_p)(_unary_p_roofline) +roofline.register_roofline(lax.round_p)(_unary_p_roofline) +roofline.register_roofline(lax.rsqrt_p)(_unary_p_roofline) +roofline.register_roofline(lax.sign_p)(_unary_p_roofline) +roofline.register_roofline(lax.sin_p)(_unary_p_roofline) +roofline.register_roofline(lax.sinh_p)(_unary_p_roofline) +roofline.register_roofline(lax.sqrt_p)(_unary_p_roofline) +roofline.register_roofline(lax.square_p)(_unary_p_roofline) +roofline.register_roofline(lax.tan_p)(_unary_p_roofline) +roofline.register_roofline(special.bessel_i0e_p)(_unary_p_roofline) +roofline.register_roofline(special.bessel_i1e_p)(_unary_p_roofline) +roofline.register_roofline(special.digamma_p)(_unary_p_roofline) +roofline.register_roofline(special.erf_inv_p)(_unary_p_roofline) +roofline.register_roofline(special.erf_p)(_unary_p_roofline) +roofline.register_roofline(special.erfc_p)(_unary_p_roofline) +roofline.register_roofline(special.lgamma_p)(_unary_p_roofline) + def _binary_p_roofline( ctx: roofline.RooflineRuleContext, *args, diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 4ad83556f..2fd3a24d3 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -424,6 +424,59 @@ class RooflineTest(jtu.JaxTestCase): ) self.assertDataclassEqual(bwd_results, expected) + @jtu.parameterized.named_parameters( + ("abs", lax.abs, float), + ("acos", lax.acos, float), + ("asin", lax.asin, float), + ("atan", lax.atan, float), + ("cbrt", lax.cbrt, float), + ("ceil", lax.ceil, float), + ("conj", lax.conj, complex), + ("cos", lax.cos, float), + ("cosh", lax.cosh, float), + ("exp", lax.exp, float), + ("expm1", lax.expm1, float), + ("floor", lax.floor, float), + ("imag", lax.imag, complex), + ("integer_pow", lambda a: lax.integer_pow(a, 5), int), + ("is_finite", lax.is_finite, float), + ("log", lax.log, float), + ("log1p", lax.log1p, float), + ("logistic", lax.logistic, float), + ("neg", lax.neg, float), + ("not", lax.bitwise_not, bool), + ("real", lax.real, complex), + ("round", lax.round, float), + ("rsqrt", lax.rsqrt, float), + ("sign", lax.sign, float), + ("sin", lax.sin, float), + ("sinh", lax.sinh, float), + ("sqrt", lax.sqrt, float), + ("square", lax.square, float), + ("tan", lax.tan, float), + ("bessel_i0e", lax.bessel_i0e, float), + ("bessel_i1e", lax.bessel_i1e, float), + ("digamma", lax.digamma, float), + ("erf_inv", lax.erf_inv, float), + ("erf", lax.erf, float), + ("erfc", lax.erfc, float), + ("lgamma", lax.lgamma, float), + ) + def test_unary_ops(self, f, dtype): + data = jnp.zeros((3, 8), dtype=dtype) + out, result = roofline.roofline( + f, + in_specs=(P()), + out_specs=P(), + )(data) + with self.subTest("flops"): + self.assertEqual(result.unfused_flops, 3 * 8) + with self.subTest("hbm_bytes"): + self.assertEqual( + result.unfused_hbm_bytes, + data.dtype.itemsize * 3 * 8 + out.dtype.itemsize * 3 * 8, + ) + def test_binary_ops(self): for f in [ lambda a, b: a ^ b, From ccbe9f7cd60b4f654950da02927a656449c29858 Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 7 Mar 2025 04:52:58 +0000 Subject: [PATCH 052/100] Fix lint --- jax/_src/dtypes.py | 6 ++++++ tests/dtypes_test.py | 5 +++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 853fb5d1c..01500c008 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -730,6 +730,12 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy "promotion path. To avoid unintended promotion, 8-bit floats do not support " "implicit promotion. If you'd like your inputs to be promoted to another type, " "you can do so explicitly using e.g. x.astype('float32')") + elif any(n in _float4_dtypes for n in nodes): + msg = ( + f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype " + "promotion path. To avoid unintended promotion, 4-bit floats do not support " + "implicit promotion. If you'd like your inputs to be promoted to another type, " + "you can do so explicitly using e.g. x.astype('float32')") elif any(n in _intn_dtypes for n in nodes): msg = ( f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype " diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index fca3f4320..87380443f 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -989,7 +989,8 @@ class TestPromotionTables(jtu.JaxTestCase): def testFloat4PromotionError(self): for dtype in fp4_dtypes: if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): - self.skipTest("TPU does not support float4_e2m1fn.") + # TPU does not support float4_e2m1fn. + continue x = jnp.array(1, dtype=dtype) y = jnp.array(1, dtype='float32') with self.assertRaisesRegex(dtypes.TypePromotionError, @@ -1055,7 +1056,7 @@ class TestPromotionTables(jtu.JaxTestCase): if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']): self.skipTest('TPU does not support float8_e8m0fnu.') if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']): - self.skipTest('TPU does not support float4_e2m1fn.') + self.skipTest('TPU does not support float4_e2m1fn.') val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) rep = repr(val) self.assertStartsWith(rep, 'Array(') From bf95bf49d405dbaeae9c9860dc40a00e6fecd86f Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 7 Mar 2025 02:59:39 -0800 Subject: [PATCH 053/100] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f1213b83af673729b60f5096da5186246568c0fb. PiperOrigin-RevId: 734484617 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 4e3e77ea6..4ac4564a7 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "6e396aae2e534dc7fc5387e2aa8b1a3a8d79a3db" -XLA_SHA256 = "03e73863ff041c57d2fdefb4216c2774bb12faf70dd083dbe8e961ccb9ea42b2" +XLA_COMMIT = "f1213b83af673729b60f5096da5186246568c0fb" +XLA_SHA256 = "77b886c9700d1f9a2ed65f18c176ddb38ffe6905128690f19e1fd7ca624dbebd" def repo(): tf_http_archive( From e6db7a9d99fbfa1a2de9fe649189611fa2e9b6ee Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Fri, 7 Mar 2025 04:00:57 -0800 Subject: [PATCH 054/100] Dedup non-ref constants closed in cond branch functions. PiperOrigin-RevId: 734497907 --- jax/_src/lax/control_flow/common.py | 48 ++++++++++++----------- jax/_src/lax/control_flow/conditionals.py | 3 +- tests/api_test.py | 10 ++--- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index cecd1cdc5..b75cbf6ac 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -27,7 +27,6 @@ from jax._src.lax import lax from jax._src import effects from jax._src import ad_util from jax._src import state -from jax._src import util from jax._src.util import weakref_lru_cache, safe_map, partition_list from jax._src.interpreters import partial_eval as pe from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef @@ -144,52 +143,55 @@ def _initial_style_jaxprs_with_common_consts( # b[] <- 2.0 # in () } canonical_ref_indices = [] + canonical_non_ref_indices = [] canonical_refs: list[Any] = [] - tracer_id_to_canonical_id = {} - all_nonref_consts = [] + canonical_non_refs: list[Any] = [] + tracer_id_to_canonical_ref_id = {} + tracer_id_to_canonical_non_ref_id = {} canonical_ref_avals = [] - all_nonref_const_avals = [] + canonical_non_ref_avals = [] for consts, consts_avals in zip(all_consts, all_const_avals): ref_indices = [] - nonref_consts = [] - nonref_const_avals = [] + non_ref_indices = [] for c, aval in zip(consts, consts_avals): + tracer_id = id(c) if isinstance(aval, state.AbstractRef): - tracer_id = id(c) - if tracer_id not in tracer_id_to_canonical_id: + if tracer_id not in tracer_id_to_canonical_ref_id: canonical_id = len(canonical_refs) canonical_refs.append(c) - tracer_id_to_canonical_id[tracer_id] = canonical_id + tracer_id_to_canonical_ref_id[tracer_id] = canonical_id canonical_ref_avals.append(aval) - canonical_id = tracer_id_to_canonical_id[tracer_id] + canonical_id = tracer_id_to_canonical_ref_id[tracer_id] ref_indices.append(canonical_id) else: - nonref_consts.append(c) - nonref_const_avals.append(aval) - all_nonref_consts.append(nonref_consts) - all_nonref_const_avals.append(tuple(nonref_const_avals)) + if tracer_id not in tracer_id_to_canonical_non_ref_id: + canonical_id = len(canonical_non_refs) + canonical_non_refs.append(c) + tracer_id_to_canonical_non_ref_id[tracer_id] = canonical_id + canonical_non_ref_avals.append(aval) + canonical_id = tracer_id_to_canonical_non_ref_id[tracer_id] + non_ref_indices.append(canonical_id) canonical_ref_indices.append(tuple(ref_indices)) + canonical_non_ref_indices.append(tuple(non_ref_indices)) - consts = [*canonical_refs, *util.concatenate(all_nonref_consts)] - jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*all_nonref_const_avals,)) + consts = [*canonical_refs, *canonical_non_refs] + jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*canonical_non_ref_avals,), (*canonical_non_ref_indices,)) for i, jaxpr in enumerate(jaxprs)) return jaxprs, consts, all_out_trees @weakref_lru_cache def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices, - all_nonref_const_avals): + canonical_non_ref_avals, canonical_non_ref_indices): is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars] nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars) newvar = core.gensym(suffix='_') - unused_const_vars = [tuple(map(newvar, const_avals)) - for const_avals in all_nonref_const_avals] padded_ref_constvars = map(newvar, canonical_ref_avals) + padded_non_ref_constvars = map(newvar, canonical_non_ref_avals) for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars): padded_ref_constvars[canonical_id] = ref_var - const_prefix = util.concatenate(unused_const_vars[:i]) - const_suffix = util.concatenate(unused_const_vars[i + 1:]) - constvars = [*padded_ref_constvars, *const_prefix, *nonref_constvars, - *const_suffix] + for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars): + padded_non_ref_constvars[canonical_id] = non_ref_var + constvars = [*padded_ref_constvars, *padded_non_ref_constvars] jaxpr = jaxpr.replace(constvars=constvars) effects = pe.make_jaxpr_effects(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, jaxpr.eqns) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index e2ad6ced1..63896cc2a 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -281,8 +281,9 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, num_consts = len(consts) out_ = iter(out) + all_inputs = [*consts, *ops] out = [ - next(out_) if fwd is None else lax.asarray(ops[fwd - num_consts]) + next(out_) if fwd is None else lax.asarray(all_inputs[fwd]) for fwd in in_fwd ] assert next(out_, None) is None diff --git a/tests/api_test.py b/tests/api_test.py index 571a33e24..ff729c03d 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6446,14 +6446,10 @@ class JaxprTest(jtu.JaxTestCase): e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b f:f32[] = cond[ branches=( - { lambda ; g_:f32[] h:f32[] i:f32[] j:f32[]. let - k:f32[] = sub j h - in (k,) } - { lambda ; l:f32[] m_:f32[] n:f32[] o:f32[]. let - p:f32[] = add n l - in (p,) } + { lambda ; g:f32[] h:f32[] i:f32[]. let j:f32[] = sub i g in (j,) } + { lambda ; k:f32[] l:f32[] m:f32[]. let n:f32[] = add l k in (n,) } ) - ] e a a c d + ] e a c d in (f,) }""" jaxpr = api.make_jaxpr(f)(jnp.float32(3.)) self.assertMultiLineStrippedEqual(expected, str(jaxpr)) From 85c6b6a128e19fb9358b84fa6a272e314c89e4e6 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 7 Mar 2025 05:18:24 -0800 Subject: [PATCH 055/100] [Mosaic GPU] Add support for tiling stores to refs using small tiling The difficulty here is that our register tiling is based on the (64, 8) shape, while the memory tiling is now (8, swizzle // bytewidth). Before, we would assume that each register tile fits neatly within a single memory tile, but now it is obviously not the case. Luckily, it wasn't too hard to add. PiperOrigin-RevId: 734517000 --- .../mosaic/gpu/fragmented_array.py | 149 ++++++++++++++---- jax/experimental/mosaic/gpu/utils.py | 1 + tests/mosaic/gpu_test.py | 7 +- 3 files changed, 125 insertions(+), 32 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index bba93b5de..d5400a641 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -22,6 +22,7 @@ import math from collections.abc import Callable from typing import Iterable, Protocol, Sequence, TypeVar +import itertools import jax from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith @@ -115,6 +116,65 @@ class Tiling: strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled) return strides + def tile_nested_shape_strides( + self, + shape: tuple[tuple[int, ...], ...], + strides: tuple[tuple[int, ...], ...], + ) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]: + """A fused version of `tile_shape` and `tile_strides` for nested shapes. + + By nested shape we mean that each logical dimension (i.e. each element of + shape/strides) is actually composed out of multiple physical dimensions. + For example, a row-major array of logical shape (128, 128) that is tiled + into (64, 64) tiles would have a nested shape ((2, 64), (2, 64)) (i.e. each + dim is split into two sub-dims) and nested strides of + ((2 * 64 * 64, 64), (64 * 64, 1)). + """ + if len(shape) != len(strides): + raise ValueError( + f"Shape {shape} and strides {strides} must have the same length" + ) + def fail_if(cond, shape=shape): # Capture shape now. + if cond: + raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}") + for tile in self.tiles: + fail_if(len(tile) > len(shape)) + untiled_shape, tiled_shape = shape[:-len(tile)], shape[-len(tile):] + untiled_strides, tiled_strides = strides[:-len(tile)], strides[-len(tile):] + major_dim_shapes, major_dim_strides = [], [] + minor_dim_shapes, minor_dim_strides = [], [] + for t, dim_shape, dim_strides in zip(tile, tiled_shape, tiled_strides): + major_dim_shape_rev, major_dim_stride_rev = [], [] + minor_dim_shape_rev, minor_dim_stride_rev = [], [] + for d, s in zip(reversed(dim_shape), reversed(dim_strides), strict=True): + if d < t: # We will need to tile more dims + fail_if(t % d != 0) + t //= d + minor_dim_shape_rev.append(d) + minor_dim_stride_rev.append(s) + elif t != 1: # Last dim to tile! + fail_if(d % t != 0) + minor_dim_shape_rev.append(t) + minor_dim_stride_rev.append(s) + if d != t: # No need to insert singleton dims. + major_dim_shape_rev.append(d // t) + major_dim_stride_rev.append(s * t) + t = 1 + else: # Done tiling! + major_dim_shape_rev.append(d) + major_dim_stride_rev.append(s) + fail_if(t != 1) + major_dim_shapes.append(major_dim_shape_rev[::-1]) + minor_dim_shapes.append(minor_dim_shape_rev[::-1]) + major_dim_strides.append(major_dim_stride_rev[::-1]) + minor_dim_strides.append(minor_dim_stride_rev[::-1]) + shape = (*untiled_shape, *major_dim_shapes, *minor_dim_shapes) + strides = (*untiled_strides, *major_dim_strides, *minor_dim_strides) + return ( + tuple(tuple(d) if d else (1,) for d in shape), + tuple(tuple(d) if d else (1,) for d in strides), + ) + def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: for tile in self.tiles: untiled, tiled = indices[:-len(tile)], indices[-len(tile):] @@ -214,7 +274,7 @@ class TiledLayout: index = ir.IndexType.get() contig_strides = utils.get_contiguous_strides(shape) tile_strides = self.tiling.tile_strides(contig_strides) - dyn_tile_strides = [c(s, i32) for s in tile_strides] + dyn_tile_strides = [c(s, i32) for s in tile_strides[-self.tiled_tiling_rank:]] warp_offset = utils.dyn_dot(self.warp_indices(), dyn_tile_strides) lane_offset = utils.dyn_dot(self.lane_indices(), dyn_tile_strides) dyn_offset = arith.addi(warp_offset, lane_offset) @@ -246,7 +306,12 @@ class TiledLayout: so the tiled shape always ends with this suffix, no matter what array shape it's applied to. """ - return self.tiling.tile_shape(self.base_tile_shape) + base_tile_shape = self.base_tile_shape + return self.tiling.tile_shape(base_tile_shape)[len(base_tile_shape):] + + @functools.cached_property + def tiled_tiling_rank(self) -> int: + return len(self.tiled_tiling_shape) @property def vector_length(self) -> int: @@ -292,16 +357,12 @@ class TiledLayout: def warp_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) - tiled_shape = tuple( - d if i == self.warp_dim else 1 - for i, d in enumerate_negative(self.tiled_tiling_shape) - ) - assert math.prod(tiled_shape) == WARPS_IN_WARPGROUP + tiled_shape_rank = len(self.tiled_tiling_shape) warp_idx = arith.remui( arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), c(WARPS_IN_WARPGROUP, i32), ) - indices = [arith.constant(i32, 0)] * len(tiled_shape) + indices = [arith.constant(i32, 0)] * tiled_shape_rank indices[self.warp_dim] = warp_idx return tuple(indices) @@ -1550,7 +1611,9 @@ class FragmentedArray: raise NotImplementedError(f"Unexpected ref space {ref_space}") ptr = utils.memref_ptr(ref, memory_space=memory_space) # Fold warp and lane offsets into the pointer once, since they are dynamic. - dyn_strides = [arith.constant(i32, s) for s in strides] + dyn_strides = [ + arith.constant(i32, s) for s in strides[-layout.tiled_tiling_rank :] + ] warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides) lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides) dyn_offset = arith.addi(warp_offset, lane_offset) @@ -1752,31 +1815,42 @@ class FragmentedArray: ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type if ref_ty.rank % 2: - raise ValueError("Tiled refence must have even rank") - ref_tiling_shape = tuple(ref_ty.shape[ref_ty.rank // 2:]) + raise ValueError("Tiled reference must have even rank") + ref_logical_rank = ref_ty.rank // 2 + ref_tiling_shape = tuple(ref_ty.shape[ref_logical_rank:]) ref_tiling = Tiling((ref_tiling_shape,)) ref_strides, _ = ref_ty.get_strides_and_offset() if ref_tiling.untile_shape(tuple(ref_ty.shape)) != shape: raise ValueError() - if len(layout.base_tile_shape) > len(ref_tiling_shape): - raise ValueError("Memory tiling must be a multiple of the register tiling") - ref_tiling_suffix = ref_tiling_shape[-len(layout.base_tile_shape):] - if any(t % wt for t, wt in zip(ref_tiling_suffix, layout.base_tile_shape)): - raise ValueError( - f"Memory tiling ({ref_tiling_suffix}) must be a multiple of the" - f" register tiling ({layout.base_tile_shape})" - ) + nested_ref_shape = tuple( + (ref_ty.shape[i], ref_ty.shape[i + ref_logical_rank]) + for i in range(ref_logical_rank) + ) + nested_ref_strides = tuple( + (ref_strides[i], ref_strides[i + ref_logical_rank]) + for i in range(ref_logical_rank) + ) + tiled_nested_shape, tiled_nested_strides = tiling.tile_nested_shape_strides( + nested_ref_shape, nested_ref_strides + ) - elem_tiled_strides = list(tiling.tile_strides(tuple(ref_strides))) - tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape))) + # We could technically handle this case, but it would be quite complicated. + # If tiling dimensions would have to be expanded into multiple, we'd have to + # adjust the dimension indices in layouts, including expanding some of them + # into multiple indices. Note that for non-tiling dims, we allow the shape + # to be arbitrary, which is why we fix it up below in mem_idx_to_reg_idx. + if any( + len(dim_shape) != 1 for dim_shape in tiled_nested_shape[-layout.tiled_tiling_rank :] + ): + raise NotImplementedError("Memory and register tiling incompatible") + tiled_shape = list(itertools.chain.from_iterable(tiled_nested_shape)) + elem_tiled_strides = list(itertools.chain.from_iterable(tiled_nested_strides)) elem_lane_strides = [elem_tiled_strides[d] for d in layout.lane_dims] lane_shape = [tiled_shape[d] for d in layout.lane_dims] if elem_tiled_strides[layout.vector_dim] != 1: raise ValueError("Stride of the vectorized dimension should be 1") for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim): tiled_shape[d] = 1 - full_tiling = Tiling((ref_tiling_shape, *tiling.tiles)) - full_layout = dataclasses.replace(layout, tiling=full_tiling) element_bits = mgpu.bitwidth(dtype) if (layout.vector_length * element_bits) % 8 != 0: @@ -1817,9 +1891,11 @@ class FragmentedArray: ) # All offsets are in units of transfer_dtype. - dyn_tiled_strides = [c(s) for s in transfer_tiled_strides] - lane_offset = utils.dyn_dot(full_layout.lane_indices(), dyn_tiled_strides) - warp_offset = utils.dyn_dot(full_layout.warp_indices(), dyn_tiled_strides) + dyn_tiled_strides = [ + c(s) for s in transfer_tiled_strides[-layout.tiled_tiling_rank :] + ] + lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_tiled_strides) + warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_tiled_strides) dyn_offset = arith.addi(lane_offset, warp_offset) if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): raise ValueError("Tiled stores can be performed into SMEM") @@ -1847,10 +1923,23 @@ class FragmentedArray: reg_ptr = utils.getelementptr(ptr, [offset], transfer_dtype) offset_no_swizzle = plan.select(_as_consts(const_offset_no_swizzle)) reg_ptr = utils.getelementptr(reg_ptr, [offset_no_swizzle], transfer_dtype) - reg_idxs = [ - tiling.tile_indices(full_tiling.untile_indices(idx)) - for idx in indices.tolist() - ] + # Here, registers are organized in an array with shape obtained by tiling + # the logical data bounds. But, the reference was tiled and so each + # logical tiled dimension can map to multiple dims in tiled_shape. + # The transform below maps this potentially higher-rank representation + # back to the lower-rank representation used by the register arrays. + def mem_idx_to_reg_idx(idx): + reg_tiled_idx = [] + base_idx = 0 + for dim_shape in tiled_nested_shape[:ref_logical_rank]: + dim_strides = utils.get_contiguous_strides(dim_shape) + dim_idxs = idx[base_idx:base_idx + len(dim_shape)] + base_idx += len(dim_shape) + reg_tiled_idx.append(sum(i * s for i, s in zip(dim_idxs, dim_strides))) + # We should have fixed up all but the tiling dims. + assert base_idx == len(idx) - layout.tiled_tiling_rank + return (*reg_tiled_idx, *idx[base_idx:]) + reg_idxs = [mem_idx_to_reg_idx(idx) for idx in indices.tolist()] def get_register(regs, reg_idxs=reg_idxs): return plan.select([regs[reg_idx] for reg_idx in reg_idxs]) def update_registers(regs, new, reg_idxs=reg_idxs): diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 191f0fcfc..285185f4b 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1172,4 +1172,5 @@ def getelementptr( def dyn_dot(x, y): + assert len(x) == len(y) return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y))) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 3b759a3b7..c0e61f5b6 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2193,15 +2193,18 @@ class LayoutTest(TestCase): dtype=[jnp.int8, jnp.int16, jnp.int32], swizzle=[16, 32, 64, 128], num_col_tiles=[1, 2, 3], + row_tiling=[8, 64], ) - def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles): + def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles, row_tiling): + if (not load_tiled or not load_tiled) and row_tiling != 64: + self.skipTest("Old code path does not support this") mlir_dtype = utils.dtype_to_ir_type(dtype) bw = bytewidth(mlir_dtype) col_tiling = swizzle // bw if col_tiling % 8: self.skipTest("WGMMA layout requires col_tiling % 8 == 0") m, n = 128, col_tiling * num_col_tiles - tiling = (64, col_tiling) + tiling = (row_tiling, col_tiling) tiled_layout = fa._tiled_wgmma_layout((m, n)) load_layout = tiled_layout if load_tiled else mgpu.TILED_LAYOUT_WGMMA store_layout = tiled_layout if store_tiled else mgpu.TILED_LAYOUT_WGMMA From b7ecfdfd9568a3d5f53d7fe8745aea51f639c100 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 26 Feb 2025 15:01:57 -0500 Subject: [PATCH 056/100] Update ad.backward_pass to support non-linear functions of constants. --- jax/_src/interpreters/ad.py | 24 +++++++++++++++++++++++- tests/api_test.py | 14 ++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 37ad40d22..c8200fdf9 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -330,12 +330,34 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack, # forces primal_in to contain UndefinedPrimals for tangent values! map(write_primal, jaxpr.invars, primals_in) + # Start with a forward pass to evaluate any side-effect-free JaxprEqns that + # only operate on primals. This is required to support primitives with + # linearization rules that include computations on the residuals. + lin_eqns = [] + for eqn in jaxpr.eqns: + # TODO (dfm): The effects check is probably stricter than necessary. + # Consider adding an allowlist of effects here. + if jaxpr.effects or any( + type(x) is not Literal and x not in primal_env for x in eqn.invars): + lin_eqns.append(eqn) + continue + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack + traceback = eqn.source_info.traceback + with source_info_util.user_context( + traceback, name_stack=name_stack), eqn.ctx.manager: + ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params) + if eqn.primitive.multiple_results: + map(write_primal, eqn.outvars, ans) + else: + write_primal(eqn.outvars[0], ans) + ct_env: dict[Any, Any] = {} ctx = (source_info_util.transform_name_stack('transpose') if transform_stack else contextlib.nullcontext()) with ctx: map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) - for eqn in jaxpr.eqns[::-1]: + for eqn in lin_eqns[::-1]: if eqn.primitive.ref_primitive: if eqn.primitive is core.mutable_array_p: val_var, = eqn.invars diff --git a/tests/api_test.py b/tests/api_test.py index ff729c03d..35e92f748 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -57,6 +57,7 @@ from jax._src import xla_bridge from jax._src import debugging from jax._src import pjit as pjit_lib from jax._src.ad_checkpoint import saved_residuals +from jax._src.interpreters import ad as ad_internal from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled @@ -4732,6 +4733,19 @@ class APITest(jtu.JaxTestCase): check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0)) + def test_deferred_primal_with_direct_linearize(self): + def my_sin_lin(nzs, x): + nz, = nzs + return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) + + my_sin_p = core.Primitive("my_sin_p") + my_sin_p.def_impl(lax.sin) + my_sin_p.def_abstract_eval(lambda x: x) + ad_internal.primitive_linearizations[my_sin_p] = my_sin_lin + + with config.use_direct_linearize(True): + jax.grad(my_sin_p.bind)(1.0) # doesn't crash + class RematTest(jtu.JaxTestCase): From f8b98993b85350447df8c1e026783a862d6bbe00 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 7 Mar 2025 07:00:52 -0800 Subject: [PATCH 057/100] Add a divisibility check so that we make sure that sharding evenly divides the shape (until this restriction is lifted) to make sure we don't create bad shardings. Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message) Also make some formatting changes in scan lowering to make it easier to debug. PiperOrigin-RevId: 734542862 --- jax/_src/core.py | 19 +++++++++++++-- jax/_src/interpreters/mlir.py | 2 +- jax/_src/lax/control_flow/loops.py | 37 ++++++++++++++---------------- jax/_src/lax/slicing.py | 5 ++-- tests/pjit_test.py | 8 +++++++ 5 files changed, 46 insertions(+), 25 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 17b842359..767c61089 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1852,9 +1852,22 @@ def _maybe_modify_sharding(sharding, ndim): out = _make_lengths_same(out, ndim) return out +def _check_divisibility(sharding, shape): + mesh = sharding.mesh + for dim, (spec, sh) in enumerate(zip(sharding.spec, shape)): + if spec is None: + continue + spec = spec if isinstance(spec, tuple) else (spec,) + size = math.prod(mesh.shape[s] for s in spec) + _, remainder = divmod(sh, size) + if remainder != 0: + raise ValueError( + f"Sharding spec {spec} implies that array axis {dim} is partitioned" + f" {size} times, but does not evenly divide the dimension size {sh}." + f" Got shape: {shape} and sharding {sharding}") @cache(max_size=4096, trace_context_in_key=True) -def get_sharding(sharding, ndim): +def get_sharding(sharding, shape): """Modifies and checks the sharding. Some modifications/checks include: @@ -1863,6 +1876,7 @@ def get_sharding(sharding, ndim): * Checking for len(spec)-ndim match * Checking if the mesh is an AbstractMesh. """ + ndim = len(shape) if sharding is None: return NamedSharding(mesh_lib.empty_abstract_mesh, P(*[None] * ndim)) @@ -1874,6 +1888,7 @@ def get_sharding(sharding, ndim): if not isinstance(out_s.mesh, mesh_lib.AbstractMesh): raise ValueError("Mesh of an aval must be an AbstractMesh. " f"Got {out_s.mesh} of type {type(out_s.mesh)}") + _check_divisibility(out_s, shape) return out_s @@ -1885,7 +1900,7 @@ class ShapedArray(UnshapedArray): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type - self.sharding = get_sharding(sharding, len(self.shape)) + self.sharding = get_sharding(sharding, self.shape) def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 7c10c7b8d..e5dabd146 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -939,7 +939,7 @@ def sharded_aval(aval: core.AbstractValue, return aval if not isinstance(aval, (core.ShapedArray, core.DShapedArray)): raise NotImplementedError - return aval.update(sharding.shard_shape(aval.shape)) # type: ignore + return aval.update(sharding.shard_shape(aval.shape), sharding=None) # type: ignore def eval_dynamic_shape(ctx: LoweringRuleContext, diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 4c0bc6e6f..caabace86 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -443,6 +443,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, consts, carry, xs_ = split_list(args, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) num_trips, remainder = divmod(length, unroll) + if unroll != 1 and num_trips == 1 and remainder == 0: # In that case, we explicitly want to fully unroll the loop. Put everything # into the remainder block and avoid lowering to a while loop. @@ -459,26 +460,6 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_] yss = _map(partial(_empty_array, (num_trips, unroll), None), y_avals) - def cond_fun(while_carry): - i, _, _ = while_carry - return i < num_trips - def body_fun(while_carry): - i_, carry, yss = while_carry - i = num_trips - i_ - 1 if reverse else i_ - xs = [ - slicing.dynamic_index_in_dim( - xs, i, keepdims=False, allow_negative_indices=False - ) - for xs in xss - ] - carry, ys = inner(unroll, carry, xs) - yss = [ - slicing.dynamic_update_index_in_dim( - ys, upd, i, 0, allow_negative_indices=False - ) - for ys, upd in zip(yss, ys) - ] - return i_ + 1, carry, yss def inner(n, carry, xs): ys = [] if unroll == 1: @@ -493,6 +474,22 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, ys = list(reversed(ys)) if reverse else ys return carry, _map(_stack, zip(*ys)) + def body_fun(while_carry): + i_, carry, yss = while_carry + i = num_trips - i_ - 1 if reverse else i_ + xs = [slicing.dynamic_index_in_dim(xs, i, keepdims=False, + allow_negative_indices=False) + for xs in xss] + carry, ys = inner(unroll, carry, xs) + yss = [slicing.dynamic_update_index_in_dim(y, upd, i, 0, + allow_negative_indices=False) + for y, upd in zip(yss, ys)] + return i_ + 1, carry, yss + + def cond_fun(while_carry): + i, _, _ = while_carry + return i < num_trips + if num_trips: i = lax._const(num_trips, 0) _, carry, yss = while_loop(cond_fun, body_fun, (i, carry, yss)) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index bb1647d7e..c26de99c7 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1608,8 +1608,9 @@ def _dynamic_update_slice_sharding_rule(operand, update, *start_indices): if operand.sharding != update.sharding: raise TypeError( "dynamic_update_slice update sharding must be equal to operand" - f" sharding, got update sharding {update.sharding} for operand sharding" - f" {operand.sharding}.") + " sharding, got update sharding" + f" {update.str_short(mesh_axis_types=True)} for operand sharding" + f" {operand.str_short(mesh_axis_types=True)}.") return operand.sharding def _dynamic_update_slice_dtype_rule(operand, update, *start_indices): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 4b267196e..b0a59027f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7072,6 +7072,14 @@ class ShardingInTypesTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, "Context mesh.*cannot be empty"): auto_axes(f, out_shardings=s)(arr) + def test_divisbility_aval_error(self): + abstract_mesh = mesh_lib.AbstractMesh( + (('x', 2),), axis_types={AxisTypes.Explicit: 'x'}) + s = NamedSharding(abstract_mesh, P('x')) + with self.assertRaisesRegex( + ValueError, 'does not evenly divide the dimension size'): + core.ShapedArray((5, 2), jnp.int32, sharding=s) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 65462fe684b3e6f3efa8f153870142907925b99a Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 7 Mar 2025 07:40:45 -0800 Subject: [PATCH 058/100] [Mosaic GPU] Add a new layout to help with transposing WGMMA results PiperOrigin-RevId: 734553651 --- .../mosaic/gpu/fragmented_array.py | 32 +++++++++++++++ jax/experimental/mosaic/gpu/utils.py | 27 ++++++++++++ tests/mosaic/gpu_test.py | 41 +++++++++++++++++++ 3 files changed, 100 insertions(+) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index d5400a641..31f918044 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -540,6 +540,12 @@ TILED_LAYOUT_WGMMA = TiledLayout( lane_dims=(-4, -3), vector_dim=-1, ) +WGMMA_TRANSPOSED_LAYOUT = TiledLayout( + Tiling(((64, 8), (16, 8), (8, 8), (2, 8), (2, 2), (2, 1))), + warp_dim=-12, + lane_dims=(-8, -3, -5), + vector_dim=-2, +) @jax.tree_util.register_pytree_node_class @dataclasses.dataclass(init=False, eq=False, frozen=True, slots=True) @@ -708,9 +714,35 @@ class FragmentedArray: At the moment, only conversions from ``WGSplatFragLayout`` are supported. """ + i32 = ir.IntegerType.get_signless(32) if self.layout == new_layout: return self shape = self.shape + if ( + self.layout == TILED_LAYOUT_WGMMA + and new_layout == WGMMA_TRANSPOSED_LAYOUT + and utils.bitwidth(self.mlir_dtype) == 16 + ): + is_even_row = arith.cmpi( + arith.CmpIPredicate.eq, + arith.remui(arith.divui(utils.thread_idx(), c(4, i32)), c(2, i32)), + c(0, i32), + ) + perm = arith.select(is_even_row, c(0x5410, i32), c(0x3276, i32)) + new_regs = [] + for reg in self.registers.flat: + reg_ty = reg.type + reg = utils.bitcast(reg, i32) + reg_shfl = utils.shfl_bfly(reg, 4) + new_reg = llvm.inline_asm( + i32, [reg, reg_shfl, perm], "prmt.b32 $0, $1, $2, $3;", "=r,r,r,r" + ) + new_regs.append(utils.bitcast(new_reg, reg_ty)) + return FragmentedArray( + _registers=np.asarray(new_regs, dtype=object).reshape(new_layout.registers_shape(shape)), + _layout=new_layout, + _is_signed=self.is_signed, + ) if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 8 == 0: tiled_layout = _tiled_wgmma_layout(shape) if (self.layout == WGMMA_LAYOUT and new_layout == tiled_layout) or ( diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 285185f4b..f90f7ff08 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1174,3 +1174,30 @@ def getelementptr( def dyn_dot(x, y): assert len(x) == len(y) return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y))) + + +def shfl_bfly(x: ir.Value, distance: int | ir.Value): + i32 = ir.IntegerType.get_signless(32) + if isinstance(distance, int): + distance = c(distance, i32) + assert x.type == i32 + return nvvm.shfl_sync( + i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly, + ) + + +def bitcast(x: ir.Value, new_type: ir.Type): + if ir.VectorType.isinstance(x.type) and ir.IntegerType.isinstance(new_type): + new_type = ir.IntegerType(new_type) + x_ty = ir.VectorType(x.type) + assert new_type.width == bitwidth(x_ty.element_type) * math.prod(x_ty.shape) + i0 = arith.ConstantOp.create_index(0) + return vector.extractelement( + vector.bitcast(ir.VectorType.get((1,), new_type), x), position=i0 + ) + if ir.IntegerType.isinstance(x.type) and ir.VectorType.isinstance(new_type): + new_type = ir.VectorType(new_type) + x_ty = ir.IntegerType(x.type) + assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape) + return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x)) + raise ValueError(f"Can't bitcast {x.type} to {new_type}") diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index c0e61f5b6..a98e3f9bb 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2287,6 +2287,47 @@ class LayoutTest(TestCase): ) np.testing.assert_array_equal(f(x), x) + @parameterized.product( + dtype=[jnp.int16], # TODO(apaszke): More dtypes + # TODO(apaszke): swizzle=64 <- not implemented in transfer_tiled right now + swizzle=[16, 32, 128], + ) + def test_transpose_tiled(self, dtype, swizzle): + mlir_dtype = utils.dtype_to_ir_type(dtype) + bw = bytewidth(mlir_dtype) + col_tiling = swizzle // bw + m, n = 128, 256 + tiling = (8, col_tiling) + transpose_layout = fa.WGMMA_TRANSPOSED_LAYOUT + def kernel(ctx, in_, out, smems): + smem_in, smem_out, barrier = smems + ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier) + barrier.wait() + t = mgpu.FragmentedArray.load_tiled( + smem_in, swizzle=swizzle, is_signed=True, layout=fa.TILED_LAYOUT_WGMMA + ) + smem_out_t = memref_transpose(smem_out, (1, 0, 3, 2)) + t.to_layout(transpose_layout).store_tiled(smem_out_t, swizzle=swizzle) + mgpu.commit_shared() + ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle) + ctx.await_async_copy(0) + x = ( + np.arange(m * n, dtype=dtype) + .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) + .transpose(0, 2, 1, 3) + ) + y_ref = ( + np.arange(m * n, dtype=dtype) + .reshape(m, n) + .T.reshape(n // tiling[0], tiling[0], m // tiling[1], tiling[1]) + .transpose(0, 2, 1, 3) + ) + + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, y_ref, [x, y_ref, mgpu.TMABarrier()], + )(x) + np.testing.assert_array_equal(y, y_ref) + @dataclasses.dataclass(frozen=True) class Tile: From 928caf83ee9120183b3bd49c276af6399fa4a070 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 7 Mar 2025 07:40:50 -0800 Subject: [PATCH 059/100] [pallas:mosaic_gpu] `copy_smem_to_gmem` now allows skipping `cp.async.commit_group` This feature is necessary to fix the SMEM->GMEM waiting behavior in `emit_pipeline`, which used a pessimistic condition prior to this change, since every copy was its own commit group. PiperOrigin-RevId: 734553668 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 8 ++++ jax/_src/pallas/mosaic_gpu/primitives.py | 41 ++++++++++++++++++- .../mosaic/gpu/dialect_lowering.py | 1 + jax/experimental/mosaic/gpu/launch_context.py | 7 ++-- jax/experimental/pallas/mosaic_gpu.py | 1 + jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 3 +- 6 files changed, 55 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index c1ecd47bd..c2080b9c6 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -91,6 +91,7 @@ class BufferedRef: self.smem_ref.at[slot], # pytype: disable=unsupported-operands self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands predicate=predicate, + commit_group=False, ) @@ -299,6 +300,8 @@ def emit_pipeline( predicate=lax.bitwise_or(slices_changed, is_last_step), ) + gpu_primitives.commit_smem_to_gmem_group() + fetch_step = step + (max_concurrent_steps - delay_release) fetch_slot = lax.rem(fetch_step, max_concurrent_steps) @@ -344,6 +347,8 @@ def emit_pipeline( if bref.is_index_invariant: bref.copy_out(last_slot, last_indices, predicate=None) + gpu_primitives.commit_smem_to_gmem_group() + # Finalize the pipeline. gpu_primitives.wait_smem_to_gmem(0) @@ -578,6 +583,7 @@ def emit_pipeline_warp_specialized( bref.copy_out(_get_slot(slot, ~bref.is_index_invariant), indices, predicate=slices_changed) + gpu_primitives.commit_smem_to_gmem_group() next_indices = _inc_grid_by_1(indices, grid) return (next_indices, new_store_slices, next_body_carry) init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) @@ -619,6 +625,8 @@ def emit_pipeline_warp_specialized( if bref.is_index_invariant: bref.copy_out(last_slot, last_indices, predicate=None) + gpu_primitives.commit_smem_to_gmem_group() + # Finalize the pipeline. gpu_primitives.wait_smem_to_gmem(0) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 547a20451..428e2925f 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -86,6 +86,7 @@ def _copy_smem_to_gmem_lowering( src_transforms_treedef, dst_transforms_treedef, has_user_predicate, + commit_group, ): predicate = ctx.module_ctx.single_wg_lane_predicate if has_user_predicate: @@ -106,6 +107,7 @@ def _copy_smem_to_gmem_lowering( src_ref=src, dst_ref=dst, predicate=predicate, + arrive=commit_group, **copy_params, ) return () @@ -119,7 +121,12 @@ def _copy_smem_to_gmem_lowering( assert copy_params.get("swizzle") is None assert not copy_params.get("gmem_transform") mgpu.dialect.async_store( - src, dst, indices, slice_lengths, predicate=predicate + src, + dst, + indices, + slice_lengths, + predicate=predicate, + commit_group=commit_group, # type: ignore[call-arg] ) return () @@ -174,7 +181,11 @@ def _extract_smem_copy_params(transforms): def copy_smem_to_gmem( - src: _Ref, dst: _Ref, predicate: jax.Array | None = None + src: _Ref, + dst: _Ref, + predicate: jax.Array | None = None, + *, + commit_group: bool = True, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. @@ -183,6 +194,9 @@ def copy_smem_to_gmem( dst: The GMEM reference to copy to. predicate: A boolean indicating whether the copy should be performed. If ``None``, the copy is always performed. + commit_group: If ``True``, this and any previously uncommitted copies + are committed to a group and can be awaited jointly via + :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`. See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` @@ -209,6 +223,7 @@ def copy_smem_to_gmem( src_transforms_treedef=src_transforms_treedef, dst_transforms_treedef=dst_transforms_treedef, has_user_predicate=predicate is not None, + commit_group=commit_group, ) return None @@ -475,6 +490,28 @@ def wait_smem_to_gmem(n: int, wait_read_only: bool = False) -> None: wait_smem_to_gmem_p.bind(n, wait_read_only=wait_read_only) +commit_group_p = jax_core.Primitive("commit_group") +commit_group_p.multiple_results = True + + +@commit_group_p.def_effectful_abstract_eval +def _commit_group_abstract_eval(): + return (), {gpu_core._memory_effect} + + +@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Warpgroup) +def _commit_group_lowering(ctx: lowering.LoweringRuleContext): + del ctx # Unused. + nvvm_dialect.cp_async_bulk_commit_group() + return () + + +def commit_smem_to_gmem_group() -> None: + """Commits all issued but uncommited SMEM->GMEM copies to a group.""" + commit_group_p.bind() + + # WGMMA on an accumulator reference wgmma_ref_p = jax_core.Primitive("wgmma_ref") wgmma_ref_p.multiple_results = True diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index d6156278c..40c0b2fce 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -424,6 +424,7 @@ def _mgpu_async_store_op_lowering_rule( gmem_transform=transforms, uniform=True, predicate=ctx.single_thread_per_warpgroup_predicate, + arrive=store_op.commit_group, ) return [] diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 7a314a2ea..ce432f26d 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -430,8 +430,8 @@ class LaunchContext: gmem_ref, smem_ref = dst_ref, src_ref if barrier is not None: raise ValueError("Barriers are unsupported for SMEM -> GMEM copies") - if arrive is not None: - raise ValueError("arrive is unsupported for SMEM -> GMEM copies") + if arrive is None: + arrive = True # Commit this copy to the async group by default else: raise ValueError("Only SMEM <-> GMEM copies supported") # TODO(apaszke): This is a very approximate check. Improve it! @@ -683,7 +683,8 @@ class LaunchContext: nvvm.cp_async_bulk_tensor_global_shared_cta( tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate ) - nvvm.cp_async_bulk_commit_group() + if arrive: + nvvm.cp_async_bulk_commit_group() def await_async_copy( self, allow_groups: int, await_read_only: bool = False diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index e13dd11ed..631b4f720 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -35,6 +35,7 @@ from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arri from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem +from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 964039d82..dbc829832 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -347,7 +347,8 @@ def MosaicGPU_AsyncStoreOp : Op:$commit_group ); let assemblyFormat = [{ From 402389290c8abb1c3a5fa6acfec6d0285071227f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 7 Mar 2025 07:58:57 -0800 Subject: [PATCH 060/100] [Mosaic TPU] Enable all conversions involving fp8 types on TPUv5+ PiperOrigin-RevId: 734558364 --- tests/pallas/ops_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 0cabb4bfe..20ff22e7e 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -648,10 +648,8 @@ class OpsTest(PallasBaseTest): } or to_dtype in {"float8_e4m3b11fnuz", "float8_e5m2", "float8_e4m3fn"}: if not jtu.test_device_matches(["tpu"]) or jtu.get_tpu_version() < 5: self.skipTest("Not supported on this hardware") - if not jtu.if_cloud_tpu_at_least(2025, 3, 8): - self.skipTest("Test requires libtpu from 2025/3/8 or later") - if to_dtype not in {"float32", "int32", "uint32"}: - self.skipTest("Only fp8 to x32 cast is supported") + if not jtu.if_cloud_tpu_at_least(2025, 3, 9): + self.skipTest("Test requires libtpu from 2025/3/9 or later") from_int = np.issubdtype(np.dtype(from_dtype), np.integer) to_int = np.issubdtype(np.dtype(to_dtype), np.integer) @@ -693,7 +691,9 @@ class OpsTest(PallasBaseTest): if randomize: x = random.randint(random.key(234), (16, 16), 0, 1, jnp.int32) != 0 else: - x = jnp.asarray([[False, True], [True, False]], dtype="bool") + x = jnp.tile( + jnp.asarray([[False, True], [True, False]], dtype="bool"), (8, 8) + ) assert x.dtype == jnp.dtype(from_dtype) # XLA does not specify the float->int conversion result for NaNs. if jnp.issubdtype(from_dtype, jnp.floating): From 1bef8b61af7ec7722fe9270a61a1854a18a49838 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 7 Mar 2025 08:18:33 -0800 Subject: [PATCH 061/100] [Mosaic GPU] Add a better explanation for the transposed layout Thanks to @bchetioui for the discussion! PiperOrigin-RevId: 734564672 --- .../mosaic/gpu/fragmented_array.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 31f918044..9b241ef7b 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -540,10 +540,31 @@ TILED_LAYOUT_WGMMA = TiledLayout( lane_dims=(-4, -3), vector_dim=-1, ) +# This tiled layout is similar to the one above. Above, each warp stores a 8x8 +# submatrix in the following way (we only show the first 4 rows for brevity): +# +# 0 0 1 1 2 2 3 3 +# 4 4 5 5 6 6 7 7 +# 8 8 9 9 10 10 11 11 +# 12 12 13 13 14 14 15 15 +# ... +# +# This tiled layout stores the same 8x8 submatrix in the following way: +# +# 0 4 1 5 2 6 3 7 +# 0 4 1 5 2 6 3 7 +# 8 12 9 13 10 14 11 15 +# 8 12 9 13 10 14 11 15 +# ... +# +# You can see that we have taken 2x2 submatrices from the above layout and +# transposed them. The assigment of lanes to elements is such that in both +# layouts the same two lanes map to a single 2x2 submatrix, making the transpose +# very cheap (one shuffle and permute suffices to change between those layouts). WGMMA_TRANSPOSED_LAYOUT = TiledLayout( - Tiling(((64, 8), (16, 8), (8, 8), (2, 8), (2, 2), (2, 1))), - warp_dim=-12, - lane_dims=(-8, -3, -5), + Tiling(((64, 8), (16, 8), (8, 8), (2, 2), (2, 1))), + warp_dim=-10, + lane_dims=(-6, -3, -5), vector_dim=-2, ) From eeccc67c0bb2906d87019e1435a41bb240c1c37a Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 7 Mar 2025 08:57:46 -0800 Subject: [PATCH 062/100] [mgpu] Debug print arrays. PiperOrigin-RevId: 734576543 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 +----- jax/experimental/mosaic/gpu/fragmented_array.py | 7 +++++++ tests/pallas/mosaic_gpu_test.py | 8 ++++---- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5454b826f..5da89f190 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1491,11 +1491,7 @@ def _debug_print_lowering_rule( ) elif len(ctx.avals_in) == 1: [arg] = args - @arg.foreach - def _(val, idx): - idx_fmt = ", ".join(["{}"] * len(idx)) - fmt_str = fmt.format(f"[{idx_fmt}]/{list(arg.shape)}: {{}}") - mgpu.debug_print(fmt_str, *idx, val, uniform=False) + arg.debug_print(fmt) else: raise NotImplementedError( "debug_print only supports printing of scalar values, or a single array" diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 9b241ef7b..b3f335155 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1555,6 +1555,13 @@ class FragmentedArray: if create_array: return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) + def debug_print(self, fmt: str): + idx_fmt = ", ".join(["{}"] * len(self.shape)) + @self.foreach + def _(val, idx): + fmt_str = fmt.format(f"[{idx_fmt}]: {{}}") + utils.debug_print(fmt_str, *idx, val, uniform=False) + def store_untiled(self, ref: ir.Value, *, vector_store: bool = True): if not ir.MemRefType.isinstance(ref.type): raise ValueError(ref) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 22337ae68..756e5d220 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -714,7 +714,7 @@ class PallasCallTest(PallasTest): shape = (128, 64) size = math.prod(shape) def kernel(x_ref, o_ref): - pl.debug_print("{}", x_ref[...]) + pl.debug_print("prefix {}", x_ref[...]) spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))) x = jnp.arange(size, dtype=jnp.float32).reshape(shape) f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec) @@ -723,8 +723,8 @@ class PallasCallTest(PallasTest): jax.block_until_ready(f(x)) output = get_output() - results = re.findall(r"\[(\d+), (\d+)\]/\[128, 64\]: (\d+)", output) - self.assertLen(results, size) + results = re.findall(r"prefix \[(\d+), (\d+)\]: (\d+).?\d*", output) + self.assertLen(results, size, output) for i, j, v in results: i, j, v = map(int, (i, j, v)) self.assertEqual(v, i * shape[1] + j) @@ -774,7 +774,7 @@ class PallasCallTest(PallasTest): with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) - self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output()) + self.assertIn("x: [1, 0, 43, 23]: 6871\n", output()) def test_load_scalar(self): @functools.partial( From 9f37b5197f4511d5adb3fc167f60ca478553fd71 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 7 Mar 2025 09:46:40 -0800 Subject: [PATCH 063/100] [sharding_in_types] Fix a bug where `empty_array` in scan was created with the wrong spec when `unroll > 1`. PiperOrigin-RevId: 734591110 --- jax/_src/lax/control_flow/loops.py | 6 +++--- tests/pjit_test.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index caabace86..5179da751 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -450,7 +450,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, num_trips, remainder = 0, length if unroll == 1: xss = xs_ - yss = _map(partial(_empty_array, (length,), None), y_avals) + yss = _map(partial(_empty_array, (length,), (None,)), y_avals) else: if remainder: if not reverse: @@ -458,7 +458,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, else: xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_)) xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_] - yss = _map(partial(_empty_array, (num_trips, unroll), None), y_avals) + yss = _map(partial(_empty_array, (num_trips, unroll), (None, None)), y_avals) def inner(n, carry, xs): ys = [] @@ -509,7 +509,7 @@ def _split_leading(sz, x): def _concat(a, b): return lax.concatenate([a, b], 0) def _empty_array(prefix, length_spec, aval): - sharding = aval.sharding.with_spec((length_spec, *aval.sharding.spec)) + sharding = aval.sharding.with_spec((*length_spec, *aval.sharding.spec)) return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape), out_sharding=sharding) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b0a59027f..720a77410 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7080,6 +7080,20 @@ class ShardingInTypesTest(jtu.JaxTestCase): ValueError, 'does not evenly divide the dimension size'): core.ShapedArray((5, 2), jnp.int32, sharding=s) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_scan_unroll(self, mesh): + np_inp = np.arange(64, dtype=jnp.float32).reshape(8, 8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'y'))) + carry = jnp.ones((8,), dtype=jnp.float32) + + @jax.jit + def f(carry, xs): + def body(carry, x): + return carry + x, x + return jax.lax.scan(body, carry, xs, unroll=2) + + f(carry, arr) # doesn't crash + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From ccf72782925f930dd3ae2cc8ccf4f6da7b525cf2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 7 Mar 2025 09:49:07 -0800 Subject: [PATCH 064/100] Add the len(arg) to the error message for static_argnums Helps reduce the confusion on what is considered an argnum. Ideally there should be static_argkwg PiperOrigin-RevId: 734591856 --- jax/_src/ad_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index f1c0078cd..1aa9f17bc 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -355,7 +355,7 @@ def _remat_static_argnums(fun, static_argnums, args): raise ValueError("the `static_argnums` argument to `jax.checkpoint` / " "`jax.remat` can only take integer values greater than or " "equal to `-len(args)` and less than `len(args)`, but got " - f"{static_argnums}") + f"{static_argnums}, while `len(args)` = {len(args)}") if not static_argnums: return fun, args From 178278863da1aa88c1cf9b97db8ba7a13e371524 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 7 Mar 2025 10:49:09 -0800 Subject: [PATCH 065/100] [JAX] Fix api_benchmark broken by https://github.com/jax-ml/jax/pull/26569 `pjit_check_aval_sharding` expects `names: Sequence[str]`. PiperOrigin-RevId: 734614264 --- benchmarks/api_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 5e2555769..c3be27f4a 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -475,7 +475,7 @@ def bench_pjit_check_aval_sharding(state): aval = jax.core.ShapedArray((8, 2), np.int32) while state: - pjit_check_aval_sharding([s] * 100, [aval] * 100, None, 'benchmark', False) + pjit_check_aval_sharding([s] * 100, [aval] * 100, [''] * 100, 'benchmark', False) @google_benchmark.register From 0e30a3ace9d6a0cd0e4946179c89be4fa36aaf16 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 7 Mar 2025 01:03:22 +0000 Subject: [PATCH 066/100] [mutable-arrays] read values should have the same explicit sharding as ref fixes #26936 --- jax/_src/state/indexing.py | 23 +++++++++++++++++++++++ jax/_src/state/primitives.py | 14 ++++++++++---- jax/_src/state/types.py | 16 ++++++++++++++++ tests/mutable_array_test.py | 17 +++++++++++++++++ 4 files changed, 66 insertions(+), 4 deletions(-) diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 7abaa3185..4b627c1cd 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -259,3 +259,26 @@ class NDIndexer: def transform_dtype(self, dtype): return dtype + + def transform_sharding(self, sharding): + # If there are no explicit axes, do nothing. + if all(p is None for p in sharding.spec): + return sharding + # If there are explicit axes, we don't support changing the shape, so we + # don't support int indexers and instead require all slices. + if (self.int_indexer_shape or + not all(isinstance(idx, Slice) for idx in self.indices)): + raise TypeError("sharded ref (array reference) can only be indexed by " + "slices, not integers") + # Moreover, only allow trivial slice(None) slices on explicitly sharded + # axes. Then the sharding stays the same. + _, slice_indexers, _ = unpack_ndindexer(self) + for i, (d, sl, s) in enumerate(zip(self.shape, slice_indexers, sharding.spec)): + if s is None: continue + if not (type(sl.start) is int and sl.start == 0 and + type(sl.size) is int and sl.size == d and + type(sl.stride) is int and sl.stride == 1): + raise ValueError("sharded ref (array reference) can only be sliced " + f"along unsharded axes, but ref of shape {self.shape} " + f"was sliced on axis {i}, which is sharded like {s}") + return sharding diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 2a8b8bcc9..51f9aa3e6 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -206,6 +206,13 @@ def _dtype_after_transforming( return dtype +def _sharding_after_transforming(sharding, transforms): + for transform in transforms: + sharding = transform.transform_sharding(sharding) + assert sharding is not None + return sharding + + def _get_abstract_eval(ref_aval: AbstractRef, *args, tree): transforms = tree_util.tree_unflatten(tree, args) @@ -214,10 +221,9 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args, if isinstance(ref_aval.inner_aval, core.ShapedArray): out_shape = _shape_after_transforming(ref_aval.shape, transforms) out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) - # TODO(yashkatariya): Transform the sharding too instead of setting it to - # None. - out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype, - sharding=core.get_cur_mesh_sharding()) + out_sharding = _sharding_after_transforming(ref_aval.sharding, transforms) + out_aval = ref_aval.inner_aval.update( + shape=out_shape, dtype=out_dtype, sharding=out_sharding) else: if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 5efc7f1e0..057242f4c 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -119,6 +119,12 @@ class RefBitcaster: del dtype # Unused return self.dtype + def transform_sharding(self, sharding): + # If there are no explicit axes, do nothing. + if all(p is None for p in sharding.spec): + return sharding + raise NotImplementedError + @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) @@ -166,6 +172,12 @@ class RefReshaper: del dtype # Unused return self.dtype + def transform_sharding(self, sharding): + # If there are no explicit axes, do nothing. + if all(p is None for p in sharding.spec): + return sharding + raise NotImplementedError + class Transform(Protocol): @@ -189,6 +201,10 @@ class Transform(Protocol): """ return dtype + def transform_sharding(self, sharding): + if all(p is None for p in sharding.spec): return sharding # no explicit axes + raise NotImplementedError + @dataclasses.dataclass class RefIndexer: diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 9a6c5c167..e32e27d95 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -254,6 +254,23 @@ class MutableArrayTest(jtu.JaxTestCase): self.assertEqual(s, a.sharding) self.assertEqual(s, y.sharding) + def test_explicit_sharding_after_indexing(self): + # https://github.com/jax-ml/jax/issues/26936 + mesh = jax.make_mesh((1, 1), ('x', 'y'), explicit_axes=('x', 'y')) + sharding = NamedSharding(mesh, P('x', 'y')) + + @jax.jit + def f(x_ref): + self.assertEqual(core.get_ty(x_ref).sharding.spec, + core.get_ty(x_ref[...]).sharding.spec) + y = x_ref[...] + 1 + return y + + with jax.sharding.use_mesh(mesh): + x = jnp.zeros((4, 4), jnp.int32, device=sharding) + x_ref = core.mutable_array(x) + y = f(x_ref) + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): From 7c2f842353c5618d3a82f82258d573fb522189c7 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 5 Feb 2025 01:41:08 +0000 Subject: [PATCH 067/100] shard_map and other fixes to direct-linearize Co-authored-by: Dougal Maclaurin --- jax/_src/api.py | 4 +- jax/_src/core.py | 4 +- jax/_src/interpreters/ad.py | 113 +++++++++++++++++--------- jax/_src/interpreters/partial_eval.py | 25 +++++- jax/_src/interpreters/pxla.py | 4 +- jax/_src/lax/lax.py | 4 +- jax/_src/mesh.py | 2 +- jax/_src/pjit.py | 12 --- jax/_src/state/primitives.py | 13 ++- jax/experimental/shard_map.py | 100 +++++++++++++---------- tests/core_test.py | 3 +- tests/mutable_array_test.py | 68 +++------------- tests/pmap_test.py | 3 + tests/shard_map_test.py | 33 ++++---- 14 files changed, 207 insertions(+), 181 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index cce82aa8b..4b14d8096 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2005,8 +2005,8 @@ def vjp( raise NotImplementedError("reduce_axes argument to vjp is deprecated") del reduce_axes check_callable(fun) - wrapped_fun = lu.wrap_init(fun, - debug_info=debug_info("vjp", fun, primals, {})) + wrapped_fun = lu.wrap_init( + fun, debug_info=debug_info("vjp", fun, primals, {})) return _vjp(wrapped_fun, *primals, has_aux=has_aux) def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): diff --git a/jax/_src/core.py b/jax/_src/core.py index 767c61089..9d8edeb8b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2507,8 +2507,8 @@ class MapPrimitive(Primitive): def get_bind_params(self, params): new_params = dict(params) jaxpr: Jaxpr = new_params.pop('call_jaxpr') - subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr, - debug_info=jaxpr.debug_info), jaxpr, ()) + subfun = lu.hashable_partial( + lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ()) axes = new_params.pop('out_axes') new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes) return [subfun], new_params diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 37ad40d22..d951b88e4 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -39,7 +39,7 @@ from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal) from jax._src.dtypes import dtype, float0 from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name, as_hashable_function, weakref_lru_cache, - partition_list) + partition_list, subs_list2) zip = safe_zip map = safe_map @@ -91,6 +91,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, *primals, **params): with core.take_current_trace() as parent_trace: tangent_trace = pe.DynamicJaxprTrace(debug_info) + tangent_trace.tag = _tag linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag) tracers = [LinearizeTracer(linearize_trace, p, tangent_trace.new_arg(get_aval(p).to_tangent_aval())) @@ -104,11 +105,23 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment] jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) - residual_avals = map(get_aval, consts) if attrs_tracked: raise NotImplementedError("TODO: attrs") - _store.store((residual_avals, nzs_out, jaxpr)) - return tuple(consts) + tuple(out_primals) + which_env = [(isinstance(c, pe.DynamicJaxprTracer) and + getattr(c._trace, 'tag', None) is _tag) for c in consts] + jaxpr = pe.move_envvars(jaxpr, tuple(which_env)) + res, env = partition_list(which_env, consts) + residual_avals = map(get_aval, res) + # Which residuals are just forwarded inputs? Check object id. + id_map = {id(p): i for i, p in enumerate(primals)} + in_fwd: list[int | None] = [id_map.get(id(r)) for r in res] + # Which residuals are already primal outputs? Check object id. + id_map = {id(p): i for i, p in enumerate(out_primals)} + out_fwd: list[int | None] = [id_map.get(id(r)) for r in res] + # Prune residuals not to include forwarded primal inputs or outputs. + res = [p for p, f1, f2 in zip(res, in_fwd, out_fwd) if f1 is None and f2 is None] + _store.store((residual_avals, nzs_out, jaxpr, env, in_fwd, out_fwd)) + return *res, *out_primals @lu.transformation2 def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents): @@ -157,6 +170,7 @@ def _linearize_jaxpr( primal_trace = pe.DynamicJaxprTrace(dbg) tangent_trace = pe.DynamicJaxprTrace(dbg) lin_trace = LinearizeTrace(primal_trace, tangent_trace) + tangent_trace.tag = lin_trace.tag def new_arg(trace, primal_aval, nz): primal = primal_trace.new_arg(primal_aval) @@ -197,6 +211,7 @@ def direct_linearize(traceable: lu.WrappedFun, tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents] linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag) + tangent_trace.tag = linearize_trace.tag tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)] tracers = [t.full_lower() for t in tracers] with (core.set_current_trace(linearize_trace, check_leaks=True), @@ -217,6 +232,10 @@ def direct_linearize(traceable: lu.WrappedFun, out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info) tangent_trace.invalidate() + jaxpr, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr, [True] * len(jaxpr.outvars), + [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) + consts = [c for c, used in zip(consts, used_consts) if used] out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else pe.PartialVal.known(zeros_like_aval(t.aval)) for t, nz in zip(out_tangents, out_nzs)] @@ -586,7 +605,7 @@ def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: primal_aval = get_aval(primal).strip_weak_type() tangent_aval = get_aval(tangent).strip_weak_type() - assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape) + assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), (primal_aval.shape, tangent_aval.shape) expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype) @@ -641,6 +660,7 @@ class LinearizeTrace(Trace): return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in), dict(symbolic_zeros=symbolic_zeros)) + @partial(lu.wrap_init, debug_info=f_jvp.debug_info) def _f_jvp(primals, tangents): outs = f_jvp.call_wrapped(*primals, *tangents) primals_out, tangents_out = split_list(outs, [len(outs) // 2]) @@ -651,7 +671,7 @@ class LinearizeTrace(Trace): nonzeros_in = [type(t) is not Zero for t in tangents_in] primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp( _f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros, - f_jvp.debug_info, primals_in, {}) + primals_in, {}) with core.set_current_trace(self.tangent_trace): tangents_out = linearized(residuals, *tangents_in) @@ -690,53 +710,65 @@ class LinearizeTrace(Trace): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not Zero for t in tangents) - f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in, - f.debug_info) + f_primal, linearize_outs_thunk = linearize_subtrace( + f, self.tag, nzs_in, f.debug_info) if isinstance(call_primitive, core.MapPrimitive): - @as_hashable_function(closure=(linearize_outs_thunk)) + out_axes_thunk = params['out_axes_thunk'] + @as_hashable_function(closure=out_axes_thunk) def new_out_axes_thunk(): - residual_avals, _, _ = linearize_outs_thunk() - out_axes = params['out_axes_thunk']() - return (*(0 for _ in residual_avals), *out_axes) + _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + out_axes = out_axes_thunk() + return (*(0 for _ in range(num_res_out)), *out_axes) primal_params = dict(params, out_axes_thunk=new_out_axes_thunk) else: primal_params = params all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), primal_params) - residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk() - num_residuals = len(residual_avals) - residuals = all_primal_results[:num_residuals] - primals_out = all_primal_results[num_residuals:] + residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + non_fwd_res = all_primal_results[:num_res_out] + primals_out = all_primal_results[num_res_out:] + residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) if isinstance(call_primitive, core.MapPrimitive): in_axes = params['in_axes'] out_axes = params['out_axes_thunk']() residual_avals = map(get_aval, residuals) - new_in_axes = (*(0 for _ in residual_avals), + residual_axes = [in_axes[f1] if f1 is not None else + out_axes[f2] if f2 is not None else + 0 for f1, f2 in zip(in_fwd, out_fwd)] + new_in_axes = (*residual_axes, *(None for _ in range(len(env))), *(ax for ax, nz in zip(in_axes, nzs_in) if nz)) new_out_axes = (*(ax for ax, nz in zip(out_axes, nzs_out) if nz),) # NOTE: This assumes that the output tangents being zero is a # deterministic function of which input tangents were zero. - @as_hashable_function(closure=(new_out_axes)) + @as_hashable_function(closure=new_out_axes) def new_out_axes_thunk(): return new_out_axes - params = dict(params, - in_axes=new_in_axes, - out_axes_thunk=new_out_axes_thunk) + params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk) update_params = call_linearize_param_updaters.get(call_primitive) - new_params = update_params(params, residual_avals, nzs_in) if update_params else params + num_new_args = len(residuals) + len(env) + new_params = update_params(params, num_new_args, nzs_in) if update_params else params + num_residuals = len(residual_avals) + @as_hashable_function(closure=(num_residuals, lin_jaxpr)) def f_tangent(*args): - residuals = args[:num_residuals] + consts = args[:num_residuals] nz_tangents = args[num_residuals:] - return core.eval_jaxpr(lin_jaxpr, residuals, *nz_tangents) + return core.eval_jaxpr(lin_jaxpr, consts, *nz_tangents) + # TODO(mattjj,dougalm): this tag is read by DynamicJaxprTrace.process_map to + # avoid round-tripping the jaxpr and thus getting grad-of-pmap cache misses. + # Remove when we replace the pmap implementation. + f_tangent._pmap_tag = isinstance(call_primitive, core.MapPrimitive) + thing = lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info) nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] nz_tangents_out = call_primitive.bind_with_trace( - self.tangent_trace, (lu.wrap_init(f_tangent, - debug_info=lin_jaxpr.debug_info), - *residuals, *nz_tangents_in), new_params) + self.tangent_trace, + (thing, + *residuals, *env, *nz_tangents_in), new_params) nz_tangents_out_iter = iter(nz_tangents_out) tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal) for nz, primal in zip(nzs_out, primals_out)] @@ -762,14 +794,14 @@ def fallback_linearize_rule(_prim: core.Primitive, msg = f"Differentiation rule for '{_prim}' not implemented" raise NotImplementedError(msg) debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params) - return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False, - debug_jvp, primals, params) + return linearize_from_jvp(lu.wrap_init(jvp, debug_info=debug_jvp), + _prim.multiple_results, _nonzeros, False, False, + primals, params) -def linearize_from_jvp(jvp: Callable, +def linearize_from_jvp(jvp: lu.WrappedFun, multiple_results: bool, nonzeros: Sequence[bool], user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool, - debug_info: core.DebugInfo, primals, params): current_name_stack = source_info_util.current_name_stack() with core.take_current_trace() as parent_trace: @@ -792,13 +824,18 @@ def linearize_from_jvp(jvp: Callable, tangent_args = tuple(trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval) for aval, nz in zip(tangent_avals, nonzeros)) with core.set_current_trace(trace): - out_primals, out_tangents = jvp(primals, tangent_args, **params) + out_primals, out_tangents = jvp.call_wrapped(primals, tangent_args, **params) if not multiple_results: out_primals = [out_primals] out_tangents = [out_tangents] out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals] + if any(p is None for p in out_primals): + raise ValueError( + "Linearization failed to produce known values for all output primals. " + "This is typically caused by attempting to differentiate a function " + "uses an operation that does not support reverse-mode autodiff.") out_nzs = [type(t) is not zero_type and not trace.to_jaxpr_tracer(t).is_known() for t in out_tangents] @@ -806,7 +843,7 @@ def linearize_from_jvp(jvp: Callable, out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz] in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz] - jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, debug_info) + jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info) def linearized(residuals, *tangents): nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz] @@ -973,9 +1010,8 @@ def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _): else: consts = () all_args, in_tree_def = tree_flatten((consts, args, ct)) - fun = lu.hashable_partial(lu.wrap_init(backward_pass, - debug_info=call_jaxpr.debug_info), - call_jaxpr, False) + fun = lu.hashable_partial(lu.wrap_init( + backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) update_params = call_transpose_param_updaters.get(primitive) if update_params: @@ -1013,9 +1049,8 @@ def map_transpose(primitive: core.Primitive, params, call_jaxpr: core.Jaxpr, args, ct, _): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts # TODO(necula): use the right debug_info for the backwards pass - fun = lu.hashable_partial(lu.wrap_init(backward_pass, - debug_info=call_jaxpr.debug_info), - call_jaxpr, False) + fun = lu.hashable_partial(lu.wrap_init( + backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False) fun, nz_arg_cts = nonzero_outputs(fun) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) # Preserve axis for primal arguments, skip tangents (represented as undefined primals). diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 6fde73705..ef8e02dda 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -46,7 +46,8 @@ from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, - as_hashable_function, weakref_lru_cache, subs_list) + as_hashable_function, weakref_lru_cache, subs_list, + HashableFunction) map, unsafe_map = safe_map, map @@ -837,6 +838,11 @@ def tracers_to_jaxpr( # del getvar # needed to avoid cyclic-reference closure, apparently! return jaxpr, const_vals, env_vals +@weakref_lru_cache +def move_envvars(jaxpr: Jaxpr, which: tuple[bool, ...]) -> Jaxpr: + constvars, envvars = partition_list(which, jaxpr.constvars) + return jaxpr.replace(constvars=constvars, invars=[*envvars, *jaxpr.invars]) + @weakref_lru_cache def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: """Moves the constvars to the start of invars.""" @@ -1840,7 +1846,7 @@ def _inline_literals( class DynamicJaxprTrace(core.Trace): - __slots__ = ("frame",) + __slots__ = ("frame", "tag") def __init__(self, debug_info: core.DebugInfo): self.frame = JaxprStackFrame(debug_info) @@ -1972,17 +1978,18 @@ class DynamicJaxprTrace(core.Trace): self.frame.add_eqn(eqn) return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] - def process_map(self, map_primitive, f: lu.WrappedFun, - tracers: Sequence[core.Tracer], params): + def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] axis_name, axis_size = params['axis_name'], params['axis_size'] reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) if in_axis is not None else a for a, in_axis in zip(in_avals, params['in_axes'])] + with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]): jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic( f, reduced_in_avals) + jaxpr, consts = _linearize_of_pmap_hack(f, jaxpr, consts) ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects) if ordered_effects: raise ValueError("Ordered effects not supported for " @@ -2582,3 +2589,13 @@ def inline_jaxpr_into_trace( return tracer return [x.val if isinstance(x, Literal) else tracer_env[x] if x in tracer_env else new_tracer(x) for x in jaxpr.outvars] + +# TODO(mattjj,dougalm): this special handling is to avoid round-tripping the +# jaxpr when we do grad-of-pmap. The tag is set by LinearizeTrace.process_call's +# handling of pmap. Remove when we replace the pmap implementation. +def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, list]: + if (not f.transforms and type(f.f) is HashableFunction and + getattr(f.f, '_pmap_tag', None)): + _, jaxpr = f.f.closure + return convert_constvars_jaxpr(jaxpr), [] + return jaxpr, consts diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9a409e4cb..1b2d85006 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1394,9 +1394,9 @@ def xla_call_jvp_update_params(params, nz_tangents): new_donated_invars = (*donated_invars, *donated_tangents) return dict(params, donated_invars=new_donated_invars) -def _xla_call_linearize_update_params(params, residual_avals, nz_tangents): +def _xla_call_linearize_update_params(params, num_new_inputs, nz_tangents): donated_invars_prev = params['donated_invars'] - donated_invars = (*(False for _ in residual_avals), + donated_invars = (*(False for _ in range(num_new_inputs)), *(d for d, nz in zip(donated_invars_prev, nz_tangents) if nz)) return dict(params, donated_invars=donated_invars) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 5dc5abdf9..bb80f7630 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3736,14 +3736,14 @@ def _sin_lowering(ctx, x): return sine(ctx, x) return _nary_lower_hlo(hlo.sine, ctx, x) -def _sin_p_lin(nzs, x): +def _sin_lin(nzs, x): nz, = nzs cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass) return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_)) sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) -ad.primitive_linearizations[sin_p] = _sin_p_lin +ad.primitive_linearizations[sin_p] = _sin_lin mlir.register_lowering(sin_p, _sin_lowering) batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index db0799c5a..c2e39a818 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -200,7 +200,7 @@ class _BaseMesh: _mesh_object_dict = {} # type: ignore -MeshAxisType = dict[AxisTypes, str | tuple[str, ...]] +MeshAxisType = dict[AxisTypes, MeshAxisName | tuple[MeshAxisName, ...]] class Mesh(_BaseMesh, contextlib.ContextDecorator): """Declare the hardware resources available in the scope of this manager. diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 86df66301..041b8a07c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2043,18 +2043,6 @@ def _pjit_jvp(primals_in, tangents_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs): - if any(isinstance(c, core.MutableArray) for c in jaxpr.consts): - jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr) - mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals) - primals_in = [*primals_in, *mut_primals] - tangents_in = [*tangents_in, *mut_tangents] - in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut_primals) - in_layouts = (*in_layouts,) + (None,) * len(mut_primals) - donated_invars = (*donated_invars,) + (False,) * len(mut_primals) - - tangents_in = [ad_util.zeros_like_aval(a) if isinstance(a, AbstractRef) else x - for x, a in zip(tangents_in, jaxpr.in_avals)] - is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in] jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr( jaxpr, is_nz_tangents_in, instantiate=False) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 2a8b8bcc9..f2e03d04e 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -437,6 +437,8 @@ def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any): ref_primal, x_primal, *idx = primals assert isinstance(ref_primal.aval, AbstractRef) ref_tangent, x_tangent, *_ = tangents + if type(ref_tangent) is ad_util.Zero: + raise Exception("you're an idiot") assert isinstance(ref_tangent.aval, AbstractRef) x_tangent = ad_util.instantiate(x_tangent) return (swap_p.bind(ref_primal, x_primal, *idx, **params), @@ -657,5 +659,14 @@ mlir.register_lowering( # === AD rules for mutable arrays === -ad.defjvp(core.mutable_array_p, lambda g, _: core.mutable_array(g)) +def _mut_jvp(primals, tangents): + (init_val,), (init_val_dot,) = primals, tangents + primal_out = core.mutable_array_p.bind(init_val) + if type(init_val_dot) is ad_util.Zero: + tangent_out = core.mutable_array_p.bind(ad_util.zeros_like_aval(init_val_dot.aval)) + else: + tangent_out = core.mutable_array_p.bind(init_val_dot) + return primal_out, tangent_out + +ad.primitive_jvps[core.mutable_array_p] = _mut_jvp ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g)) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 0f161c074..0477e1c90 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -544,6 +544,8 @@ def _shard_map_staging( return out_tracers pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging +# TODO add underscore version, for direct-linearize to consume + def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: assert isinstance(aval, core.ShapedArray) return aval @@ -742,9 +744,8 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, out_avals_ = [x.aval for x in jaxpr.outvars] in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, in_avals_, in_nodes) - new_axis_context = sharding_impls.SPMDAxisContext( - mesh, frozenset(mesh.axis_names) - auto - ) + manual_axes = frozenset(mesh.axis_names) - auto + new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) with _extend_axis_env(mesh, auto): out_nodes_, tokens_out = mlir.call_lowering( @@ -895,7 +896,6 @@ def _match_spec(mesh: Mesh, check_rep: bool, def _match(mesh, check_rep, pspec, x): src = P(mesh.axis_names) - # TODO put back (?) needed for rep checking in eager? for now test rewrite return shard_map(_rem_singleton, mesh, (src,), pspec, check_rep=False)(x) def _rem_singleton(x): return jnp.squeeze(x, axis=0) @@ -914,6 +914,7 @@ class ShardMapTrace(core.Trace): __slots__ = ("mesh", "auto", "check", "context_mesh") mesh: Mesh + auto: frozenset[AxisName] check: bool context_mesh: AbstractMesh @@ -927,7 +928,7 @@ class ShardMapTrace(core.Trace): if isinstance(val, ShardMapTracer): return val.val, val.rep elif isinstance(val, Tracer): - raise Exception("Shouldn't have any non-shard_map tracers") + raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") else: val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh) return val_, None @@ -1609,34 +1610,40 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, out_names_thunk, check_rep, rewrite, auto): primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not ad.Zero for t in tangents) - f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, - f.debug_info) + f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz] - all_names = _all_newly_manual_mesh_names(mesh, auto, trace) + res_names = _all_newly_manual_mesh_names(mesh, auto, trace) - @as_hashable_function(closure=(linearize_outs_thunk)) + @as_hashable_function(closure=linearize_outs_thunk) def primal_out_names_thunk(): - residual_avals, _, _ = linearize_outs_thunk() + _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() out_names = out_names_thunk() - # This is incorrect so we set `check_rep=False` as we do in the JVP rule. - return (*({0: all_names} for _ in residual_avals), *out_names) + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + # This is incorrect so we set `check_rep=False` in the tangent (as in JVP). + return (*({0: res_names} for _ in range(num_res_out)), *out_names) primal_params = dict( mesh=mesh, in_names=in_names, out_names_thunk=primal_out_names_thunk, check_rep=check_rep, rewrite=rewrite, auto=auto) all_primal_results = shard_map_p.bind_with_trace( - trace.parent_trace, (f_primal,) + tuple(primals), primal_params) - residual_avals, nzs_out, lin_jaxpr = linearize_outs_thunk() - num_residuals = len(residual_avals) - residuals = all_primal_results[:num_residuals] - primals_out = all_primal_results[num_residuals:] - args_to_promote = [getattr(aval, 'shape', ()) == () for aval in residual_avals] - lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) + trace.parent_trace, (f_primal, *primals), primal_params) + residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + non_fwd_res = all_primal_results[:num_res_out] + primals_out = all_primal_results[num_res_out:] + residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) + args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None + for aval, f1, f2 in zip(residual_avals, in_fwd, out_fwd)] + with core.extend_axis_env_nd(mesh.shape.items()): + lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) out_names = out_names_thunk() - new_in_names = (*({0: all_names} for _ in residual_avals), + residual_names = [in_names[f1] if f1 is not None else + out_names[f2] if f2 is not None else + {0: res_names} for f1, f2 in zip(in_fwd, out_fwd)] + new_in_names = (*residual_names, *({} for _ in range(len(env))), *(ax for ax, nz in zip(in_names, nzs_in) if nz)) - new_out_names = (*(ax for ax, nz in zip(out_names, nzs_out) if nz),) + new_out_names = tuple(ax for ax, nz in zip(out_names, nzs_out) if nz) @as_hashable_function(closure=(new_out_names)) def tangent_out_names_thunk(): return new_out_names @@ -1645,15 +1652,14 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, out_names_thunk=tangent_out_names_thunk, check_rep=False, rewrite=rewrite, auto=auto) + # TODO TODO don't round-trip def f_tangent(*args): - residuals = args[:num_residuals] - nz_tangents = args[num_residuals:] - return core.eval_jaxpr(lin_jaxpr, (), *residuals, *nz_tangents) + return core.eval_jaxpr(lin_jaxpr, (), *args) nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] nz_tangents_out = shard_map_p.bind_with_trace(trace.tangent_trace, (lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info), - *residuals, *nz_tangents_in), tangent_params) + *residuals, *env, *nz_tangents_in), tangent_params) nz_tangents_out_iter = iter(nz_tangents_out) tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal) for nz, primal in zip(nzs_out, primals_out)] @@ -1663,13 +1669,13 @@ ad.LinearizeTrace.process_shard_map = _shard_map_linearize @lu.transformation2 def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs): ans = f(*args, **kwargs) - residual_avals, _, _ = linearize_outs_thunk() - num_residuals = len(residual_avals) - residuals = ans[:num_residuals] - primals = ans[num_residuals:] - residuals = tuple(jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x - for x in residuals) - return residuals + primals + _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + residuals = ans[:num_res_out] + primals = ans[num_res_out:] + residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x + for x in residuals] + return *residuals, *primals @lu.transformation2 def _promote_scalar_residuals(f: Callable, *args, **kwargs): @@ -1798,10 +1804,10 @@ def _partial_eval_jaxpr_custom_rule( _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) newvar = core.gensym() - params_known, params_staged, all_names = _pe_custom_params( + params_known, params_staged, res_names = _pe_custom_params( unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which, dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) - residuals = [newvar(_unshard_aval(mesh, {0: all_names}, var.aval)) + residuals = [newvar(_unshard_aval(mesh, {0: res_names}, var.aval)) for var, w in zip(jaxpr_staged.invars[:num_res], which) if w] eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, @@ -1853,10 +1859,10 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, # prune inputs to jaxpr_known according to unks_in mesh = params_known['mesh'] auto = params_known['auto'] - all_names = _all_newly_manual_mesh_names(mesh, auto) + res_names_ = _all_newly_manual_mesh_names(mesh, auto) in_names_known, _ = partition_list(unks_in, params_known['in_names']) _, out_names_known = partition_list(kept_outs_known, params_known['out_names']) - out_names_known = out_names_known + [{0: all_names}] * sum(which) + out_names_known = out_names_known + [{0: res_names_}] * sum(which) new_params_known = dict(params_known, in_names=tuple(in_names_known), out_names=tuple(out_names_known)) @@ -1864,12 +1870,12 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, _, in_names_staged = partition_list(inst_in, params_staged['in_names']) res_names = [in_names_known[f1] if f1 is not None else out_names_known[f2] if f2 is not None else - {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] + {0: res_names_} for f1, f2 in zip(in_fwd, out_fwd)] in_names_staged = res_names + in_names_staged _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names']) new_params_staged = dict(params_staged, in_names=tuple(in_names_staged), out_names=tuple(out_names_staged), check_rep=False) - return new_params_known, new_params_staged, all_names + return new_params_known, new_params_staged, res_names_ # TODO(mattjj): remove this mechanism when we revise mesh scopes def _all_mesh_names_except_spmd( @@ -1880,15 +1886,21 @@ def _all_mesh_names_except_spmd( return tuple(name for name in mesh.axis_names if name not in spmd_names and name not in auto) -# TODO(mattjj): remove this mechanism when we revise mesh scopes def _all_newly_manual_mesh_names( mesh: Mesh, auto: frozenset[AxisName], trace=None ) -> tuple[AxisName, ...]: - axis_env = core.get_axis_env() - spmd_names = axis_env.spmd_axis_names - axis_sizes = axis_env.axis_sizes - return tuple(name for name in mesh.axis_names if name not in spmd_names and - name not in auto and name not in axis_sizes) + if not (ctx_mesh := get_abstract_mesh()).empty: + del mesh + already_manual_names = set(ctx_mesh.axis_types.get(AxisTypes.Manual, ())) + return tuple(name for name in ctx_mesh.axis_names + if name not in auto | already_manual_names) + else: + # TODO(mattjj): remove this mechanism when we revise mesh scopes + axis_env = core.get_axis_env() + vmap_spmd_names = set(axis_env.spmd_axis_names) + already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names + return tuple(name for name in mesh.axis_names + if name not in auto | vmap_spmd_names | already_manual_names) # DCE diff --git a/tests/core_test.py b/tests/core_test.py index 5fc906bd3..c46d493bd 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -43,13 +43,14 @@ __ = pe.PartialVal.unknown(ShapedArray((), np.float32)) def call(f, *args): return jit(f)(*args) -@util.curry def core_call(f, *args): args, in_tree = jax.tree.flatten(args) dbg = debug_info("core_call_test", f, args, {}) f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree) out = core.call_p.bind(f, *args) return jax.tree.unflatten(out_tree(), out) +# call = core_call +core_call = util.curry(core_call) @util.curry def core_closed_call(f, *args): diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 9a6c5c167..13151a098 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -131,51 +131,6 @@ class MutableArrayTest(jtu.JaxTestCase): out = f() self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False) - @parameterized.parameters([True, False]) - def test_refs_in_vjps(self, jit): - def gradient_history_calculator_fwd(x, ref): - return x, ref - - def gradient_history_calculator_bwd(amax_history, grad_output): - amax_update = jnp.max(jnp.abs(grad_output)) - shifted = jnp.roll(amax_history[:], 1) - shifted = shifted.at[0].set(amax_update) - amax_history[:] = shifted - amax_from_history = jnp.max(amax_history[:]) - grad_output = grad_output / amax_from_history - return grad_output, None - - @jax.custom_vjp - def gradient_history_calculator(x, ref): - return x - - gradient_history_calculator.defvjp( - gradient_history_calculator_fwd, - gradient_history_calculator_bwd) - - class DotOp: - def __init__(self): - self.amax_history = core.mutable_array(jnp.zeros(5,)) - - def forward(self, x, y): - out = jnp.dot(x, y) - out = gradient_history_calculator(out, self.amax_history) - return out - - dot_op = DotOp() - x_top = jnp.ones((5,)) - y_top = jnp.ones((5,)) - - def loss(x, y): - return dot_op.forward(x, y).sum() - - if jit: - loss = jax.jit(loss) - - for i in range(3): - jax.grad(loss, (0,1))(x_top, y_top) - self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False) - @parameterized.parameters([True, False]) def test_scan_internal_mut_array(self, jit): def body_fun(_, x): @@ -371,17 +326,18 @@ class MutableArrayErrorsTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, "x_ref and y_ref"): f(x_ref, x_ref) - @parameterized.parameters([False, True]) - def test_argument_aliases_custom_vjp_fwd(self, jit): - @jax.custom_vjp - def f(x_ref, y_ref): - ... - f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None)) - if jit: - f = jax.jit(f) - x_ref = core.mutable_array(0.) - with self.assertRaisesRegex(ValueError, "x_ref and y_ref"): - jax.vjp(f, x_ref, x_ref) + # TODO(mattjj): re-enable test after direct-linearize + # @parameterized.parameters([False, True]) + # def test_argument_aliases_custom_vjp_fwd(self, jit): + # @jax.custom_vjp + # def f(x_ref, y_ref): + # ... + # f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None)) + # if jit: + # f = jax.jit(f) + # x_ref = core.mutable_array(0.) + # with self.assertRaisesRegex(ValueError, "x_ref and y_ref"): + # jax.vjp(f, x_ref, x_ref) # TODO(mattjj): add test test_closure_and_argument_aliases_custom_vjp diff --git a/tests/pmap_test.py b/tests/pmap_test.py index af2d03e29..0bddcaa78 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -51,6 +51,9 @@ from jax._src.interpreters import pxla from jax._src.lax import parallel from jax._src.lib import xla_extension from jax._src.util import safe_map, safe_zip +from jax._src import util +from jax.api_util import flatten_fun_nokwargs, debug_info +from jax._src import linear_util as lu config.parse_flags_with_absl() jtu.request_cpu_devices(8) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 8e51b3153..520fd10c9 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2205,20 +2205,23 @@ class ShardMapTest(jtu.JaxTestCase): mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): + # manual: 'i', 'j' return x * x def h(x): + # auto: 'j', manual: 'i' return shard_map(g, mesh, - in_specs=P(None, 'j'), - out_specs=P(None, 'j'))(x) + in_specs=P(None, 'j'), + out_specs=P(None, 'j'))(x) @jax.jit def f(x): + # auto: 'i', 'j' return shard_map(h, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x).sum() + in_specs=P('i', None), + out_specs=P('i', None), + check_rep=False, + auto=frozenset({'j'}))(x).sum() v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2814,7 +2817,7 @@ def sample(num: int, make_gen: Callable[[], Chooser]) -> Iterator[CaseSpec]: name, *case = sample_one(rng, make_gen()) if name not in seen: seen.add(name) - yield name, *case + yield case # To sample one test spec, we run the generator, getting back sequences of # options from it and sending in our choices from those options until finally a @@ -2929,7 +2932,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase): def make_mesh(mesh_shape): return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape)) - @parameterized.named_parameters( + @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)) def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) @@ -2938,7 +2941,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase): expected = ref(fun, mesh, in_specs, out_specs)(*args) self.assertAllClose(expected, out, check_dtypes=False) - @parameterized.named_parameters( + @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)) def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) @@ -2947,9 +2950,9 @@ class ShardMapSystematicTest(jtu.JaxTestCase): expected = ref(fun, mesh, in_specs, out_specs)(*args) self.assertAllClose(expected, out, check_dtypes=False) - @parameterized.named_parameters( - (name + f'_check_rep={check_rep}', *params, check_rep) - for (name, *params) in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap) + @parameterized.parameters( + (*params, check_rep) + for params in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap) for check_rep in [True, False] ) @jax.default_matmul_precision("float32") @@ -2961,7 +2964,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase): f = jax.jit(f) jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2) - @parameterized.named_parameters( + @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)) @jax.default_matmul_precision("float32") def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _): @@ -2980,7 +2983,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase): return g(*args) jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2) - @parameterized.named_parameters( + @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, partial(sample_shmap_batched, 5))) def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref): @@ -3003,7 +3006,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase): tol = 1e-2 if jtu.test_device_matches(['tpu']) else None self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol) - @parameterized.named_parameters( + @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, partial(sample_shmap_batched, 5))) def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _): From f4f31f89ae7755b219b15f4b0b27a77849a317e5 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 7 Mar 2025 21:35:40 +0000 Subject: [PATCH 068/100] [scan] when num_trips==0, don't generate weird size-zero reshapes --- jax/_src/lax/control_flow/loops.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 5179da751..c7dee3e71 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -457,8 +457,11 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_)) else: xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_)) - xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_] - yss = _map(partial(_empty_array, (num_trips, unroll), (None, None)), y_avals) + if num_trips: + xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_] + yss = _map(partial(_empty_array, (num_trips, unroll), (None, None)), y_avals) + else: + yss = _map(partial(_empty_array, (num_trips * unroll,), (None,)), y_avals) def inner(n, carry, xs): ys = [] @@ -493,7 +496,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, if num_trips: i = lax._const(num_trips, 0) _, carry, yss = while_loop(cond_fun, body_fun, (i, carry, yss)) - if unroll != 1: + if unroll != 1 and num_trips != 0: ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss] else: ys = yss From 041f5757473e98bfd644b3a93fac8c36c8016e12 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Fri, 7 Mar 2025 14:46:16 -0800 Subject: [PATCH 069/100] Support MHA in ragged paged attention for packed type PiperOrigin-RevId: 734695213 --- .../pallas/ops/tpu/ragged_paged_attention.py | 15 ++++++++++++--- tests/pallas/tpu_ragged_paged_attention_test.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 5d20dad6b..30cb20733 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -270,6 +270,15 @@ def ragged_paged_attention_kernel( b = jnp.left_shift(b, bw * (packing - 1)) return pltpu.bitcast(b, jnp.float32).astype(jnp.bfloat16) + def fold_on_2nd_minor(vec): + assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32 + assert len(vec.shape) >= 2 + last_dim = vec.shape[-1] + packing = get_dtype_packing(vec.dtype) + if vec.shape[-2] % packing != 0: + vec = vec.astype(jnp.float32) + return vec.reshape(-1, last_dim) + @pl.when(heads_blk_idx + q_blk_idx == 0) def prefetch_first_kv_blk(): async_copy_k, async_copy_v = create_kv_async_copy_descriptors( @@ -495,9 +504,9 @@ def ragged_paged_attention_kernel( q_head_idx = kv_head_idx * num_q_heads_per_kv_head # TODO(jevinjiang): extra handlig for packed type that can start at # unaligned position! - q = q_ref[ - :, q_head_idx : q_head_idx + num_q_heads_per_kv_head, : - ].reshape(-1, head_dim) + q = fold_on_2nd_minor( + q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] + ) k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk) v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk) flash_attention( diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index 1ed12aecf..cca8e3bc8 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -266,7 +266,7 @@ class PagedAttentionKernelTest(jtu.JaxTestCase): @parameterized.product( num_seqs=[1, 5, 16], # TODO(jevinjiang): Support more num_heads! - num_heads=[(32, 8), (32, 16), (12, 2)], + num_heads=[(32, 8), (32, 16), (12, 2), (4, 4)], dtype=[jnp.float32, jnp.bfloat16], num_kv_pages_per_block=[4, 8], num_queries_per_block=[32, 64], From 251b93ebd7c9a9206a929929da8e55ed3a943496 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 5 Feb 2025 01:41:08 +0000 Subject: [PATCH 070/100] fixups that we meant to include in #26427 Co-authored-by: Dougal Maclaurin --- jax/_src/interpreters/ad.py | 3 +-- jax/_src/state/primitives.py | 4 ++-- tests/debug_info_test.py | 6 ++++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index d951b88e4..cc1d19137 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -763,11 +763,10 @@ class LinearizeTrace(Trace): # Remove when we replace the pmap implementation. f_tangent._pmap_tag = isinstance(call_primitive, core.MapPrimitive) - thing = lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info) nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] nz_tangents_out = call_primitive.bind_with_trace( self.tangent_trace, - (thing, + (lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info), *residuals, *env, *nz_tangents_in), new_params) nz_tangents_out_iter = iter(nz_tangents_out) tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 3e8a37c75..6f7570a5f 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -443,8 +443,8 @@ def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any): ref_primal, x_primal, *idx = primals assert isinstance(ref_primal.aval, AbstractRef) ref_tangent, x_tangent, *_ = tangents - if type(ref_tangent) is ad_util.Zero: - raise Exception("you're an idiot") + # if type(ref_tangent) is ad_util.Zero: + # raise Exception("you're an idiot") assert isinstance(ref_tangent.aval, AbstractRef) x_tangent = ad_util.instantiate(x_tangent) return (swap_p.bind(ref_primal, x_primal, *idx, **params), diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 2c5d6a772..0e10b255f 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -876,6 +876,8 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(.*-> \(tensor {jax.result_info = \"\"}"), ]) + @unittest.skipIf(config.use_direct_linearize.value, + 'broken with direct-linearize') # TODO(necula) def test_vjp_of_nested_jit(self): tracer_spy = TracerSpy() def my_f(x, y): @@ -1285,6 +1287,8 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=checkpoint / remat, fun=my_g, arg_names=x,y, from None", ]) + @unittest.skipIf(config.use_direct_linearize.value, + 'broken with direct-linearize') # TODO(necula) def test_grad_scan(self): # Based on control_flow_test:testScanHigherOrderDifferentiation tracer_spy = TracerSpy() @@ -1593,6 +1597,8 @@ class DebugInfoTest(jtu.JaxTestCase): ], ) + @unittest.skipIf(config.use_direct_linearize.value, + 'broken with direct-linearize') # TODO(necula) def test_hessian(self): tracer_spy = TracerSpy() From fe26c19b9225615a1159b9ab0b586949bce78d74 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 8 Mar 2025 01:46:24 +0000 Subject: [PATCH 071/100] [direct-linearize] fix name_stack bugs Surprisingly, the bug was tracked down to #26111 aka cl/730939406, specifically the new implementation of reset_name_stack in source_info_util.py. To repro, use the before-this-commit implementation of reset_name_stack (left commented-out in the file), and run ``` JAX_USE_DIRECT_LINEARIZE=1 python tests/name_stack_test.py NameStackTransformationTest.test_nested_jit_stack ``` --- jax/_src/interpreters/ad.py | 4 ++-- jax/_src/source_info_util.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index d951b88e4..62375211a 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -1118,8 +1118,8 @@ def _jvp_jaxpr(jaxpr: core.ClosedJaxpr, assert len(jaxpr.in_avals) == len(nonzeros) f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info) - f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False), - nonzeros) + f_jvp, out_nonzeros = f_jvp_traceable( + jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic( diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index c05895c11..b1901f44f 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -305,7 +305,16 @@ class SetNameStackContextManager(contextlib.ContextDecorator): set_name_stack = SetNameStackContextManager -reset_name_stack = lambda: SetNameStackContextManager(NameStack()) + + +# TODO(mattjj,phawkins): figure out why the commented-out reset_name_stack +# implementation doesn't work. Luckily this context manager isn't called much so +# the performance shouldn't matter. See blame commit message for repro. +# reset_name_stack = lambda: SetNameStackContextManager(NameStack()) +@contextlib.contextmanager +def reset_name_stack() -> Iterator[None]: + with set_name_stack(NameStack()): + yield class TransformNameStackContextManager(contextlib.ContextDecorator): From 0f0636afab3f91cc144818026f390ea47b60dfab Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Fri, 7 Mar 2025 18:28:48 -0800 Subject: [PATCH 072/100] [Mosaic TPU][Pallas] Add pl.reciprocal PiperOrigin-RevId: 734749577 --- jax/_src/pallas/mosaic/lowering.py | 8 ++++++++ jax/_src/pallas/primitives.py | 28 ++++++++++++++++++++++++++++ jax/experimental/jax2tf/jax2tf.py | 1 + jax/experimental/pallas/__init__.py | 1 + jaxlib/mosaic/dialect/tpu/tpu.td | 10 ++++++++++ jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 8 ++++++++ tests/pallas/ops_test.py | 20 ++++++++++++++++++++ 7 files changed, 76 insertions(+) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 46294a766..9bc20ed2c 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3222,6 +3222,14 @@ def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule +def _reciprocal_lowering_rule(ctx: LoweringRuleContext, x, *, approx): + if not isinstance(x.type.element_type, ir.F32Type): + raise ValueError("Only float32 is supported.") + return tpu.reciprocal(x, approx=approx) + + +lowering_rules[primitives.reciprocal_p] = _reciprocal_lowering_rule + def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty): del ty (out_aval,) = ctx.avals_out diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index af84c2836..9a850d847 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -705,6 +705,34 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False, preferred_element_type=out_dtype, ) +reciprocal_p = jax_core.Primitive("reciprocal") + + +def reciprocal(x, *, approx=False): + return reciprocal_p.bind(x, approx=approx) + + +@reciprocal_p.def_abstract_eval +def _reciprocal_abstract_eval(x, *, approx): + del approx + return x + + +def _reciprocal_lowering_rule( + ctx: mlir.LoweringRuleContext, x, *, approx=False +): + def _reciprocal(x, *, approx=False): + if approx: + return jnp.reciprocal(x.astype(jnp.bfloat16)).astype(jnp.float32) + return jnp.reciprocal(x) + + return mlir.lower_fun(_reciprocal, multiple_results=False)( + ctx, x, approx=approx + ) + + +mlir.register_lowering(reciprocal_p, _reciprocal_lowering_rule) + class PrintEffect(effects.Effect): __str__ = lambda self: "Print" diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index d58a1bb0d..af5ec987e 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1545,6 +1545,7 @@ tf_not_yet_impl = [ "symmetric_product", "from_edtype", "to_edtype", + "reciprocal", # Pallas TPU primitives "bitcast", "repeat", diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 963be0381..1e0abacfc 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -53,6 +53,7 @@ from jax._src.pallas.primitives import max_contiguous as max_contiguous from jax._src.pallas.primitives import multiple_of as multiple_of from jax._src.pallas.primitives import num_programs as num_programs from jax._src.pallas.primitives import program_id as program_id +from jax._src.pallas.primitives import reciprocal as reciprocal from jax._src.pallas.primitives import run_scoped as run_scoped from jax._src.pallas.primitives import store as store from jax._src.pallas.primitives import swap as swap diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 25ba46dfe..4b5ed3493 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -538,6 +538,16 @@ def TPU_WeirdOp : TPU_Op<"weird", [Pure, ElementwiseMappable]> { let hasVerifier = 1; } +def TPU_ReciprocalOp : TPU_Op<"reciprocal", [Pure, SameOperandsAndResultType, ElementwiseMappable]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$input, + DefaultValuedAttr:$approx + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; +} + def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> { let arguments = (ins Variadic:$input); let results = (outs AnyVectorOfNonZeroRank:$output); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 69c29e51f..5a8042120 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/include/mlir/IR/Builders.h" #include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" +#include "mlir/include/mlir/IR/Diagnostics.h" #include "mlir/include/mlir/IR/IRMapping.h" #include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" @@ -1164,6 +1165,13 @@ LogicalResult LogBufferOp::verify() { return success(); } +LogicalResult ReciprocalOp::verify() { + if (!getType().getElementType().isF32()) { + return emitOpError("Not implemented: Reciprocal op for non-f32 dtypes"); + } + return success(); +} + void PackSubelementsOp::build(OpBuilder &builder, OperationState &state, const VectorType output_type, const ArrayRef padded_sources, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 20ff22e7e..e7afe4a30 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -2459,6 +2459,26 @@ class PallasPrimitivesTest(PallasBaseTest): wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)]) self.assertIn(expected, jaxpr.pretty_print(use_color=False)) + @parameterized.product(approx=[False, True]) + def test_reciprocal(self, approx): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on non-TPU devices") + if not jtu.if_cloud_tpu_at_least(2025, 3, 8): + self.skipTest("Test requires libtpu from 2025/3/8 or later") + shape = (32, 256) + x = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape) + + def kernel(x_ref, o_ref): + o_ref[...] = pl.reciprocal(x_ref[...], approx=approx) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32) + )(x) + kwargs = {} + if approx: + kwargs.update(dict(atol=2e-5, rtol=2e-5)) + np.testing.assert_allclose(out, jax.lax.reciprocal(x), **kwargs) + class PallasPrimitivesInterpretTest(PallasPrimitivesTest): INTERPRET = True From 04696b4d7bf2ae2c6dcb7af3339e86f5f0347eb3 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 8 Mar 2025 03:26:16 -0800 Subject: [PATCH 073/100] Update XLA dependency to use revision http://github.com/openxla/xla/commit/be68e80894862fe97757ea2b6110958ef4244c21. PiperOrigin-RevId: 734851053 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 4ac4564a7..97a672c24 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f1213b83af673729b60f5096da5186246568c0fb" -XLA_SHA256 = "77b886c9700d1f9a2ed65f18c176ddb38ffe6905128690f19e1fd7ca624dbebd" +XLA_COMMIT = "be68e80894862fe97757ea2b6110958ef4244c21" +XLA_SHA256 = "02d0b47d3d8866fdf4a9bb68987a16e54d6bd8fbf2fa9736b4110ab75059a189" def repo(): tf_http_archive( From 36d515ed2c2bc552182dd6bd87f2144e50b0c773 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 7 Mar 2025 19:24:31 -0500 Subject: [PATCH 074/100] A few more fixes for debug_info tests with direct_linearize. --- jax/_src/pjit.py | 1 - tests/debug_info_test.py | 14 +++++++------- tests/pmap_test.py | 3 --- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 041b8a07c..69cd8e809 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -29,7 +29,6 @@ import warnings import numpy as np from jax._src import api -from jax._src import ad_util from jax._src import api_util from jax._src import config from jax._src import core diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 0e10b255f..a39b53c3a 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -876,8 +876,6 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(.*-> \(tensor {jax.result_info = \"\"}"), ]) - @unittest.skipIf(config.use_direct_linearize.value, - 'broken with direct-linearize') # TODO(necula) def test_vjp_of_nested_jit(self): tracer_spy = TracerSpy() def my_f(x, y): @@ -898,6 +896,8 @@ class DebugInfoTest(jtu.JaxTestCase): # TODO(necula): result_paths "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", # TODO(necula): arg_names + "traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=," + if config.use_direct_linearize.value else "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ @@ -1287,8 +1287,6 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=checkpoint / remat, fun=my_g, arg_names=x,y, from None", ]) - @unittest.skipIf(config.use_direct_linearize.value, - 'broken with direct-linearize') # TODO(necula) def test_grad_scan(self): # Based on control_flow_test:testScanHigherOrderDifferentiation tracer_spy = TracerSpy() @@ -1328,6 +1326,8 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=as_,,, result_paths=" + if config.use_direct_linearize.value else "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=", ], expected_tracer_debug_infos=[ @@ -1597,8 +1597,6 @@ class DebugInfoTest(jtu.JaxTestCase): ], ) - @unittest.skipIf(config.use_direct_linearize.value, - 'broken with direct-linearize') # TODO(necula) def test_hessian(self): tracer_spy = TracerSpy() @@ -1614,8 +1612,10 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", # TODO(necula): arg_names and result_paths? - "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,", + "traced_for=jit, fun=my_f, arg_names=x,, result_paths=," + if config.use_direct_linearize.value else + "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 0bddcaa78..af2d03e29 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -51,9 +51,6 @@ from jax._src.interpreters import pxla from jax._src.lax import parallel from jax._src.lib import xla_extension from jax._src.util import safe_map, safe_zip -from jax._src import util -from jax.api_util import flatten_fun_nokwargs, debug_info -from jax._src import linear_util as lu config.parse_flags_with_absl() jtu.request_cpu_devices(8) From b9fb69d1fcd634e0efd9b7214900206fa0d021e0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 9 Mar 2025 03:54:13 -0700 Subject: [PATCH 075/100] Update XLA dependency to use revision http://github.com/openxla/xla/commit/e89a6b46bcb54e31c12d87893be85bd14720a6ec. PiperOrigin-RevId: 735080797 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 97a672c24..35557975d 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "be68e80894862fe97757ea2b6110958ef4244c21" -XLA_SHA256 = "02d0b47d3d8866fdf4a9bb68987a16e54d6bd8fbf2fa9736b4110ab75059a189" +XLA_COMMIT = "e89a6b46bcb54e31c12d87893be85bd14720a6ec" +XLA_SHA256 = "973751b5c2f5f3eac2a4bff331d4ec89b7af2ba630d2ac4f6b860c9140c7adcc" def repo(): tf_http_archive( From 6a718b762f8ae59ee81c3793273d2689fe0b553c Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Sun, 9 Mar 2025 21:35:46 -0700 Subject: [PATCH 076/100] Update stateful-computations.md tree_map -> tree.map --- docs/stateful-computations.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index fe84fc0d7..30c626bec 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -195,7 +195,7 @@ def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params: # and then use `updates` instead of `grad` to actually update the params. # (And we'd include `new_optimizer_state` in the output, naturally.) - new_params = jax.tree_map( + new_params = jax.tree.map( lambda param, g: param - g * LEARNING_RATE, params, grad) return new_params From 75d8702023fca6fe4a223bf1e08545c1c80581c0 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 10 Mar 2025 02:14:04 -0700 Subject: [PATCH 077/100] [Pallas/Mosaic GPU] Add lowerings/layout inference for all the necessary conversion ops when using Warpgroup semantics. Enable some of the pre-existing Pallas `ops_test`s for testing. PiperOrigin-RevId: 735293084 --- jax/_src/pallas/mosaic_gpu/lowering.py | 43 +++++++++++- .../mosaic/gpu/dialect_lowering.py | 39 ++++++++++- .../mosaic/gpu/fragmented_array.py | 38 ++++++++++- .../mosaic/gpu/layout_inference.py | 68 +++++++++++++++---- tests/mosaic/gpu_dialect_test.py | 38 +++++++++++ tests/pallas/ops_test.py | 63 +++++++++++++---- 6 files changed, 257 insertions(+), 32 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5da89f190..42d9627c1 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1162,6 +1162,9 @@ def _convert_element_type_lowering_rule_wg( cur_dtype = mgpu_utils.dtype_to_ir_type(x_aval.dtype) new_dtype = mgpu_utils.dtype_to_ir_type(new_dtype) + if cur_dtype == new_dtype: + return x + if 1 < mgpu_utils.bitwidth(cur_dtype) < 8 or 1 < mgpu_utils.bitwidth(new_dtype) < 8: raise NotImplementedError("Conversion involving sub-byte types unsupported") @@ -1170,7 +1173,29 @@ def _convert_element_type_lowering_rule_wg( from_integer = ir.IntegerType.isinstance(cur_dtype) to_integer = ir.IntegerType.isinstance(new_dtype) if from_float and to_float: - if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width: + cur_ty_width = ir.FloatType(cur_dtype).width + new_ty_width = ir.FloatType(new_dtype).width + if cur_ty_width == new_ty_width: + # There is no instruction to perform conversions between two float types + # of the same width. Go through the next-larger standard type. + # TODO(bchetioui): support conversions between float types of width 8. + # Which larger type to pick will depend on the number of bits in the + # smallest exponent. + if cur_ty_width != 16: + raise NotImplementedError( + "Conversion between float types of width other than 16 not" + " supported" + ) + larger_ty = ir.F32Type.get() + if x_aval.shape: + upcast_ty = ir.VectorType.get(x_aval.shape, larger_ty) + else: + upcast_ty = larger_ty + + def convert(ty, x): + return arith_dialect.truncf(ty, arith_dialect.extf(upcast_ty, x)) + + elif ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width: convert = arith_dialect.truncf else: convert = arith_dialect.extf @@ -1190,10 +1215,26 @@ def _convert_element_type_lowering_rule_wg( else: convert = arith_dialect.uitofp elif from_float and to_integer: + dst_width = mgpu_utils.bitwidth(new_dtype) + # We clamp the float value to the min/max integer destination value + # in order to match JAX/XLA casting behavior. Note that this differs + # from numpy casting behavior. if mgpu_utils.is_signed(y_aval.dtype): + maxint = 2 ** (dst_width - 1) - 1 + minint = -(2 ** (dst_width - 1)) convert = arith_dialect.fptosi else: + maxint = 2**dst_width - 1 + minint = 0 convert = arith_dialect.fptoui + + maxint = _ir_constant(maxint, cur_dtype) + minint = _ir_constant(minint, cur_dtype) + if x_aval.shape: + maxint = vector_dialect.splat(x.type, maxint) + minint = vector_dialect.splat(x.type, minint) + x = arith_dialect.minimumf(x, maxint) + x = arith_dialect.maximumf(x, minint) else: raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}") diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 40c0b2fce..03acc4c88 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -253,9 +253,9 @@ def _vector_load_op_lowering_rule( element_type = vector_load_op.result.type.element_type is_signed = False if ir.IntegerType.isinstance(element_type) else None - + strided_layout = layouts.from_strided_fragmented_layout_attr(out_layout_attr) fragmented_array = fa.FragmentedArray.load_strided( - vector_load_op.base, is_signed=is_signed + vector_load_op.base, is_signed=is_signed, vec_size=strided_layout.vec_size ) return [_fragmented_array_to_ir(fragmented_array, vector_load_op.result.type)] @@ -429,6 +429,41 @@ def _mgpu_async_store_op_lowering_rule( return [] +def _conversion_op_lowering_rule( + _: LoweringContext, + op: ir.OpView, + source_is_signed: bool | None, + target_is_signed: bool | None, +) -> Sequence[ir.Value]: + [in_layout] = inference_utils.in_layouts(op) + [layout] = inference_utils.out_layouts(op) + if in_layout != layout: + raise ValueError("Layout mismatch") + + target_ty = op.result.type.element_type # pytype: disable=attribute-error + operand = _fragmented_array_from_ir(op.operands[0], layout, source_is_signed) + converted = operand.astype(target_ty, is_signed=target_is_signed) + return [_fragmented_array_to_ir(converted, op.result.type)] + + +for op, source_is_signed, target_is_signed in [ + (arith.ExtFOp, None, None), + (arith.ExtSIOp, True, True), + (arith.ExtUIOp, False, False), + (arith.FPToSIOp, None, True), + (arith.FPToUIOp, None, False), + (arith.SIToFPOp, True, None), + (arith.TruncFOp, None, None), + (arith.TruncIOp, False, False), + (arith.UIToFPOp, False, None), +]: + _lowerings[op.OPERATION_NAME] = functools.partial( + _conversion_op_lowering_rule, + source_is_signed=source_is_signed, + target_is_signed=target_is_signed, + ) + + def _binary_op_lowering_rule( _: LoweringContext, op: Any, diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index b3f335155..a52eb329d 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -641,13 +641,22 @@ class FragmentedArray: raise NotImplementedError @classmethod - def load_strided(cls, ref: ir.Value, *, is_signed: bool | None = None): + def load_strided( + cls, + ref: ir.Value, + *, + is_signed: bool | None = None, + vec_size: int | None = None, + ): if not ir.MemRefType.isinstance(ref.type): raise TypeError(ref.type) ref_ty = ir.MemRefType(ref.type) shape = tuple(ref_ty.shape) - layout = WGStridedFragLayout.from_shaped_type(ref_ty) + if vec_size is None: + layout = WGStridedFragLayout.from_shaped_type(ref_ty) + else: + layout = WGStridedFragLayout(shape=shape, vec_size=vec_size) vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type) try: # Flattening the reference potentially produces simpler PTX but @@ -1322,7 +1331,30 @@ class FragmentedArray: from_integer = ir.IntegerType.isinstance(cur_dtype) to_integer = ir.IntegerType.isinstance(new_dtype) if from_float and to_float: - if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width: + cur_ty_width = ir.FloatType(cur_dtype).width + new_ty_width = ir.FloatType(new_dtype).width + if cur_ty_width == new_ty_width: + # There is no instruction to perform conversions between two float types + # of the same width. Go through the next-larger standard type. + # TODO(bchetioui): support conversions between float types of width 8. + # Which larger type to pick will depend on the number of bits in the + # smallest exponent. + if cur_ty_width != 16: + raise NotImplementedError( + "Conversion between float types of width other than 16 not" + " supported" + ) + larger_ty = ir.F32Type.get() + match self.layout: + case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout(): + shape = ir.VectorType(self.registers.flat[0].type).shape + upcast_ty = ir.VectorType.get(shape, larger_ty) + case WGMMARowFragLayout() | WGSplatFragLayout(): + upcast_ty = larger_ty + case _: + raise NotImplementedError(f"Unsupported layout {self.layout}") + convert = lambda ty, x: arith.truncf(ty, arith.extf(upcast_ty, x)) + elif ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width: convert = arith.truncf else: convert = arith.extf diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 6b2010cf5..4b69e60aa 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -23,13 +23,16 @@ from typing import cast from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith -from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import memref +from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +import math +import numpy as np from . import fragmented_array as fa from . import inference_utils from . import layouts as layouts_lib +from . import utils # mypy: ignore-errors @@ -192,22 +195,38 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts: for op in [ - arith.AddIOp, arith.AddFOp, + arith.AddIOp, + arith.AddFOp, arith.AndIOp, arith.BitcastOp, arith.CmpFOp, arith.CmpIOp, - arith.ExtFOp, arith.ExtSIOp, arith.ExtUIOp, + arith.ExtFOp, + arith.ExtSIOp, + arith.ExtUIOp, + arith.FPToSIOp, + arith.FPToUIOp, arith.MaximumFOp, - arith.MaxUIOp, arith.MaxSIOp, + arith.MaxUIOp, + arith.MaxSIOp, arith.MinimumFOp, - arith.MinUIOp, arith.MinSIOp, - arith.MulIOp, arith.MulFOp, + arith.MinUIOp, + arith.MinSIOp, + arith.MulIOp, + arith.MulFOp, arith.OrIOp, - arith.FloorDivSIOp, arith.DivUIOp, arith.DivFOp, - arith.RemUIOp, arith.RemSIOp, arith.RemFOp, - arith.SubIOp, arith.SubFOp, - arith.TruncFOp, arith.TruncIOp, + arith.FloorDivSIOp, + arith.DivUIOp, + arith.DivFOp, + arith.RemUIOp, + arith.RemSIOp, + arith.RemFOp, + arith.SIToFPOp, + arith.UIToFPOp, + arith.SubIOp, + arith.SubFOp, + arith.TruncFOp, + arith.TruncIOp, arith.XOrIOp, vector.LoadOp, vector.StoreOp, @@ -488,11 +507,36 @@ def infer_layout(module: ir.Module): # propagated. However, it is possible for some operations to remain # unannotated---for example, if there were no annotations on any operation in # the module at the start of this function. We annotate all the remaining ops - # that should be annotated with a strided fragmented layout. + # that should be annotated with a strided fragmented layout, whose vector size + # is derived from the narrowest type and vector size used in the program. We + # make sure to derive a single vector size in order to avoid relayouts at + # lowering time. + default_vector_size = math.inf + + def update_default_vector_size(op: ir.OpView): + nonlocal default_vector_size + for v in list(op.operands) + list(op.results): + if ir.VectorType.isinstance(v.type): + max_vec_size_for_v = ( + np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE + ) + desired_vec_size = 8 // utils.bytewidth(v.type.element_type) + default_vector_size = min( + default_vector_size, max_vec_size_for_v, desired_vec_size + ) + + for op in module.body: + traverse_op(op, update_default_vector_size) + + if default_vector_size is None: # Nothing to annotate. + return + def to_default_layout(ty: ir.Type) -> ir.Attribute | None: if not ir.VectorType.isinstance(ty): return None - layout = fa.WGStridedFragLayout.from_shaped_type(ty) + layout = fa.WGStridedFragLayout( + shape=cast(ir.ShapedType, ty).shape, vec_size=default_vector_size + ) return layouts_lib.to_strided_fragmented_layout_attr(layout) def set_default_layout(op: ir.OpView): diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 94d5d6714..ba9d23fa5 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -18,6 +18,7 @@ from typing import Callable from absl.testing import parameterized import jax +from jax import numpy as jnp from jax._src import config from jax._src import test_util as jtu from jax._src.interpreters import mlir as mlir_interpreter @@ -824,6 +825,43 @@ class DialectLoweringTest(MosaicGpuTest): ) ) + @parameterized.parameters( + (arith.ExtFOp, jnp.bfloat16, jnp.float32), + (arith.ExtSIOp, jnp.int16, jnp.int32), + (arith.ExtUIOp, jnp.int16, jnp.uint32), + (arith.FPToSIOp, jnp.float32, jnp.int32), + (arith.FPToUIOp, jnp.float32, jnp.uint32), + (arith.SIToFPOp, jnp.int16, jnp.float32), + (arith.TruncFOp, jnp.float32, jnp.float16), + (arith.TruncIOp, jnp.int32, jnp.int16), + (arith.UIToFPOp, jnp.uint32, jnp.float32), + ) + def test_lower_conversion_op_lowers_to_same_op(self, op, in_dtype, out_dtype): + shape = (4, 32) + + with ir.InsertionPoint(self.module.body): + scalar_in_ty = mgpu_utils.dtype_to_ir_type(in_dtype) + scalar_out_ty = mgpu_utils.dtype_to_ir_type(out_dtype) + in_ty = ir.VectorType.get(shape, scalar_in_ty) + out_ty = ir.VectorType.get(shape, scalar_out_ty) + if ir.IntegerType.isinstance(scalar_in_ty): + zero = ir.IntegerAttr.get(scalar_in_ty, 0) + else: + zero = ir.FloatAttr.get(scalar_in_ty, 0) + splat_zero = arith.ConstantOp( + in_ty, ir.DenseElementsAttr.get_splat(in_ty, zero) + ) + op(out_ty, splat_zero) + + mgpu.infer_layout(self.module) + mgpu.lower_mgpu_dialect(self.module, None) + + conversion_ops = find_if(self.module, lambda o: isinstance(o, op)) + # This is a splat, so we expect a single conversion op involving a scalar + # after lowering. + self.assertLen(conversion_ops, 1) + self.assertEqual(conversion_ops[0].result.type, scalar_out_ty) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index e7afe4a30..24ce3b722 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -38,10 +38,15 @@ import jax.numpy as jnp import numpy as np if sys.platform != "win32": - from jax.experimental.pallas import triton as plgpu + try: + from jax.experimental.pallas import mosaic_gpu as plgpu_mgpu + except ImportError: + plgpu_mgpu = None + from jax.experimental.pallas import triton as plgpu_triton from jax.experimental.pallas import tpu as pltpu else: - plgpu = None + plgpu_mgpu = None + plgpu_triton = None pltpu = None try: @@ -99,6 +104,7 @@ _DTYPES_32BIT = ( # TODO(apaszke): Add 8-bit floats. _DTYPES_SUB_32BIT = ( "bfloat16", + "float16", "int16", "int8", "int4", @@ -282,6 +288,13 @@ class PallasBaseTest(jtu.JaxTestCase): @classmethod def pallas_call(cls, *args, **kwargs): + if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: + assert plgpu_mgpu is not None + compiler_params = plgpu_mgpu.GPUCompilerParams( + thread_semantics=plgpu_mgpu.ThreadSemantics.Warpgroup + ) + kwargs["compiler_params"] = compiler_params + return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) def skip_if_mosaic_gpu(self): @@ -569,13 +582,12 @@ class OpsTest(PallasBaseTest): @parameterized.product(from_dtype=_DTYPES_32BIT, to_dtype=_DTYPES) @hp.given(hps.data()) def test_cast_from_32bit(self, from_dtype, to_dtype, data): - self.skip_if_mosaic_gpu() + sut_is_mosaic_gpu = jtu.test_device_matches(["gpu"]) and use_mosaic_gpu if to_dtype in {"float8_e4m3b11fnuz", "float8_e5m2", "float8_e4m3fn"}: if not jtu.test_device_matches(["tpu"]) or jtu.get_tpu_version() < 5: self.skipTest("Not supported on this hardware") if not jtu.if_cloud_tpu_at_least(2025, 3, 8): self.skipTest("Test requires libtpu from 2025/3/8 or later") - if from_dtype == to_dtype: self.skipTest("Unnecessary test") if jtu.is_device_tpu(version=4): @@ -589,6 +601,10 @@ class OpsTest(PallasBaseTest): self.skipTest("Not supported on this TPU generation") if jtu.test_device_matches(["gpu"]) and to_dtype in {"int4", "uint4"}: self.skipTest("int4/uint4 casts are buggy on GPU") # b/391292861 + if to_dtype == "float16" and not sut_is_mosaic_gpu: + self.skipTest("float16 is only supported with Mosaic GPU") + if sut_is_mosaic_gpu and to_dtype == "bool": + self.skipTest("Sub-byte types are not yet supported with Mosaic GPU") # XLA does not specify the float->int conversion result for NaNs. elements = dict(allow_nan=not jnp.issubdtype(to_dtype, jnp.integer)) @@ -620,7 +636,7 @@ class OpsTest(PallasBaseTest): # miss bugs that would be hidden due to exhaustive enumeration being in order. @parameterized.product(from_dtype=_DTYPES_SUB_32BIT, to_dtype=_DTYPES, randomize=(False, True)) def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): - self.skip_if_mosaic_gpu() + sut_is_mosaic_gpu = jtu.test_device_matches(["gpu"]) and use_mosaic_gpu if from_dtype == to_dtype: self.skipTest("Unnecessary test") @@ -640,7 +656,15 @@ class OpsTest(PallasBaseTest): self.skipTest("Not supported on this TPU generation") if jtu.test_device_matches(["gpu"]) and to_dtype in {"int4", "uint4"}: self.skipTest("int4/uint4 casts are buggy on GPU") # b/391292861 - + if from_dtype == "float16" or to_dtype == "float16" and not sut_is_mosaic_gpu: + self.skipTest("float16 is only supported with Mosaic GPU") + if sut_is_mosaic_gpu: + unsupported_types = {"bool", "int4", "uint4"} + if to_dtype in unsupported_types or from_dtype in unsupported_types: + self.skipTest("Sub-byte types are not yet supported with Mosaic GPU") + if not randomize: + # TODO(bchetioui): rework the test shapes to make this work. + self.skipTest("Exhaustive tests may run out of SMEM with Mosaic GPU") if from_dtype in { "float8_e4m3b11fnuz", "float8_e5m2", @@ -686,7 +710,14 @@ class OpsTest(PallasBaseTest): else: x = jax.lax.bitcast_convert_type( jnp.arange(1 << from_bitwidth, dtype=from_int_dtype), from_dtype - ).reshape(8, -1) + ) + if sut_is_mosaic_gpu: + # TMA loads only support max 256 elements per dimension, so we make + # sure that all the dimensions don't exceed that. + if x.shape[0] > 256: + x = x.reshape(256, -1) + else: + x = x.reshape(8, -1) else: if randomize: x = random.randint(random.key(234), (16, 16), 0, 1, jnp.int32) != 0 @@ -1437,7 +1468,7 @@ class OpsTest(PallasBaseTest): self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), ) def kernel(x_ref, o_ref): - o_ref[...] = plgpu.approx_tanh(x_ref[...]) + o_ref[...] = plgpu_triton.approx_tanh(x_ref[...]) x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) # We upcast to float32 because NumPy <2.0 does not handle custom dtypes @@ -1465,7 +1496,7 @@ class OpsTest(PallasBaseTest): out_shape=jax.ShapeDtypeStruct((256,), jnp.float16), ) def kernel(x_ref, o_ref): - [o_ref[...]] = plgpu.elementwise_inline_asm( + [o_ref[...]] = plgpu_triton.elementwise_inline_asm( "tanh.approx.f16x2 $0, $1;", args=[x_ref[...]], constraints="=r,r", @@ -1491,14 +1522,14 @@ class OpsTest(PallasBaseTest): ) def kernel(x_ref, o_ref): o_ref[...] = x_ref[...] - plgpu.debug_barrier() + plgpu_triton.debug_barrier() x = jnp.array([4.2, 2.4]).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x) @unittest.skipIf( sys.platform == "win32", - "plgpu.TritonCompilerParams unavailable on Windows", + "plgpu_triton.TritonCompilerParams unavailable on Windows", ) def test_debug_print(self): self.skip_if_mosaic_gpu() @@ -1513,7 +1544,9 @@ class OpsTest(PallasBaseTest): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1) + compiler_params=plgpu_triton.TritonCompilerParams( + num_warps=1, num_stages=1 + ), ) def kernel(x_ref, o_ref): pl.debug_print("It works!") @@ -1527,7 +1560,7 @@ class OpsTest(PallasBaseTest): @unittest.skipIf( sys.platform == "win32", - "plgpu.TritonCompilerParams unavailable on Windows", + "plgpu_triton.TritonCompilerParams unavailable on Windows", ) def test_debug_print_with_values(self): if jtu.test_device_matches(["tpu"]): @@ -1540,7 +1573,9 @@ class OpsTest(PallasBaseTest): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1) + compiler_params=plgpu_triton.TritonCompilerParams( + num_warps=1, num_stages=1 + ), ) def kernel(x_ref, o_ref): pl.debug_print("x[0] =", x_ref[0]) From f906d2b6d1da6c65cff90b4c128e1883dab74790 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 10 Mar 2025 04:13:48 -0700 Subject: [PATCH 078/100] Update XLA dependency to use revision http://github.com/openxla/xla/commit/efb27eb924fd5d9b20b908a7cadb11d78d2a81a1. PiperOrigin-RevId: 735322004 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 35557975d..a01d837fc 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "e89a6b46bcb54e31c12d87893be85bd14720a6ec" -XLA_SHA256 = "973751b5c2f5f3eac2a4bff331d4ec89b7af2ba630d2ac4f6b860c9140c7adcc" +XLA_COMMIT = "efb27eb924fd5d9b20b908a7cadb11d78d2a81a1" +XLA_SHA256 = "b3a3e0df6bd5923d081fc1a96df41f2f29497b329da58b9b992f4345abf21c8b" def repo(): tf_http_archive( From 91340ea0a74059e1468b013c68c34fa6c2b0af38 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 10 Mar 2025 05:07:26 -0700 Subject: [PATCH 079/100] [pallas:mosaic_gpu] Added support for math functions to the WG lowering PiperOrigin-RevId: 735333893 --- jax/_src/pallas/mosaic_gpu/lowering.py | 117 ++++++++++++++---- .../mosaic/gpu/dialect_lowering.py | 32 +++++ .../mosaic/gpu/layout_inference.py | 9 +- tests/pallas/mosaic_gpu_test.py | 38 ++++-- 4 files changed, 161 insertions(+), 35 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 42d9627c1..43b9008bb 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -17,7 +17,7 @@ from __future__ import annotations import collections -from collections.abc import Hashable, MutableMapping, MutableSequence, Sequence +from collections.abc import Callable, Hashable, MutableMapping, MutableSequence, Sequence import contextlib import dataclasses import functools @@ -25,6 +25,7 @@ import math from typing import Any, Protocol, cast import jax +from jax import api_util from jax import lax from jax._src import core as jax_core from jax._src import linear_util as lu @@ -36,6 +37,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect +from jax._src.lib.mlir.dialects import math as math_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.lib.mlir.dialects import scf as scf_dialect @@ -837,6 +839,29 @@ def _program_id(parallel_axis: int, squashed_dims: tuple[int, ...]) -> ir.Value: ) +def _lower_fun( + fun: Callable[..., Any], *, multiple_results: bool +) -> Callable[..., Any]: + + def lowering_rule(ctx: LoweringRuleContext, *args, **params): + wrapped_fun = lu.wrap_init( + fun + if multiple_results + else lambda *args, **params: (fun(*args, **params),), + params, + debug_info=api_util.debug_info( + "Pallas Mosaic GPU lower_fun", fun, args, params + ), + ) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + out = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, jaxpr, args, consts + ) + return out if multiple_results else out[0] + + return lowering_rule + + @register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Warpgroup) def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): @@ -1247,6 +1272,13 @@ mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ lax.not_p: lambda ctx, x: ~x, }) +mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({ + lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False), + lax.not_p: _lower_fun( + lambda x: jnp.bitwise_xor(x, -1), multiple_results=False + ), +}) + def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) @@ -1383,55 +1415,98 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): @register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Warpgroup) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): - [x_aval] = ctx.avals_in - x = _ensure_fa(x, x_aval.dtype) - if y == 2: - return x * x - return NotImplementedError + if y != 2: + raise NotImplementedError + return _square_lowering_rule(ctx, x) + @register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Warpgroup) def _square_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - x = _ensure_fa(x, x_aval.dtype) - return x * x + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + x = _ensure_fa(x, x_aval.dtype) + return x * x + if jnp.issubdtype(x_aval.dtype, jnp.integer): + return arith_dialect.muli(x, x) + if jnp.issubdtype(x_aval.dtype, jnp.floating): + return arith_dialect.mulf(x, x) + raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") + @register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.rsqrt( + _ensure_ir_value(x, x_aval.dtype), fastmath=fastmath + ) + @register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup) def _tanh_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.tanh(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.logistic_p, mgpu.ThreadSemantics.Lane) -def _logistic_lowering_rule(ctx: LoweringRuleContext, x): - [x_aval] = ctx.avals_in - a = _ensure_fa(x, x_aval.dtype) - return 1. / (1. + (-a).exp(approx=ctx.module_ctx.approx_math)) +def _logistic(x): + return 1.0 / (1 + lax.exp(-x)) + + +mosaic_lowering_rules[mgpu.ThreadSemantics.Lane][lax.logistic_p] = _lower_fun( + _logistic, multiple_results=False +) +mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][lax.logistic_p] = ( + _lower_fun(_logistic, multiple_results=False) +) + @register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup) def _exp_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - a = _ensure_fa(x, x_aval.dtype) - return a.exp(approx=ctx.module_ctx.approx_math) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.exp(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) @register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) def _exp2_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - a = _ensure_fa(x, x_aval.dtype) - return a.exp2(approx=ctx.module_ctx.approx_math) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.exp2(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) @register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup) def _log_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - a = _ensure_fa(x, x_aval.dtype) - return a.log(approx=ctx.module_ctx.approx_math) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.log(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) @register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 03acc4c88..eb5dcbc39 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -28,6 +28,7 @@ from jax._src.lib.mlir.dialects import builtin from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import llvm +from jax._src.lib.mlir.dialects import math as mlir_math from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf @@ -464,6 +465,37 @@ for op, source_is_signed, target_is_signed in [ ) +def _unary_op_lowering_rule( + _: LoweringContext, + op: Any, + impl: Callable[[fa.FragmentedArray], fa.FragmentedArray], + is_signed: bool | None = None, +) -> Sequence[ir.Value]: + in_layouts = inference_utils.in_layouts(op) + [layout] = inference_utils.out_layouts(op) + if any(in_layout != layout for in_layout in in_layouts): + raise ValueError("Layout mismatch") + kwargs = {} + if hasattr(op, "fastmath"): + kwargs = dict( + approx=op.fastmath == ir.Attribute.parse("#arith.fastmath") + ) + a = _fragmented_array_from_ir(op.operand, layout, is_signed) + return [_fragmented_array_to_ir(impl(a, **kwargs), op.result.type)] + + +for op, impl, is_signed in [ + (mlir_math.RsqrtOp, fa.FragmentedArray.rsqrt, None), + (mlir_math.ExpOp, fa.FragmentedArray.exp, None), + (mlir_math.Exp2Op, fa.FragmentedArray.exp2, None), + (mlir_math.LogOp, fa.FragmentedArray.log, None), + (mlir_math.TanhOp, fa.FragmentedArray.tanh, None), +]: + _lowerings[op.OPERATION_NAME] = functools.partial( + _unary_op_lowering_rule, impl=impl, is_signed=is_signed + ) + + def _binary_op_lowering_rule( _: LoweringContext, op: Any, diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 4b69e60aa..044e7537d 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -18,15 +18,16 @@ from collections.abc import Callable, Sequence import dataclasses import enum from functools import partial +import math from typing import cast from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import math as mlir_math from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector -import math import numpy as np from . import fragmented_array as fa @@ -34,6 +35,7 @@ from . import inference_utils from . import layouts as layouts_lib from . import utils + # mypy: ignore-errors OptionalLayouts = tuple[list[ir.Attribute], list[ir.Attribute]] | None @@ -228,6 +230,11 @@ for op in [ arith.TruncFOp, arith.TruncIOp, arith.XOrIOp, + mlir_math.ExpOp, + mlir_math.Exp2Op, + mlir_math.LogOp, + mlir_math.RsqrtOp, + mlir_math.TanhOp, vector.LoadOp, vector.StoreOp, ]: diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 756e5d220..0a3af26de 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -81,25 +81,37 @@ class PallasSm90ATest(PallasTest, jtu.CudaArchSpecificTest): class PallasCallTest(PallasTest): - @parameterized.named_parameters( - ("add_one", lambda x: x + 1.), - ("logistic", jax.lax.logistic), - ("exp", jax.lax.exp), - ("square", lambda x: x ** 2), - ("rsqrt", jax.lax.rsqrt), - ("tanh", jax.lax.tanh, 1e-6), - ("log", jax.lax.log) + @parameterized.product( + op=[ + lax.neg, + lax.bitwise_not, + lax.logistic, + lax.exp, + lambda x: x**2, + lax.rsqrt, + lax.tanh, + lax.log, + ], + approx_math=[True, False], + thread_semantics=[*plgpu.ThreadSemantics], ) - def test_unary_op(self, unary, rtol=1e-7): + def test_unary_op(self, op, approx_math, thread_semantics): + dtype = jnp.int32 if op is lax.bitwise_not else jnp.float32 + @functools.partial( pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + out_shape=jax.ShapeDtypeStruct([256], dtype), + compiler_params=plgpu.GPUCompilerParams( + approx_math=approx_math, thread_semantics=thread_semantics + ), ) def kernel(x_ref, o_ref): - o_ref[...] = unary(x_ref[...]) + o_ref[...] = op(x_ref[...]) - x = jnp.arange(256).astype(jnp.float32) - np.testing.assert_allclose(kernel(x), unary(x), rtol=rtol) + x = jnp.arange(256).astype(dtype) + np.testing.assert_allclose( + kernel(x), op(x), rtol=1e-5 if approx_math else 3e-7 + ) @parameterized.product( op=[ From 4eada56027a3d24e66e981855c987b378ef38e88 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 10 Mar 2025 11:03:52 -0400 Subject: [PATCH 080/100] Avoid using array operations within lax.py operations. --- jax/_src/lax/lax.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index bb80f7630..9878ddc3c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3723,10 +3723,11 @@ def _sin_complex(x): # 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2 a, b = real(x), imag(x) a_is_zero = eq(a, _const(a, 0)) + two = _const(a, 2) sn, cs = sin(a), cos(a) - e1m, e2m = expm1(b), expm1(-b) - snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2 - re, im = sn * csh, cs * snh + e1m, e2m = expm1(b), expm1(neg(b)) + snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two) + re, im = mul(sn, csh), mul(cs, snh) # avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf return select(a_is_zero, complex(_const(a, 0), im), complex(re, im)) @@ -3752,10 +3753,11 @@ def _cos_complex(x): # see also _sin_complex a, b = real(x), imag(x) a_is_zero = eq(a, _const(a, 0)) + two = _const(a, 2) sn, cs = sin(a), cos(a) - e1m, e2m = expm1(b), expm1(-b) - snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2 - re, im = cs * csh, -sn * snh + e1m, e2m = expm1(b), expm1(neg(b)) + snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two) + re, im = mul(cs, csh), mul(neg(sn), snh) return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im)) def _cos_lowering(ctx, x): @@ -3769,28 +3771,28 @@ ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) mlir.register_lowering(cos_p, _cos_lowering) tan_p = standard_unop(_float | _complex, 'tan') -ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) +ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans)))) mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) asin_p = standard_unop(_float | _complex, 'asin') -ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x)))) +ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x))))) mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin)) acos_p = standard_unop(_float | _complex, 'acos') -ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x)))) +ad.defjvp(acos_p, lambda g, x: mul(g, neg(rsqrt(sub(_const(x, 1), square(x)))))) mlir.register_lowering(acos_p, partial(_nary_lower_hlo, chlo.acos)) def atan_impl(x): return atan2(x, _const(x, 1)) atan_p = standard_unop(_float | _complex, 'atan') -ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x))) +ad.defjvp(atan_p, lambda g, x: div(g, add(_const(x, 1), square(x)))) mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.atan)) atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2') ad.defjvp(atan2_p, - lambda g, x, y: g * (y / (square(x) + square(y))), - lambda g, x, y: g * -x / (square(x) + square(y))) + lambda g, x, y: mul(g, div(y, add(square(x), square(y)))), + lambda g, x, y: mul(g, div(neg(x), add(square(x), square(y))))) mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.atan2)) sinh_p = standard_unop(_float | _complex, 'sinh') @@ -3802,17 +3804,17 @@ ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x))) mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.cosh)) asinh_p = standard_unop(_float | _complex, 'asinh') -ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x)))) +ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(add(square(x), _one(x))))) mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.asinh)) acosh_p = standard_unop(_float | _complex, 'acosh') ad.defjvp(acosh_p, - lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x))))) + lambda g, x: mul(g, rsqrt(mul(sub(x, _one(x)), add(x, _one(x)))))) mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh)) atanh_p = standard_unop(_float | _complex, 'atanh') ad.defjvp(atanh_p, - lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x)))) + lambda g, x: mul(reciprocal(add(_one(x), x)), div(g, sub(_one(x), x)))) mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.atanh)) real_p = unop(_complex_basetype, _complex, 'real') @@ -3906,11 +3908,11 @@ def _square_complex(x): a, b = real(x), imag(x) # zero square(x).real is handled explicitly for abs(a)==abs(b) cases # where for finite a, 2 * a is non-finite: - zero_re = is_finite(a) & (eq(a, b) | eq(a, -b)) + zero_re = is_finite(a) & (eq(a, b) | eq(a, neg(b))) # equivalent to a**2 - b**2 but avoids overflow errors for large a # and large b cases: - re = (a - b) * (a + b) - im = a * b * 2 + re = mul(sub(a, b), add(a, b)) + im = mul(mul(a, b), _const(a, 2)) return select(zero_re, complex(_const(a, 0), im), complex(re, im)) def _square_lower_hlo(ctx, x): @@ -5276,7 +5278,7 @@ def _ragged_dot_jvp_rule( if type(dy) is not ad_util.Zero else _zeros(primal_out) ) - tangent_out = dx_out + dy_out + tangent_out = add(dx_out, dy_out) return primal_out, tangent_out From 21884d4a14d364c3b82b312a668079c668cc2836 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 10 Mar 2025 08:17:07 -0700 Subject: [PATCH 081/100] Move (most) jaxlib linalg custom call registration into JAX. My motivation here is to fix the plugin support for batch partitionable custom calls. Since plugin support for custom call partitioners is provided via register_plugin_callback in xla_bridge, instead of xla_client itself, it's much more straightforward to register the custom calls in JAX. It would be possible to refactor things differently, but it actually seems like a reasonable choice to use the supported APIs from `jax.ffi` instead of `xla_client` so that we can take advantage of any new features we might add there in the future. This is all still a little bit brittle and I'd eventually like to migrate to a version where the XLA FFI library provides a mechanism for exporting handlers, but this change is still compatible with any future changes like that. PiperOrigin-RevId: 735381736 --- jax/_src/lax/linalg.py | 27 ++++++++++--- jax/experimental/sparse/_base.py | 10 +++++ jaxlib/gpu_linalg.py | 31 ++++++++------- jaxlib/gpu_solver.py | 67 +++++++++++++++----------------- jaxlib/gpu_sparse.py | 22 +++++------ jaxlib/lapack.py | 24 ++++++------ 6 files changed, 100 insertions(+), 81 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index abd104293..c674401fb 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -44,6 +44,10 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import svd as lax_svd from jax._src.lax import utils as lax_utils from jax._src.lax.lax import _float, _complex, _int +from jax._src.lib import gpu_linalg +from jax._src.lib import gpu_solver +from jax._src.lib import gpu_sparse +from jax._src.lib import lapack from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo @@ -51,12 +55,23 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec as P from jax._src.typing import Array, ArrayLike -# The following imports may be unused but they are needed to register the -# custom call targets defined in each module. -from jax._src.lib import gpu_linalg # pylint:disable=unused-import # noqa: F401 -from jax._src.lib import gpu_solver # pylint:disable=unused-import # noqa: F401 -from jax._src.lib import gpu_sparse # pylint:disable=unused-import # noqa: F401 -from jax._src.lib import lapack # pylint:disable=unused-import # noqa: F401 + +def register_module_custom_calls(module): + if hasattr(module, "registrations"): + for platform, targets in module.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + if hasattr(module, "batch_partitionable_targets"): + for name in module.batch_partitionable_targets(): + ffi.register_ffi_target_as_batch_partitionable(name) + + +register_module_custom_calls(gpu_linalg) +register_module_custom_calls(gpu_solver) +register_module_custom_calls(gpu_sparse) +register_module_custom_calls(lapack) # Top-level functions in alphabetical order. diff --git a/jax/experimental/sparse/_base.py b/jax/experimental/sparse/_base.py index 36d84cb0d..7739af029 100644 --- a/jax/experimental/sparse/_base.py +++ b/jax/experimental/sparse/_base.py @@ -19,8 +19,18 @@ import math import jax from jax._src import core +from jax._src import ffi from jax._src import util from jax._src.typing import Array +from jax._src.lib import gpu_sparse + + +if hasattr(gpu_sparse, "registrations"): + for platform, targets in gpu_sparse.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) class JAXSparse(util.StrictABC): diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 1acfbaf22..c747c0abb 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -12,25 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jaxlib import xla_client +from typing import Any from .plugin_support import import_from_plugin _cuda_linalg = import_from_plugin("cuda", "_linalg") _hip_linalg = import_from_plugin("rocm", "_linalg") -if _cuda_linalg: - for _name, _value in _cuda_linalg.registrations().items(): - xla_client.register_custom_call_target( - _name, _value, platform="CUDA", api_version=1 - ) - xla_client.register_custom_call_as_batch_partitionable( - "cu_lu_pivots_to_permutation") +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cuda_linalg), ("ROCM", _hip_linalg)]: + if module: + registrations[platform].extend( + (*i, 1) for i in module.registrations().items()) + return registrations # pytype: disable=bad-return-type -if _hip_linalg: - for _name, _value in _hip_linalg.registrations().items(): - xla_client.register_custom_call_target( - _name, _value, platform="ROCM", api_version=1 - ) - xla_client.register_custom_call_as_batch_partitionable( - "hip_lu_pivots_to_permutation") + +def batch_partitionable_targets() -> list[str]: + targets = [] + if _cuda_linalg: + targets.append("cu_lu_pivots_to_permutation") + if _hip_linalg: + targets.append("hip_lu_pivots_to_permutation") + return targets diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index a40c6bf93..efb58f9a4 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jaxlib import xla_client +from typing import Any from .plugin_support import import_from_plugin @@ -24,45 +24,39 @@ _hipblas = import_from_plugin("rocm", "_blas") _hipsolver = import_from_plugin("rocm", "_solver") _hiphybrid = import_from_plugin("rocm", "_hybrid") -if _cublas: - for _name, _value in _cublas.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") -if _cusolver: - for _name, _value in _cusolver.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 0 - if _name.endswith("_ffi"): - api_version = 1 - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cublas), ("ROCM", _hipblas)]: + if module: + registrations[platform].extend( + (*i, 0) for i in module.registrations().items()) + for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items() + ) + for platform, module in [("CUDA", _cuhybrid), ("ROCM", _hiphybrid)]: + if module: + registrations[platform].extend( + (*i, 1) for i in module.registrations().items()) + return registrations # pytype: disable=bad-return-type -if _cuhybrid: - for _name, _value in _cuhybrid.registrations().items(): - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=1) -if _hipblas: - for _name, _value in _hipblas.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") +def batch_partitionable_targets() -> list[str]: + targets = [] + for module in [_cusolver, _hipsolver]: + if module: + targets.extend( + name for name in module.registrations() + if name.endswith("_ffi") + ) + for module in [_cuhybrid, _hiphybrid]: + if module: + targets.extend(name for name in module.registrations()) + return targets -if _hipsolver: - for _name, _value in _hipsolver.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 0 - if _name.endswith("_ffi"): - api_version = 1 - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) - -if _hiphybrid: - for _name, _value in _hiphybrid.registrations().items(): - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=1) def initialize_hybrid_kernels(): if _cuhybrid: @@ -70,6 +64,7 @@ def initialize_hybrid_kernels(): if _hiphybrid: _hiphybrid.initialize() + def has_magma(): if _cuhybrid: return _cuhybrid.has_magma() diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index d397557df..d8645041c 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -17,13 +17,12 @@ cusparse wrappers for performing sparse matrix computations in JAX import math from functools import partial +from typing import Any import jaxlib.mlir.ir as ir import numpy as np -from jaxlib import xla_client - from .hlo_helpers import custom_call, mk_result_types_and_shapes from .plugin_support import import_from_plugin @@ -31,17 +30,14 @@ from .plugin_support import import_from_plugin _cusparse = import_from_plugin("cuda", "_sparse") _hipsparse = import_from_plugin("rocm", "_sparse") -if _cusparse: - for _name, _value in _cusparse.registrations().items(): - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) - -if _hipsparse: - for _name, _value in _hipsparse.registrations().items(): - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items()) + return registrations # pytype: disable=bad-return-type cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index c5a59e314..330fcb992 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -12,23 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +from typing import Any -from jaxlib import xla_client +import numpy as np from .cpu import _lapack from .cpu._lapack import eig from .cpu._lapack import schur -for _name, _value in _lapack.registrations().items(): - api_version = 0 - if _name.endswith("_ffi"): - api_version = 1 - xla_client.register_custom_call_as_batch_partitionable(_name) - xla_client.register_custom_call_target( - _name, _value, platform="cpu", api_version=api_version - ) - EigComputationMode = eig.ComputationMode SchurComputationMode = schur.ComputationMode @@ -43,6 +34,17 @@ LAPACK_DTYPE_PREFIX = { } +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + return {"cpu": [ + (name, value, int(name.endswith("_ffi"))) + for name, value in _lapack.registrations().items() + ]} + + +def batch_partitionable_targets() -> list[str]: + return [name for name in _lapack.registrations() if name.endswith("_ffi")] + + def prepare_lapack_call(fn_base, dtype): """Initializes the LAPACK library and returns the LAPACK target name.""" _lapack.initialize() From d2bf034c4731902fc1f88cf1c1cc07a3b8f286ee Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 10 Mar 2025 08:25:03 -0700 Subject: [PATCH 082/100] [Mosaic GPU] Test the wgmma_op lowering when a is in registers. I had to add support for wgmma layout in vector_load. Not sure if this is useful outside the test. PiperOrigin-RevId: 735384104 --- .../mosaic/gpu/dialect_lowering.py | 32 +++++++++++++------ tests/mosaic/gpu_test.py | 21 ++++++++++-- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index eb5dcbc39..368b47df4 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -235,11 +235,6 @@ def _vector_load_op_lowering_rule( ir.ArrayAttr, vector_load_op.attributes["out_layouts"] ) - if not layouts.is_strided_fragmented_layout(out_layout_attr): - raise ValueError( - f"{vector_load_op} has an unsupported layout: {out_layout_attr}" - ) - for i in vector_load_op.indices: index_defining_op = i.owner.opview if ( @@ -254,10 +249,29 @@ def _vector_load_op_lowering_rule( element_type = vector_load_op.result.type.element_type is_signed = False if ir.IntegerType.isinstance(element_type) else None - strided_layout = layouts.from_strided_fragmented_layout_attr(out_layout_attr) - fragmented_array = fa.FragmentedArray.load_strided( - vector_load_op.base, is_signed=is_signed, vec_size=strided_layout.vec_size - ) + + if layouts.is_strided_fragmented_layout(out_layout_attr): + strided_layout = layouts.from_strided_fragmented_layout_attr( + out_layout_attr + ) + fragmented_array = fa.FragmentedArray.load_strided( + vector_load_op.base, + is_signed=is_signed, + vec_size=strided_layout.vec_size, + ) + elif layouts.is_wgmma_fragmented_layout(out_layout_attr): + layout = ir.MemRefType(vector_load_op.base.type).layout + swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout) + transformed_ref = transform_memref(vector_load_op.base, transforms) + fragmented_array = fa.FragmentedArray.load_tiled( + transformed_ref, + swizzle=swizzle, + is_signed=is_signed + ) + else: + raise ValueError( + f"{vector_load_op} has an unsupported layout: {out_layout_attr}" + ) return [_fragmented_array_to_ir(fragmented_array, vector_load_op.result.type)] diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index a98e3f9bb..cc654eb2b 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2755,6 +2755,7 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): transforms_b: tuple[Tile | Transpose | Swizzle, ...] = () transpose_a: bool = False transpose_b: bool = False + load_a_in_registers: bool = False result = [] for swizzle in [ @@ -2786,6 +2787,13 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): transforms_a=[Tile([64, k]), Swizzle(swizzle)], transforms_b=[Tile([k, k]), Swizzle(swizzle)], ), + TestCaseInput( + shape_a=[groups_m * 64, groups_k * k], + shape_b=[groups_k * k, groups_n * k], + shape_res=[groups_m * 64, groups_n * k], + transforms_a=[Tile([64, k]), Swizzle(swizzle)], + load_a_in_registers=True, + ), ]) # The below only works for 128-byte swizzling. Regardless of transposing, # TMA needs the size of the last dimension to be compatible with the @@ -2849,6 +2857,14 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): parity, _ = tma_barrier.update_parities(parities) mgpu_dialect.wait(dialect_barrier, parity) + # SMEM -> Registers + a_operand = a_smem_ref + zero_index = arith.constant(ir.IndexType.get(), 0) + if test_case.load_a_in_registers: + a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type) + zero_vector_indices = [zero_index] * len(test_case.shape_a) + a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices) + # Computation shape_result = ir.MemRefType(result_gmem_ref.type).shape result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type @@ -2860,7 +2876,7 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): ) result = mgpu_dialect.wgmma( accumulator, - a_smem_ref, + a_operand, b_smem_ref, transpose_a=test_case.transpose_a, transpose_b=test_case.transpose_b, @@ -2870,8 +2886,7 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): nvvm.wgmma_wait_group_sync_aligned(0) # Registers -> SMEM - zero_index = arith.constant(ir.IndexType.get(), 0) - vector.store(result, result_smem_ref, [zero_index, zero_index]) + vector.store(result, result_smem_ref, [zero_index] * len(shape_result)) # SMEM -> GMEM mgpu_dialect.async_store( From 1bab037ca0b3ae29a0fbf1c8b784937c8f299076 Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 10 Mar 2025 17:37:00 +0100 Subject: [PATCH 083/100] Add file and zip to tsan.yaml --- .github/workflows/tsan.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 1de765df0..2940d3dd2 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -35,7 +35,7 @@ jobs: apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \ zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ - libffi-dev liblzma-dev + libffi-dev liblzma-dev file zip - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax From 5cb29949d413cea6adbf936ff249049fa81ba1c6 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 10 Mar 2025 10:37:11 -0700 Subject: [PATCH 084/100] Warn the user if transparent huge pages aren't enabled. PiperOrigin-RevId: 735431881 --- jax/_src/cloud_tpu_init.py | 15 ++++++++++- jax/_src/hardware_utils.py | 53 ++++++++++++++++++++++++++------------ pyproject.toml | 3 +++ 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 82219886b..0539e4253 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -15,6 +15,7 @@ import datetime import os import re +import warnings from jax import version from jax._src import config from jax._src import hardware_utils @@ -72,7 +73,19 @@ def cloud_tpu_init() -> None: # Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed. libtpu_path = get_tpu_library_path() - num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0] + num_tpu_chips, tpu_id = hardware_utils.num_available_tpu_chips_and_device_id() + if ( + tpu_id is not None + and tpu_id >= hardware_utils.TpuVersion.v5e + and not hardware_utils.transparent_hugepages_enabled() + ): + warnings.warn( + 'Transparent hugepages are not enabled. TPU runtime startup and' + ' shutdown time should be significantly improved on TPU v5e and newer.' + ' If not already set, you may need to enable transparent hugepages in' + ' your VM image (sudo sh -c "echo always >' + ' /sys/kernel/mm/transparent_hugepage/enabled")' + ) if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init(): return diff --git a/jax/_src/hardware_utils.py b/jax/_src/hardware_utils.py index 81ef07a71..84ad9edf9 100644 --- a/jax/_src/hardware_utils.py +++ b/jax/_src/hardware_utils.py @@ -12,25 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import enum import os import pathlib import glob _GOOGLE_PCI_VENDOR_ID = '0x1ae0' -_TPU_PCI_DEVICE_IDS = [ - # TPU v2, v3 - '0x0027', - # No public name (plc) - '0x0056', - # TPU v4 - '0x005e', - # TPU v5p - '0x0062', - # TPU v5e - '0x0063', - # TPU v6e - '0x006f', -] _NVIDIA_GPU_DEVICES = [ '/dev/nvidia0', @@ -38,10 +25,36 @@ _NVIDIA_GPU_DEVICES = [ '/dev/dxg', # WSL2 ] + +class TpuVersion(enum.IntEnum): + # TPU v2, v3 + v2 = 0 + v3 = 1 + # No public name (plc) + plc = 2 + # TPU v4 + v4 = 3 + # TPU v5p + v5p = 4 + # TPU v5e + v5e = 5 + # TPU v6e + v6e = 6 + + +_TPU_PCI_DEVICE_IDS = { + '0x0027': TpuVersion.v3, + '0x0056': TpuVersion.plc, + '0x005e': TpuVersion.v4, + '0x0062': TpuVersion.v5p, + '0x0063': TpuVersion.v5e, + '0x006f': TpuVersion.v6e, +} + def num_available_tpu_chips_and_device_id(): """Returns the device id and number of TPU chips attached through PCI.""" num_chips = 0 - device_id = '' + tpu_version = None for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'): vendor_id = pathlib.Path(vendor_path).read_text().strip() if vendor_id != _GOOGLE_PCI_VENDOR_ID: @@ -50,12 +63,20 @@ def num_available_tpu_chips_and_device_id(): device_path = os.path.join(os.path.dirname(vendor_path), 'device') device_id = pathlib.Path(device_path).read_text().strip() if device_id in _TPU_PCI_DEVICE_IDS: + tpu_version = _TPU_PCI_DEVICE_IDS[device_id] num_chips += 1 - return num_chips, device_id + return num_chips, tpu_version def has_visible_nvidia_gpu() -> bool: """True if there's a visible nvidia gpu available on device, False otherwise.""" return any(os.path.exists(d) for d in _NVIDIA_GPU_DEVICES) + + +def transparent_hugepages_enabled() -> bool: + # See https://docs.kernel.org/admin-guide/mm/transhuge.html for more + # information about transparent huge pages. + path = pathlib.Path('/sys/kernel/mm/transparent_hugepage/enabled') + return path.exists() and path.read_text().strip() == '[always] madvise never' diff --git a/pyproject.toml b/pyproject.toml index e32b14a89..a1b9e7dd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,9 @@ filterwarnings = [ # https://github.com/protocolbuffers/protobuf/issues/12186#issuecomment-1745679358 "ignore:Type google\\._upb\\._message\\.(Scalar|Message)MapContainer uses PyType_Spec with a metaclass that has custom tp_new\\. This is deprecated and will no longer be allowed in Python 3\\.14\\.:DeprecationWarning", + # TODO(b/401588349): Remove this once transparent hugepages are enabled. + "ignore:Transparent hugepages", + # NOTE: this is probably not where you want to add code to suppress a # warning. Only pytest tests look at this list, whereas Bazel tests also # check for warnings and do not check this list. Most likely, you should From d41e96835b310a534205023c5c95d481e7f3a1ac Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 10 Mar 2025 11:09:40 -0700 Subject: [PATCH 085/100] Modify version test to consider "rc" versions as well I was testing the RC promotion workflow and found that the version test failed as it does not consider pre-releases. Therefore, this commit modifies the `VERSION_PATTERN` to also consider "rc" wheels. Fixes https://github.com/jax-ml/jax/actions/runs/13705984545/job/38331236497 PiperOrigin-RevId: 735444828 --- tests/version_test.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/version_test.py b/tests/version_test.py index 1036d958f..b78e61ae0 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -24,11 +24,15 @@ import jax from jax._src.lib import check_jaxlib_version from jax._src import test_util as jtu -# This is a subset of the full PEP440 pattern; for example we skip pre & post releases +# This is a subset of the full PEP440 pattern; for example we skip post releases VERSION_PATTERN = re.compile(r""" ^ # start of string (?P[0-9]+\.[0-9]+\.[0-9]+) # main version; like '0.4.16' - (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' + (?: + (?:rc(?P[0-9]+))? # optional rc version; like 'rc1' + | # or + (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' + )? (?:\+(?P[a-zA-Z0-9_.]+))? # optional local version; like '+g6643af3c3' $ # end of string """, re.VERBOSE) @@ -170,6 +174,18 @@ class JaxVersionTest(unittest.TestCase): self.assertEqual(version, f"{base_version}.dev20250101+1c0f1076erc1") self.assertValidVersion(version) + with jtu.set_env( + JAX_RELEASE="1", + JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, + JAXLIB_NIGHTLY=None, + WHEEL_VERSION_SUFFIX="rc0", + ): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertEqual(version, f"{base_version}rc0") + self.assertValidVersion(version) + def testVersions(self): check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3", minimum_jaxlib_version="1.2.3") From 007fc7a6f12dcab3d057c184a64ea05519cf6366 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 10 Mar 2025 11:34:16 -0700 Subject: [PATCH 086/100] Remove version limit for `setuptools` dependency. PiperOrigin-RevId: 735453796 --- build/requirements.in | 3 +-- build/requirements_lock_3_10.txt | 6 +++--- build/requirements_lock_3_11.txt | 6 +++--- build/requirements_lock_3_12.txt | 6 +++--- build/requirements_lock_3_13.txt | 6 +++--- build/test-requirements.txt | 3 +-- 6 files changed, 14 insertions(+), 16 deletions(-) diff --git a/build/requirements.in b/build/requirements.in index e122aaa4a..d4e13d943 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -18,5 +18,4 @@ ml_dtypes>=0.4.0 opt_einsum zstandard etils[epath] -# TODO(ybaturina): remove setuptools version -setuptools<71.0.0 +setuptools diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index ccffa247f..290c7e732 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -634,9 +634,9 @@ zstandard==0.22.0 \ # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==69.5.1 \ - --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ - --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 +setuptools==76.0.0 \ + --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ + --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 # via # -r build/requirements.in # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 7f3ee61ff..f73065950 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -623,9 +623,9 @@ zstandard==0.22.0 \ # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==69.5.1 \ - --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ - --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 +setuptools==76.0.0 \ + --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ + --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 # via # -r build/requirements.in # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index bf22c3623..feebc33dc 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -623,9 +623,9 @@ zstandard==0.22.0 \ # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==69.5.1 \ - --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ - --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 +setuptools==76.0.0 \ + --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ + --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 # via # -r build/requirements.in # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 9fa78c062..0a32888f6 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -747,9 +747,9 @@ zstandard==0.23.0 \ # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==70.3.0 \ - --hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \ - --hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc +setuptools==76.0.0 \ + --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ + --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 # via # -r build/requirements.in # -r build/test-requirements.txt diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 19d713532..3b36900c0 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -12,8 +12,7 @@ portpicker; python_version<"3.13" pytest-xdist wheel rich -# TODO(ybaturina): remove setuptools version -setuptools<71.0.0 +setuptools # matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement # below. matplotlib~=3.8.4; python_version=="3.10" From 8ecadfdf9d6ed4c4dc181012ce5d8b2db8b12526 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 10 Mar 2025 11:37:50 -0700 Subject: [PATCH 087/100] Internal: make it easier to detect the vmap sentinel --- jax/_src/api_util.py | 5 ++++- tests/api_test.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index b9fd505e3..a42141b96 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -382,6 +382,9 @@ def is_hashable(arg): return False +SENTINEL = object() + + def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): # given an axis spec tree axis_tree (a pytree with integers and Nones at the # leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of @@ -389,7 +392,7 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): # and return the flattened result # TODO(mattjj,phawkins): improve this implementation proxy = object() - dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves) + dummy = tree_unflatten(treedef, [SENTINEL] * treedef.num_leaves) axes = [] add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0])) try: diff --git a/tests/api_test.py b/tests/api_test.py index 35e92f748..c9cf28e0a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -20,6 +20,7 @@ from collections.abc import Callable import concurrent.futures from contextlib import contextmanager import copy +import dataclasses import enum import functools from functools import partial @@ -1986,6 +1987,37 @@ class APITest(jtu.JaxTestCase): ): jax.vmap(f)(jnp.ones(4), jnp.ones(2), jnp.ones(2)) + def test_vmap_sentinel(self): + + @jax.tree_util.register_dataclass + @dataclasses.dataclass + class Foo: + x: jax.Array + + def __init__(self, x): + nonlocal saw_sentinel + if x is jax._src.api_util.SENTINEL: + saw_sentinel += 1 + self.x = x + + x = jnp.arange(10) + + # assert that sentinel is seen once for vmap in_axes + saw_sentinel = 0 + jax.vmap(lambda f: f.x)(Foo(x)) + self.assertEqual(saw_sentinel, 1) + + # assert that sentinel is seen once for vmap out_axes + saw_sentinel = 0 + jax.vmap(Foo)(x) + self.assertEqual(saw_sentinel, 1) + + # assert that sentinel is seen twice with vmap in_axes and out_axes + saw_sentinel = 0 + jax.vmap(lambda f: Foo(f.x + 1))(Foo(x)) + self.assertEqual(saw_sentinel, 2) + + def test_device_get_scalar(self): x = np.arange(12.).reshape((3, 4)).astype("float32") x = api.device_put(x) From 73d20cd62ac58b0df2cb8409362352b2bd711dd9 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Mon, 10 Mar 2025 11:39:14 -0700 Subject: [PATCH 088/100] [Pallas] Small fix to TPU interpret mode (input_output_aliases + scalar args). PiperOrigin-RevId: 735455671 --- jax/_src/pallas/mosaic/interpret.py | 4 +++- tests/pallas/tpu_pallas_interpret_test.py | 21 +++++++++++++-------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 4f2d0b4f3..e2086d6af 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -1346,7 +1346,9 @@ def interpret_pallas_call( output_buffer_ids = [] output_buffer_shapes = [] output_vals = _initialize_output_vals( - grid_mapping.block_mappings_output, args, input_output_aliases) + grid_mapping.block_mappings_output, + scalars + input_args, + input_output_aliases) num_outputs = grid_mapping.num_outputs output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] for out_val, bs in zip(output_vals, output_block_shapes): diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 632dba949..1219f37fb 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -68,25 +68,30 @@ class InterpretTest(jtu.JaxTestCase): z = matmul(x, y) np.testing.assert_allclose(z, x @ y, atol=1e-4) - def test_dynamic_grid(self): - def kernel(x_ref, o_ref): - o_ref[...] = x_ref[...] + def test_dynamic_grid_and_aliasing(self): + def kernel(s_ref, x_ref, o_ref): + o_ref[...] = x_ref[...] + s_ref[0].astype(x_ref.dtype) iters = jax.random.randint(jax.random.key(0), (), 10, 20, dtype=jnp.int32) @jax.jit - def f(x): + def f(s, x): return pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), grid=(iters,), - in_specs=(pl.BlockSpec(x.shape, lambda i: (0, 0)),), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(x.shape, lambda i: (0, 0)), + ], out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)), + input_output_aliases={1: 0}, interpret=mosaic_interpret.TPUInterpretParams() - )(x) + )(s, x) + s = jnp.array([1], dtype=jnp.int32) x = jnp.arange(32 * 128.).reshape((32, 128)) - y = f(x) - np.testing.assert_allclose(y, x) + y = f(s, x) + np.testing.assert_allclose(y, x + 1.0) @parameterized.parameters('eager', 'on_wait') def test_race_detection(self, dma_execution_mode): From b6d4fe53876a4a2ec785a3ab55dc7ff5b5fd8812 Mon Sep 17 00:00:00 2001 From: Praveen Narayanan Date: Mon, 10 Mar 2025 12:24:38 -0700 Subject: [PATCH 089/100] Define lax.ragged_dot_general and express lax.ragged_dot in terms of it. PiperOrigin-RevId: 735471245 --- jax/_src/lax/lax.py | 681 +++++++++++++++++++++++------- jax/experimental/jax2tf/jax2tf.py | 1 + jax/lax/__init__.py | 2 + tests/lax_test.py | 223 ++++++++++ 4 files changed, 761 insertions(+), 146 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9878ddc3c..99760099d 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -16,6 +16,7 @@ from __future__ import annotations import builtins from collections.abc import Callable, Sequence +import dataclasses import enum import functools from functools import partial @@ -2227,10 +2228,122 @@ def ragged_dot( Results: (m, n) shaped array with preferred_element_type element type. """ - return ragged_dot_p.bind(lhs, rhs, group_sizes, - precision=canonicalize_precision(precision), - preferred_element_type=preferred_element_type, - group_offset=group_offset) + return ragged_dot_general( + lhs, + rhs, + group_sizes, + ragged_dot_dimension_numbers=_BASIC_RAGGED_DOT_DIMENSION_NUMBERS, + precision=canonicalize_precision(precision), + preferred_element_type=preferred_element_type, + group_offset=group_offset, + ) + + +@dataclasses.dataclass(frozen=True) +class RaggedDotDimensionNumbers(): + """Describes ragged, group, and dot dimensions for ragged dot general. + + Args: + dot_dimension_numbers: a tuple of tuples of sequences of ints of the form + `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, + rhs_batch_dims))`. + lhs_ragged_dimensions: a sequence of ints indicating the 'lhs' ragged + dimensions. + rhs_group_dimensions: a sequence of ints indicating the 'rhs' group + dimensions. + """ + dot_dimension_numbers: DotDimensionNumbers + lhs_ragged_dimensions: Sequence[int] + rhs_group_dimensions: Sequence[int] + + def __init__( + self, dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions + ): + super().__setattr__( + 'dot_dimension_numbers', + tuple(tuple(map(tuple, t)) for t in dot_dimension_numbers), + ) + super().__setattr__('lhs_ragged_dimensions', tuple(lhs_ragged_dimensions)) + super().__setattr__('rhs_group_dimensions', tuple(rhs_group_dimensions)) + + +def _from_maybe_ragged( + dot_dimension_numbers: RaggedDotDimensionNumbers | DotDimensionNumbers, +) -> DotDimensionNumbers: + return ( + dot_dimension_numbers.dot_dimension_numbers + if isinstance(dot_dimension_numbers, RaggedDotDimensionNumbers) + else dot_dimension_numbers + ) + + +# RaggedDotDimensionNumbers that specify the simple case (i.e., lax.ragged_dot.) +_BASIC_RAGGED_DOT_DIMENSION_NUMBERS = RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], +) + + +def ragged_dot_general( + lhs: Array, + rhs: Array, + group_sizes: Array, + ragged_dot_dimension_numbers: RaggedDotDimensionNumbers, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None, + group_offset: Array | None = None, +) -> Array: + """Ragged matrix multiplication. + + Ragged dot takes three arrays---``lhs``, ``rhs``, and ``group_sizes``---and + a ``ragged_dot_dimension_numbers`` argument. Like `dot_general`, ``lhs`` and + ``rhs`` are allowed arbitrary batch and contracting dimensions. Additionally, + ``lhs`` is required to have one ragged dimension, and ``rhs`` may have at + most one group dimension. + + Let `g` be the number of groups in the lhs ragged dimension. Ragged dot has + three modes, depending on the kind of the lhs ragged dimension: + 1. `[b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]`. + Here the ragged dimension is a non-contracting dimension (`m`) of ``lhs``, + and `x...` are the lhs non-contracting dims outer to the ragged dim. + 2. `[b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]`. + Here the ragged dimension is a contracting dimension (`k`) of ``lhs`` and + ``rhs``, and `x...` are the lhs contracting dims outer to the ragged dim. + 3. `[b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]`. + Here the ragged dimension is a batch dimension (`b`) of ``lhs`` and + ``rhs``, and `x...` are the lhs batch dims outer to the ragged dim. + If ``group_sizes`` is passed-in with shape `[g]`, it is broadcasted according + to the rules above. + + Args: + lhs: an array + rhs: an array + group_sizes: an array with integer element type + ragged_dot_dimension_numbers: a ``RaggedDotDimensionNumbers`` object to + specify the dot dimension numbers, lhs ragged dimension, and rhs group + dimension. + precision: Optional. Consistent with precision argument for + :func:`jax.lax.dot`. + preferred_element_type: Optional. Consistent with precision argument for + :func:`jax.lax.dot`. + group_offset: Optional. (1,) shaped array that indicates the group in + group_sizes to start computing from. If not specified, defaults to [0]. + + Results: + An array whose shape is the same as that produced by `dot_general`, with an + extra leading dimension of size `g` in the case where the lhs ragged + dimension is a contracting dimension. + """ + return ragged_dot_general_p.bind( + lhs, + rhs, + group_sizes, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, + precision=canonicalize_precision(precision), + preferred_element_type=preferred_element_type, + group_offset=group_offset, + ) def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None @@ -4593,7 +4706,7 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, out_sharding): if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers) if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) for d in (lhs_contracting, lhs_batch)): msg = ("dot_general requires lhs dimension numbers to be nonnegative and " @@ -4654,12 +4767,17 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers) def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers): - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers) batch_shape = tuple(lhs_shape[i] for i in lhs_batch) lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch))) lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch) - rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch))) - rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch) + rhs_group = () + if isinstance(dimension_numbers, RaggedDotDimensionNumbers): + rhs_group = tuple(dimension_numbers.rhs_group_dimensions) + rhs_contract_or_batch_or_group = tuple( + sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group) + ) + rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch_or_group) return batch_shape + lhs_tensored_shape + rhs_tensored_shape @@ -4723,7 +4841,7 @@ def tuple_delete(tup, idx): def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, - out_sharding): + out_sharding, name: str = 'lax.dot_general'): if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError del dimension_numbers # unused @@ -4744,8 +4862,7 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, result_dtype = rhs.dtype else: if lhs.dtype != rhs.dtype: - raise TypeError( - f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}") + raise TypeError(f'{name} argument type error: {lhs.dtype}, {rhs.dtype}') result_dtype = lhs.dtype has_algorithm = isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)) return _maybe_upcast(result_dtype, preferred_element_type, @@ -4884,8 +5001,9 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers): # explicitly present dimensions that this dot_general is zipping together. lbd, rbd = batch_dims assert lbd is not None or rbd is not None - (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers) + is_ragged_dot = isinstance(dimension_numbers, RaggedDotDimensionNumbers) def bump_dims(dims, b): return tuple(np.add(dims, np.greater_equal(dims, b))) @@ -4908,8 +5026,14 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers): elif (type(rbd) is int and lbd is None): # The right vmapped dimension becomes an additional tensor dimension in the # batched dot_general. - rhs_tensor = [d for d in range(rhs_ndim) - if d not in rhs_batch and d not in rhs_contract] + rhs_tensor = list( + remaining( + range(rhs_ndim), + rhs_batch, + rhs_contract, + dimension_numbers.rhs_group_dimensions if is_ragged_dot else [], + ) + ) result_batch_dim = (lhs_ndim - len(lhs_contract) + int(sum(np.less(rhs_tensor, rbd)))) rhs_batch = bump_dims(rhs_batch, rbd) @@ -4919,6 +5043,16 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers): assert False new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) + if is_ragged_dot: + new_dimension_numbers = RaggedDotDimensionNumbers( + dot_dimension_numbers=new_dimension_numbers, + lhs_ragged_dimensions=bump_dims( + dimension_numbers.lhs_ragged_dimensions, lbd + ), + rhs_group_dimensions=bump_dims( + dimension_numbers.rhs_group_dimensions, rbd + ), + ) return new_dimension_numbers, result_batch_dim def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *, @@ -5010,15 +5144,6 @@ def _dot_general_batch_unpack_dims(batch_dims): lbd, rbd = batch_dims return (lbd, rbd) -# DotDimensionNumbers used in the dot_general call for ragged_dot(). -_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = ( - ([2, 0], [1, 0]), - ([], []), -) -_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = ( - ([3, 1], [2, 1]), - ([0], [0]), -) ad.defbilinear(dot_general_p, _dot_general_transpose_lhs, _dot_general_transpose_rhs) @@ -5186,58 +5311,181 @@ for platform in ["cpu", "tpu"]: platform=platform) -def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape: - if len(lhs.shape) == 3: - # Batched case - b, m, k = lhs.shape - b2, group_count, rk, n = rhs.shape - b3 = group_sizes.shape[0] - if b != b2: - raise TypeError( - f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and' - f' {b2}.' - ) - if b3 != b: - raise TypeError( - 'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got' - f' {b3} and {b}.' - ) - if k != rk: - raise TypeError( - f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and' - f' {rk}.' - ) - num_groups = group_sizes.shape[1] - if group_count != num_groups: - raise TypeError( - 'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got' - f' {group_count} and {num_groups}.' - ) - return (b, m, n) +class RaggedDotMode(enum.Enum): + RAGGED_NONCONTRACTING = 1 # [b,m,k], [g,b,k,n], [b,g] -> [b,m,n] + RAGGED_CONTRACTING = 2 # [b,m,k], [b,k,n], [b,g] -> [g,b,m,n] + RAGGED_BATCH = 3 # [b,m,k], [b,k,n], [g] -> [b,m,n] - m, k = lhs.shape - group_count, rk, n = rhs.shape - if k != rk: - raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.") - num_groups = group_sizes.shape[0] - if group_count != num_groups: - raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.") - return (m, n) -def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array, - precision, preferred_element_type: DTypeLike | None, - **_) -> np.dtype: +def _ragged_dot_mode_and_dim( + lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers +) -> tuple[RaggedDotMode, int]: + assert len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) == 1 + lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0] + (lhs_contracting, _), (lhs_batch, _) = ragged_dot_dimension_numbers.dot_dimension_numbers + lhs_noncontracting = remaining(range(lhs_rank), lhs_contracting, lhs_batch) + if lhs_ragged_dim in lhs_noncontracting: + mode = RaggedDotMode.RAGGED_NONCONTRACTING + elif lhs_ragged_dim in lhs_contracting: + mode = RaggedDotMode.RAGGED_CONTRACTING + elif lhs_ragged_dim in lhs_batch: + mode = RaggedDotMode.RAGGED_BATCH + else: + raise TypeError( + f'lhs_ragged_dim {lhs_ragged_dim} not found in ' + f'lhs_noncontracting {lhs_noncontracting}, ' + f'lhs_contracting {lhs_contracting}, or ' + f'lhs_batch {lhs_batch}.' + ) + return mode, lhs_ragged_dim + + +def _ragged_dot_mode( + lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers +) -> RaggedDotMode: + return _ragged_dot_mode_and_dim(lhs_rank, ragged_dot_dimension_numbers)[0] + + +def _is_ragged_contracting( + lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers +) -> bool: + return ( + _ragged_dot_mode(lhs_rank, ragged_dot_dimension_numbers) + == RaggedDotMode.RAGGED_CONTRACTING + ) + + +def _ragged_dot_prefix_dims(mode, rank, ragged_dim, batch, contract): + batch, contract = map(list, (batch, contract)) + noncontract = remaining(range(rank), contract, batch) + match mode: + case RaggedDotMode.RAGGED_NONCONTRACTING: + return batch + noncontract[: noncontract.index(ragged_dim)] + case RaggedDotMode.RAGGED_CONTRACTING: + return batch + contract[: contract.index(ragged_dim)] + case RaggedDotMode.RAGGED_BATCH: + return batch[: batch.index(ragged_dim)] + + +def _ragged_dot_general_shape_rule( + lhs, + rhs, + group_sizes, + *, + ragged_dot_dimension_numbers, + precision, + preferred_element_type: DTypeLike | None, + **_, +): + def _check_in_range(dim, rank, dim_name, arg_name): + if dim < 0 or dim >= rank: + raise TypeError( + f'ragged_dot_general requires {dim_name} numbers to be nonnegative ' + f'and less than the number of axes of the {arg_name} value, ' + f'got {dim} for {arg_name} of rank {rank}.' + ) + + # Validate the lhs ragged dimension, and find out which mode we're in. + if len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) != 1: + raise TypeError( + 'ragged_dot_general expects exactly one lhs ragged dimension.' + ) + lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0] + _check_in_range(lhs_ragged_dim, lhs.ndim, 'lhs ragged dimension', 'lhs') + mode = _ragged_dot_mode(lhs.ndim, ragged_dot_dimension_numbers) + + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = ( + ragged_dot_dimension_numbers.dot_dimension_numbers + ) + + # Validate the shape of group_sizes, if it is something other than [g]. + if group_sizes.ndim == 0: + raise TypeError('expected rank of group_sizes to be >=1.') + if group_sizes.ndim != 1: + # Construct the expected shape [b...,x...,g] of group_sizes. + prefix_dims = _ragged_dot_prefix_dims( + mode, lhs.ndim, lhs_ragged_dim, lhs_batch, lhs_contracting + ) + expected_gs_shape = tuple(lhs.shape[i] for i in prefix_dims) + expected_gs_shape += (group_sizes.shape[-1],) + # TODO(pravnar): Permit other broadcastable shapes. + if not core.definitely_equal_shape(group_sizes.shape, expected_gs_shape): + raise TypeError( + 'expected group_sizes to have shape ' + f'{expected_gs_shape}, got {group_sizes.shape}.' + ) + num_groups = group_sizes.shape[-1] + + # Validate properties of the rhs group dimension(s). + rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions + match mode: + case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: + if len(rhs_group_dims) != 0: + raise TypeError( + 'ragged_dot_general requires zero group dimensions in the rhs ' + 'when lhs ragged dimension is contracting or batch.' + ) + case RaggedDotMode.RAGGED_NONCONTRACTING: + if len(rhs_group_dims) != 1: + raise TypeError( + 'ragged_dot_general requires exactly one rhs group dimension ' + 'when lhs ragged dimension is noncontracting.' + ) + rhs_group_dim = rhs_group_dims[0] + _check_in_range(rhs_group_dim, rhs.ndim, 'rhs group dimension', 'rhs') + if rhs_group_dim in rhs_batch or rhs_group_dim in rhs_contracting: + raise TypeError( + 'ragged_dot_general requires rhs group dimension numbers to be ' + 'distinct from contracting and batch dimensions.' + ) + if rhs.shape[rhs_group_dim] != num_groups: + raise TypeError( + 'expected rhs group dimension size to be ' + f'{num_groups}, got {rhs.shape[rhs_group_dim]}.' + ) + + out_shape = _dot_general_shape_rule( + lhs, + rhs, + dimension_numbers=ragged_dot_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + out_sharding=None, + ) + if mode == RaggedDotMode.RAGGED_CONTRACTING: + out_shape = (num_groups,) + out_shape + return out_shape + + +def _ragged_dot_general_dtype_rule( + lhs: Array, + rhs: Array, + group_sizes: Array, + ragged_dot_dimension_numbers: RaggedDotDimensionNumbers, + precision, + preferred_element_type: DTypeLike | None, + **_, +) -> np.dtype: if not dtypes.issubdtype(group_sizes.dtype, np.integer): - raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.") - # defer the output dtype to dot_general, which is part of the _ragged_dot_impl. + raise TypeError( + 'ragged_dot_general requires that ' + 'group_sizes.dtype is subtype of np.integer.' + ) + # defer the output dtype to dot_general, which is part of the _ragged_dot_general_impl. return _dot_general_dtype_rule( - lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, - precision=precision, preferred_element_type=preferred_element_type, - out_sharding=None) + lhs, + rhs, + dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + out_sharding=None, + name='lax.ragged_dot_general', + ) -def _ragged_dot_jvp_rule( - primals, tangents, precision, preferred_element_type, group_offset +def _ragged_dot_general_jvp_rule( + primals, tangents, ragged_dot_dimension_numbers, + precision, preferred_element_type, group_offset ): # note - we could ostensibly just get this by passing on the # value to ragged_dot below, but, this feels cleaner. @@ -5247,20 +5495,22 @@ def _ragged_dot_jvp_rule( dx, dy, _ = tangents # no tan on the gs # primal - primal_out = ragged_dot( + primal_out = ragged_dot_general( x, y, gs, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, ) # tangent dx_out = ( - ragged_dot( + ragged_dot_general( dx, y, gs, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, ) @@ -5268,10 +5518,11 @@ def _ragged_dot_jvp_rule( else _zeros(primal_out) ) dy_out = ( - ragged_dot( + ragged_dot_general( x, dy, gs, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, ) @@ -5283,58 +5534,111 @@ def _ragged_dot_jvp_rule( return primal_out, tangent_out -def _ragged_to_dense(x, y, group_sizes): - from jax._src.lax import control_flow # avoid circular imports - shape = (y.shape[0], x.shape[0], x.shape[1]) - x = broadcast_in_dim(x, shape, [1, 2]) - iota = broadcasted_iota(group_sizes.dtype, shape, 1) - group_ends = control_flow.cumsum(group_sizes) - group_starts = concatenate( - [_zeros(group_sizes)[:1], group_ends[:-1]], - dimension=0, - ) - group_ends = broadcast_in_dim(group_ends, shape, (0,)) - group_starts = broadcast_in_dim(group_starts, shape, (0,)) - mask = bitwise_and(group_starts <= iota, iota < group_ends) - x = select(mask, x, _zeros(x)) - return x - - -def _ragged_dot_transpose_rule( - ct, *operands, precision, preferred_element_type, group_offset +def _ragged_dot_general_transpose_rule( + ct, + x, + y, + group_sizes, + *, + ragged_dot_dimension_numbers, + precision, + preferred_element_type: DTypeLike | None, + group_offset: Array | None, ): - x, y, gs = operands if group_offset is not None: raise NotImplementedError('Unimplemented group_offset support.') - if ad.is_undefined_primal(y): - grad_x = None - else: - y_t = _matrix_transpose(y) - grad_x = ragged_dot( - ct, - y_t, - gs, - precision=precision, - preferred_element_type=preferred_element_type, - ) + (x_contract, y_contract), (x_batch, y_batch) = ragged_dot_dimension_numbers.dot_dimension_numbers + x_ndim = x.aval.ndim if ad.is_undefined_primal(x) else np.ndim(x) + y_ndim = y.aval.ndim if ad.is_undefined_primal(y) else np.ndim(y) + x_kept = remaining(range(x_ndim), x_contract, x_batch) + y_group = ragged_dot_dimension_numbers.rhs_group_dimensions + y_kept = remaining(range(y_ndim), y_contract, y_batch, y_group) + mode, lhs_ragged_dim = _ragged_dot_mode_and_dim( + x_ndim, ragged_dot_dimension_numbers + ) - if ad.is_undefined_primal(x): - grad_y = None - else: - y = y.aval if ad.is_undefined_primal(y) else y - x_dense = _ragged_to_dense(x, y, group_sizes=gs) - ct_dense = _ragged_to_dense(ct, y, group_sizes=gs) - dimension_numbers = (([1], [1]), ([0], [0])) - grad_y = dot_general( - x_dense, - ct_dense, - dimension_numbers, - precision=precision, - preferred_element_type=preferred_element_type, - ) + unimplemented = lambda fn_name, ragged_dot_mode: NotImplementedError( + f'Unimplemented {fn_name} for ragged dot general in mode ' + f'{ragged_dot_mode.name}.' + ) - return grad_x, grad_y, None + # This is a hack to ensure we continue to emit the `_matrix_transpose` for the + # grad_x case. This isn't strictly necessary since we have dot_dim_nums. + # TODO(pravnar): Remove this once we no longer care to emit the transpose. + _is_basic_ragged_dot = ( + x_ndim == 2 + and y_ndim == 3 + and ragged_dot_dimension_numbers == _BASIC_RAGGED_DOT_DIMENSION_NUMBERS + ) + + def grad_x_dims(): + match mode: + case RaggedDotMode.RAGGED_NONCONTRACTING: + ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept) + dims = ( + ragged_dot_dimension_numbers + if _is_basic_ragged_dot + else RaggedDotDimensionNumbers( + dot_dimension_numbers=((ans_y, y_kept), (ans_batch, y_batch)), + lhs_ragged_dimensions=[ + len(x_batch) + x_kept.index(lhs_ragged_dim) + ], + rhs_group_dimensions=y_group, + ) + ) + x_contract_sorted_by_y = list( + np.take(x_contract, np.argsort(y_contract)) + ) + unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y + case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: + raise unimplemented('grad_x_dims', mode) + return dims, unsorted_axes + + def grad_y_dims(): + match mode: + case RaggedDotMode.RAGGED_NONCONTRACTING: + ans_batch, ans_x, _ = ranges_like(x_batch, x_kept, y_kept) + dims = RaggedDotDimensionNumbers( + dot_dimension_numbers=((x_kept, ans_x), (x_batch, ans_batch)), + lhs_ragged_dimensions=[lhs_ragged_dim], + rhs_group_dimensions=[], + ) + y_contract_sorted_by_x = list( + np.take(y_contract, np.argsort(x_contract)) + ) + unsorted_axes = ( + list(y_group) + list(y_batch) + y_contract_sorted_by_x + y_kept + ) + case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: + raise unimplemented('grad_y_dims', mode) + return dims, unsorted_axes + + def _ragged_dot_grad(lhs, rhs, dims_fn, aval): + dims, unsorted_axes = dims_fn() + ragged_dot_general_out = ragged_dot_general( + lhs, rhs, group_sizes, dims, precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset) + result = transpose(ragged_dot_general_out, tuple(np.argsort(unsorted_axes))) + if result.dtype != aval.dtype: + result = _convert_element_type(result, aval.dtype, aval.weak_type) + return result + + x_bar = ( + None + if ad.is_undefined_primal(y) + else _ragged_dot_grad(ct, + _matrix_transpose(y) if _is_basic_ragged_dot else y, + grad_x_dims, + x.aval) + ) + y_bar = ( + None + if ad.is_undefined_primal(x) + else _ragged_dot_grad(x, ct, grad_y_dims, y.aval) + ) + return x_bar, y_bar, None def _ragged_dot_batch_unpack_args(batched_args): @@ -5349,62 +5653,71 @@ def _ragged_dot_batch_unpack_dims(batch_dims): return (lbd, rbd) -def _ragged_dot_invoke_prim( +def _ragged_dot_general_invoke_prim( group_sizes, lhs, rhs, - new_dimension_numbers, + new_ragged_dot_dimension_numbers, precision, preferred_element_type, out_sharding, ): del out_sharding - return ragged_dot( + return ragged_dot_general( lhs, rhs, group_sizes, + ragged_dot_dimension_numbers=new_ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, ) -def _ragged_dot_batch_rule( +def _ragged_dot_general_batch_rule( axis_data, batched_args, batch_dims, *, + ragged_dot_dimension_numbers, precision, preferred_element_type: DTypeLike | None, **_, ): - invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2]) - - return _dot_batch_rule( + invoke = partial(_ragged_dot_general_invoke_prim, batched_args[2]) + batched_out, result_batch_dim = _dot_batch_rule( _ragged_dot_batch_unpack_args, _ragged_dot_batch_unpack_dims, invoke, axis_data, batched_args, batch_dims, - dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + dimension_numbers=ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, out_sharding=None, ) + if _is_ragged_contracting(batched_args[0].ndim - 1, + ragged_dot_dimension_numbers): + result_batch_dim += 1 + return batched_out, result_batch_dim -ragged_dot_p = standard_primitive(_ragged_dot_shape_rule, - _ragged_dot_dtype_rule, 'ragged_dot') -ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p)) -ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule -ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule -batching.fancy_primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule -batching.skippable_batchers[ragged_dot_p] = lambda _: () +ragged_dot_general_p = standard_primitive( + _ragged_dot_general_shape_rule, + _ragged_dot_general_dtype_rule, + 'ragged_dot_general', +) +ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule +ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule +batching.fancy_primitive_batchers[ragged_dot_general_p] = _ragged_dot_general_batch_rule +batching.skippable_batchers[ragged_dot_general_p] = lambda _: () -def _ragged_dot_impl( + +def _ragged_dot_general_impl( lhs: Array, rhs: Array, group_sizes: Array, + ragged_dot_dimension_numbers: RaggedDotDimensionNumbers, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, group_offset: Array | None = None, @@ -5412,24 +5725,100 @@ def _ragged_dot_impl( if group_offset is not None: raise NotImplementedError("Unimplemented group_offset support.") - if len(lhs.shape) == 3: - ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS - ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0)) - else: - ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS - ragged_to_dense = _ragged_to_dense + def ragged_to_dense(x: Array, gs: Array, *, dim: int): + from jax._src.lax import control_flow # avoid circular imports + assert gs.ndim == 1 + shape = gs.shape + x.shape + x = broadcast_in_dim(x, shape, list(range(1, len(shape)))) + iota = broadcasted_iota(gs.dtype, shape, dim+1) + group_ends = control_flow.cumsum(gs) + group_starts = concatenate( + [_zeros(gs)[:1], group_ends[:-1]], + dimension=0, + ) + group_ends = broadcast_in_dim(group_ends, shape, (0,)) + group_starts = broadcast_in_dim(group_starts, shape, (0,)) + mask = bitwise_and(group_starts <= iota, iota < group_ends) + x = select(mask, x, _zeros(x)) + return x - lhs = ragged_to_dense(lhs, rhs, group_sizes) + def batched_ragged_to_dense(dim, *x_in_axes: int): + if not x_in_axes: + return partial(ragged_to_dense, dim=dim) + x_axis, *rest = x_in_axes + decr = lambda d: d - 1 if d >= x_axis else d + return api.vmap( + batched_ragged_to_dense(decr(dim), *[decr(ax) for ax in rest]), + in_axes=(x_axis, 0), + ) - return dot_general( - lhs, - rhs, - dimension_numbers=ragged_dot_dims, + incr = lambda dims: [d + 1 for d in dims] + + # Expand the ragged `dim` of `x`, given its batching `axes`. + # The group axis from `gs` becomes the outermost axis of the result. + # Some examples: + # x: [m,k] , gs: [g] ==> expand(x, 0, gs): [g,m,k] + # x: [b1,m,b2,k], gs: [b1,b2,g] ==> expand(x, 1, gs, 0, 2): [g,b1,m,b2,k] + def expand(x, dim, gs, *axes): + expanded = batched_ragged_to_dense(dim, *axes)(x, gs) + unsorted_dims = incr(axes) + [0] + incr(remaining(range(x.ndim), axes)) + return transpose(expanded, np.argsort(unsorted_dims)) + + mode, lhs_ragged_dim = _ragged_dot_mode_and_dim( + lhs.ndim, ragged_dot_dimension_numbers + ) + (l_contract, r_contract), (l_batch, r_batch) = ( + ragged_dot_dimension_numbers.dot_dimension_numbers + ) + l_prefix = _ragged_dot_prefix_dims( + mode, lhs.ndim, lhs_ragged_dim, l_batch, l_contract + ) + + _dot_general = partial( + dot_general, precision=precision, preferred_element_type=preferred_element_type, ) + # TODO(pravnar): Permit other broadcastable shapes. + if group_sizes.ndim == 1: + group_sizes = broadcast(group_sizes, [lhs.shape[i] for i in l_prefix]) -mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False)) + match mode: + case RaggedDotMode.RAGGED_NONCONTRACTING: + rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions + assert len(rhs_group_dims) == 1 + return _dot_general( + expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix), + rhs, + dimension_numbers=( + (incr(l_contract) + [0], list(r_contract) + [rhs_group_dims[0]]), + (incr(l_batch), r_batch), + ), + ) + case RaggedDotMode.RAGGED_CONTRACTING: + rhs_ragged_dim = r_contract[l_contract.index(lhs_ragged_dim)] + r_prefix = _ragged_dot_prefix_dims( + mode, rhs.ndim, rhs_ragged_dim, r_batch, r_contract + ) + return _dot_general( + expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix), + expand(rhs, rhs_ragged_dim, group_sizes, *r_prefix), + dimension_numbers=( + (incr(l_contract), incr(r_contract)), + ([0] + incr(l_batch), [0] + incr(r_batch)), + ), + ) + case RaggedDotMode.RAGGED_BATCH: + return _dot_general( + lhs, + rhs, + dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers, + ) + + +mlir.register_lowering(ragged_dot_general_p, + mlir.lower_fun(_ragged_dot_general_impl, + multiple_results=False)) def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index af5ec987e..1809f211f 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1541,6 +1541,7 @@ tf_not_yet_impl = [ "assert_consumed_value", "consume", "ragged_dot", + "ragged_dot_general", "cholesky_update", "symmetric_product", "from_edtype", diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index a26d15c14..4e376fb66 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -17,6 +17,7 @@ from jax._src.lax.lax import ( DotDimensionNumbers as DotDimensionNumbers, + RaggedDotDimensionNumbers as RaggedDotDimensionNumbers, Precision as Precision, PrecisionLike as PrecisionLike, DotAlgorithm as DotAlgorithm, @@ -158,6 +159,7 @@ from jax._src.lax.lax import ( pow as pow, pow_p as pow_p, ragged_dot as ragged_dot, + ragged_dot_general as ragged_dot_general, real as real, real_p as real_p, reciprocal as reciprocal, diff --git a/tests/lax_test.py b/tests/lax_test.py index 8497bf389..ad6b2a0bc 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -4820,5 +4820,228 @@ class RaggedTest(jtu.JaxTestCase): self._CheckAgainstNumpy( lax_reference.ragged_dot, lax.ragged_dot, args_maker) + @parameterized.parameters( + { + "lhs_shape": lhs_shape, + "rhs_shape": rhs_shape, + "group_sizes_shape": group_sizes_shape, + "ragged_dot_dimension_numbers": ragged_dot_dimension_numbers, + "err_msg": err_msg, + } + for lhs_shape, rhs_shape, group_sizes_shape, ragged_dot_dimension_numbers, err_msg in [ + ( + [11, 5], + [3, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0, 1], + rhs_group_dimensions=[0], + ), + "ragged_dot_general expects exactly one lhs ragged dimension", + ), + ( + [11, 5], + [3, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[2], + rhs_group_dimensions=[0], + ), + ( + "ragged_dot_general requires lhs ragged dimension numbers to " + "be nonnegative and less than the number of axes of the lhs" + ), + ), + ( + [11, 5], + [3, 5, 7], + [2, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], + ), + r"expected group_sizes to have shape \(3,\), got \(2, 3\)", + ), + ( + [19, 17, 11, 5], + [3, 19, 5, 7], + [19, 11, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([3], [2]), ([0], [1])), + lhs_ragged_dimensions=[2], + rhs_group_dimensions=[0], + ), + ( + r"expected group_sizes to have shape \(19, 17, 3\), " + r"got \(19, 11, 3\)" + ), + ), + ( + [19, 11, 17, 5], + [19, 17, 5, 7], + [19, 11, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([2, 3], [1, 2]), ([0], [0])), + lhs_ragged_dimensions=[3], + rhs_group_dimensions=[], + ), + ( + r"expected group_sizes to have shape \(19, 17, 3\), " + r"got \(19, 11, 3\)" + ), + ), + ( + [17, 19, 11, 5], + [17, 19, 5, 7], + [19, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([3], [2]), ([0, 1], [0, 1])), + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[], + ), + ( + r"expected group_sizes to have shape \(17, 3\), " + r"got \(19, 3\)" + ), + ), + ( + [19, 11, 5], + [19, 5, 7], + [19, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([2], [1]), ([0], [0])), + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[0], + ), + ( + "ragged_dot_general requires rhs group dimension numbers to " + "be distinct from contracting and batch dimensions" + ), + ), + ( + [11, 3], + [3, 3, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[1], + ), + ( + "ragged_dot_general requires rhs group dimension numbers to " + "be distinct from contracting and batch dimensions" + ), + ), + ( + [11, 5], + [3, 5, 7], + [2], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], + ), + "expected rhs group dimension size to be 2, got 3", + ), + ( + [2, 11, 5], + [3, 2, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([2], [2]), ([0], [1])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], + ), + ( + "ragged_dot_general requires zero group dimensions in " + "the rhs when lhs ragged dimension is contracting or batch" + ), + ), + ( + [11, 5], + [3, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[0], + ), + ( + "ragged_dot_general requires zero group dimensions in " + "the rhs when lhs ragged dimension is contracting or batch" + ), + ), + ( + [11, 5], + [5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [0]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[], + ), + ( + "ragged_dot_general requires exactly one rhs group dimension " + "when lhs ragged dimension is noncontracting" + ), + ), + ] + ) + def test_ragged_dot_general_shape_inference_failure( + self, lhs_shape, rhs_shape, group_sizes_shape, + ragged_dot_dimension_numbers, err_msg): + lhs = jnp.ones(lhs_shape, dtype=jnp.float32) + rhs = jnp.ones(rhs_shape, dtype=jnp.float32) + group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) + with self.assertRaisesRegex(TypeError, err_msg): + lax.ragged_dot_general(lhs, rhs, group_sizes, + ragged_dot_dimension_numbers) + + @parameterized.parameters( + { + "lhs_shape": lhs_shape, + "rhs_shape": rhs_shape, + "group_sizes_shape": group_sizes_shape, + "ragged_dnums": ragged_dnums, + "out_shape": out_shape, + } + for lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape in [ + ( + [11, 5], + [3, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], + ), + (11, 7), + ), + ( + [11, 5], + [5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [0]), ([], [])), + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[], + ), + (3, 11, 7), + ), + ] + ) + def test_ragged_dot_general_shape_inference_success( + self, lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape): + lhs = jnp.ones(lhs_shape, dtype=jnp.float32) + rhs = jnp.ones(rhs_shape, dtype=jnp.float32) + group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) + self.assertEqual( + lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dnums).shape, + out_shape, + ) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From affe2e734e94918354f1d5d7fe1bb2b2ce5ea9e9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 10 Mar 2025 13:03:59 -0700 Subject: [PATCH 090/100] Rename `dot_with_no_batch_dims_saveable` to `dots_with_no_batch_dims_saveable` for internal consistency PiperOrigin-RevId: 735484326 --- jax/_src/ad_checkpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 1aa9f17bc..c2868cf7c 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -69,7 +69,7 @@ def dots_saveable(prim, *_, **__) -> bool: lax_convolution.conv_general_dilated_p} checkpoint_dots = dots_saveable -def dot_with_no_batch_dims_saveable(prim, *_, **params) -> bool: +def dots_with_no_batch_dims_saveable(prim, *_, **params) -> bool: # This is a useful heuristic for transformers. if prim is lax_internal.dot_general_p: (_, _), (lhs_b, rhs_b) = params['dimension_numbers'] @@ -160,8 +160,8 @@ checkpoint_policies = types.SimpleNamespace( nothing_saveable=nothing_saveable, dots_saveable=dots_saveable, checkpoint_dots=dots_saveable, - dots_with_no_batch_dims_saveable=dot_with_no_batch_dims_saveable, - checkpoint_dots_with_no_batch_dims=dot_with_no_batch_dims_saveable, + dots_with_no_batch_dims_saveable=dots_with_no_batch_dims_saveable, + checkpoint_dots_with_no_batch_dims=dots_with_no_batch_dims_saveable, offload_dot_with_no_batch_dims=offload_dot_with_no_batch_dims, save_anything_except_these_names=save_anything_except_these_names, save_any_names_but_these=save_any_names_but_these, From 87272fbe9373cab40e95e149f708a9527fc4b8d9 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 10 Mar 2025 14:06:35 -0700 Subject: [PATCH 091/100] [Pallas/Fuser] Add debug option to fuser.fuse that prints out jaxpr PiperOrigin-RevId: 735505460 --- jax/_src/pallas/fuser/jaxpr_fusion.py | 44 ++++++++++++++++++--------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index f98175510..3d36b8f3e 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -32,28 +32,42 @@ def _get_aval(x): return jax_core.raise_to_shaped(jax_core.get_aval(x)) -def fuse(f, *, physicalize: bool = False): +def fuse(f=None, *, physicalize: bool = False, debug: bool = False): """Fuses a function into a single fusable. + Args: + f: The function to fuse. + physicalize: (experimental) whether to physicalize the function. + debug: Whether to print debug information. + There should be a single call to a `fusable` inside the body of `f`. `fuse` returns a transformed function that will fuse the surrounding computation into the fusable and invoke it. """ - def wrapper(*args, **kwargs): - flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) - debug_info = api_util.debug_info('fuse', f, args, kwargs) - flat_fun, out_tree_thunk = api_util.flatten_fun( - lu.wrap_init(f, debug_info=debug_info), in_tree - ) - flat_avals = [_get_aval(x) for x in flat_args] - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) - out_tree = out_tree_thunk() - out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args) - return tree_util.tree_unflatten(out_tree, out_flat) - if physicalize: - wrapper = fusable_dtype.physicalize(wrapper) - return wrapper + def decorator(f): + def wrapper(*args, **kwargs): + flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) + debug_info = api_util.debug_info("fuse", f, args, kwargs) + flat_fun, out_tree_thunk = api_util.flatten_fun( + lu.wrap_init(f, debug_info=debug_info), in_tree + ) + flat_avals = [_get_aval(x) for x in flat_args] + jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + if debug: + print("Jaxpr before fusion:") + print(jaxpr) + out_tree = out_tree_thunk() + out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args) + return tree_util.tree_unflatten(out_tree, out_flat) + + if physicalize: + wrapper = fusable_dtype.physicalize(wrapper) + return wrapper + + if f is not None: + return decorator(f) + return decorator _fusable: dict[jax_core.Primitive, Any] = {} From 81dde225b06a55ecc46860db9970829dbf917b6c Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 10 Mar 2025 14:22:25 -0700 Subject: [PATCH 092/100] [Pallas/Fuser] Add select_n push rule PiperOrigin-RevId: 735510713 --- jax/_src/pallas/fuser/block_spec.py | 48 +++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 913ef09cd..83b485107 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -25,8 +25,8 @@ from typing import Any, Callable, Protocol, Sequence import jax from jax import lax -from jax._src import api_util from jax._src import ad_util +from jax._src import api_util from jax._src import core from jax._src import custom_derivatives from jax._src import linear_util as lu @@ -351,7 +351,7 @@ def _pull_block_spec( jaxpr.constvars, jaxpr.invars, needed_invars, - jaxpr.eqns[:jaxpr.eqns.index(eqn)], + jaxpr.eqns[: jaxpr.eqns.index(eqn)], debug_info=jaxpr.debug_info, ) scalar_prefetch_jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts( @@ -426,6 +426,7 @@ def make_kernel_function( return tuple(s for s in shape if s is not None) _no_aval = object() + def _get_block_aval(bs, aval): if bs is pallas_core.no_block_spec or bs is None: return _no_aval @@ -441,10 +442,12 @@ def make_kernel_function( unflat_arg_usages, unflat_kwarg_usages = tree_util.tree_unflatten( in_tree, invar_usages ) + def sds_like(x): if x is _no_aval: return _no_aval return jax.ShapeDtypeStruct(x.shape, x.dtype) + kernel_in_type = jax.tree.map( sds_like, (unflat_in_block_arg_avals, unflat_in_block_kwarg_avals) ) @@ -688,8 +691,10 @@ def _eltwise_eval_rule(prim, ctx, x, **params): def _eltwise_pull_rule( - prim: core.Primitive, ctx: PullRuleContext, block_spec: pallas_core.BlockSpec, - **params + prim: core.Primitive, + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + **params, ) -> Sequence[pallas_core.BlockSpec]: del prim, ctx, params return [block_spec] @@ -702,7 +707,9 @@ def _eltwise_usage_rule( return [used_out] -def _bcast_block_spec(block_spec: pallas_core.BlockSpec, i: int) -> pallas_core.BlockSpec: +def _bcast_block_spec( + block_spec: pallas_core.BlockSpec, i: int +) -> pallas_core.BlockSpec: def new_index_map(i, *args): idx = block_spec.index_map(*args) assert len(idx) == len(block_spec.block_shape) @@ -710,7 +717,9 @@ def _bcast_block_spec(block_spec: pallas_core.BlockSpec, i: int) -> pallas_core. return idx new_block_shape = util.tuple_update(block_spec.block_shape, i, 1) - return pallas_core.BlockSpec(new_block_shape, functools.partial(new_index_map, i)) + return pallas_core.BlockSpec( + new_block_shape, functools.partial(new_index_map, i) + ) def _binop_usage_rule(prim, ctx, used_out: set[Usage]): @@ -945,7 +954,9 @@ def _dynamic_slice_rule( return block_indices new_block_spec = pallas_core.BlockSpec(block_spec.block_shape, new_index_map) - return [new_block_spec] + [pallas_core.no_block_spec] * (len(ctx.avals_in) - 1) + return [new_block_spec] + [pallas_core.no_block_spec] * ( + len(ctx.avals_in) - 1 + ) @register_eval_rule(lax.concatenate_p) @@ -1348,7 +1359,8 @@ def _push_block_spec_jaxpr( return env[atom] def _write_block_spec( - atom: core.Atom, block_spec: pallas_core.BlockSpec | pallas_core.NoBlockSpec + atom: core.Atom, + block_spec: pallas_core.BlockSpec | pallas_core.NoBlockSpec, ): if isinstance(atom, core.Literal): return @@ -1374,7 +1386,9 @@ def _push_block_spec_jaxpr( util.safe_map(_write_block_spec, eqn.outvars, out_block_specs) out_block_specs = tuple(util.safe_map(_read_block_spec, jaxpr.outvars)) - valid_block_spec = [bs for bs in flat_block_specs if bs is not pallas_core.no_block_spec][0] + valid_block_spec = [ + bs for bs in flat_block_specs if bs is not pallas_core.no_block_spec + ][0] out_block_specs = tuple( valid_block_spec if obs is pallas_core.no_block_spec else obs for obs in out_block_specs @@ -1491,6 +1505,18 @@ def _convert_element_type_push_rule( return block_spec +@register_push_block_spec_rule(lax.select_n_p) +def _select_n_push_rule( + ctx: PushRuleContext, + *args: pallas_core.BlockSpec, +): + del ctx + block_specs = [b for b in args if b is not pallas_core.no_block_spec] + if len(block_specs) > 1: + raise NotImplementedError('select_n with multiple inputs not supported yet') + return block_specs[0] + + @register_push_block_spec_rule(custom_derivatives.custom_jvp_call_p) def _custom_jvp_call_push_rule( ctx, *block_specs, call_jaxpr: core.ClosedJaxpr, **_ @@ -1500,9 +1526,7 @@ def _custom_jvp_call_push_rule( @register_push_block_spec_rule(pjit.pjit_p) -def _pjit_push_rule( - ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_ -): +def _pjit_push_rule(ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_): assert not jaxpr.consts return _push_block_spec_jaxpr(jaxpr.jaxpr, *block_specs) From 802cb33bf8ac0f36ff49d2d57b3362a8350d7efc Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Mon, 10 Mar 2025 14:48:15 -0700 Subject: [PATCH 093/100] [Pallas] Increase tolerance in PallasOutOfBoundsInterpretTest. PiperOrigin-RevId: 735519526 --- tests/pallas/pallas_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 9ee0dfc29..faa75d455 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -2125,13 +2125,15 @@ class PallasOutOfBoundsInterpretTest(PallasBaseTest): # TODO(justinfu): This test has low precision on GPU. Improve precision. if jtu.test_device_matches(["gpu"]): atol = 1e-2 + rtol = 5e-3 else: atol = 1e-5 + rtol = 1e-7 # With a masked matmul implementation, uninitialized values will be # masked before computation. This should return the correct result. with self.subTest('MaskedOutputIsCorrect'): - np.testing.assert_allclose(out, expected, atol=atol) + np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) class PallasCheckifyTest(PallasBaseTest): From aceae84fab5eb804111520fc3501e7a621d5e4ea Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 10 Mar 2025 15:13:12 -0700 Subject: [PATCH 094/100] [Pallas] Enable skipping of floating-point operations when interpreting Pallas TPU kernels on CPU. PiperOrigin-RevId: 735527650 --- jax/_src/pallas/mosaic/BUILD | 3 + jax/_src/pallas/mosaic/interpret.py | 110 +++++++++++++++++----- tests/pallas/tpu_pallas_interpret_test.py | 21 +++++ 3 files changed, 108 insertions(+), 26 deletions(-) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 3ef324372..24e834104 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -159,6 +159,9 @@ py_library( ":core", ":primitives", "//jax", + "//jax:core", + "//jax:source_info_util", + "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", ] + py_deps("numpy"), diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index e2086d6af..a731bfdfd 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -69,18 +69,24 @@ class TPUInterpretParams: Attributes: dma_execution_mode: If "eager", DMAs are executed as soon as they are - issued. If "on_wait", DMA reads or writes are only executed when a - device is waiting on a DMA semaphore that will be signaled when the read - or write is complete. + issued. If "on_wait", DMA reads or writes are only executed when a device + is waiting on a DMA semaphore that will be signaled when the read or write + is complete. Default: "on_wait". detect_races: If True, a dynamic, happens-before race detector will be used to detect data races during kernel interpretation. If any races are detected, a message will be printed and `races.races_found` will be set to True. Default: False. + skip_floating_point_ops: If True, operations that produce only floating + point values will not be interpreted; instead, their results will be + replaced with arrays all of `jnp.inf`. Additionaly any floating point + operands to any operation will be replaced with (arrays of) `jnp.inf`. + Default: False. """ dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" detect_races: bool = False + skip_floating_point_ops: bool = False VectorClock = np.ndarray @@ -954,16 +960,32 @@ def _is_any(memory_space): return ((memory_space == mosaic_core.TPUMemorySpace.ANY) or (memory_space == pallas_core.MemorySpace.ANY)) +def _is_float(dtype): + return jnp.issubdtype(dtype, jnp.floating) + +_SENTINEL = jnp.inf + +@dataclasses.dataclass(frozen=True) +class Placeholder: + """Placeholder for use in `_interpret_jaxpr` below instead of putting a concrete value into `env`.""" + shape: tuple[int, ...] + dtype: jnp.dtype + def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): env = {} def read(var): if isinstance(var, jax_core.Literal): - return var.val + result = var.val else: - return env[var] + result = env[var] + if isinstance(result, Placeholder): + result = jax.lax.full(result.shape, _SENTINEL, result.dtype) + return result def write(var, value): + if interpret_params.skip_floating_point_ops and _is_float(value.dtype): + value = Placeholder(value.shape, value.dtype) env[var] = value jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) @@ -987,11 +1009,16 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): with source_info_util.user_context( eqn.source_info.traceback, name_stack=eqn.source_info.name_stack): prim = eqn.primitive - invals = jax.util.safe_map(read, eqn.invars) + # We defer reading the values for `eqn.invars` into each of the branches + # of the if-elif-else statement below. This is because the else branch may + # not need to do any reads if `interpret_params.skip_floating_point_ops` + # is True. If this is the case, we want to avoid materializing the read + # array into the jaxpr when this function is traced. + deferred_invals = functools.partial(jax.util.safe_map, read, eqn.invars) if prim is primitives.load_p: (ref, transforms, mask, _) = jax.tree.unflatten( - eqn.params['args_tree'], invals) + eqn.params['args_tree'], deferred_invals()) if mask is not None: raise NotImplementedError('masked load_p') out = callback.io_callback( @@ -1005,7 +1032,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): elif prim is primitives.swap_p: (ref, transforms, val, mask) = jax.tree.unflatten( - eqn.params['args_tree'], invals) + eqn.params['args_tree'], deferred_invals()) out = callback.io_callback( functools.partial(swap, source_info=eqn.source_info), eqn.outvars[0].aval, @@ -1023,6 +1050,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): elif prim is lax.cond_p: def _make_branch(jaxpr): return lambda *args: _interpret(jaxpr, *args) + invals = deferred_invals() out = lax.switch( invals[0], [_make_branch(branch_jaxpr.jaxpr) @@ -1031,7 +1059,9 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): elif prim is lax.scan_p: consts, init_carry, xs = split_list( - invals, [eqn.params['num_consts'], eqn.params['num_carry']]) + deferred_invals(), + [eqn.params['num_consts'], eqn.params['num_carry']], + ) def _scan_body(c, a): return split_list( _interpret(eqn.params['jaxpr'].jaxpr, *consts, *c, *a), @@ -1041,8 +1071,10 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): out = carry + out elif prim is lax.while_p: - cond_consts, body_consts, init_vals = split_list( - invals, [eqn.params['cond_nconsts'], eqn.params['body_nconsts']]) + cond_consts, body_consts, init_vals = split_list( + deferred_invals(), + [eqn.params['cond_nconsts'], eqn.params['body_nconsts']], + ) out = lax.while_loop( lambda args: _interpret( eqn.params['cond_jaxpr'].jaxpr, *cond_consts, *args)[0], @@ -1056,6 +1088,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): elif prim is pjit.pjit_p: def f(*args, jaxpr): return _interpret(jaxpr.jaxpr, *jaxpr.consts, *args) + invals = deferred_invals() in_avals = tuple(jax_core.shaped_abstractify(i) for i in invals) new_jaxpr = _to_jaxpr( lu.wrap_init(functools.partial(f, jaxpr=eqn.params['jaxpr']), @@ -1084,7 +1117,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): primitives.uninitialized_value(v.aval.shape, v.aval.dtype), ordered=True)) - out = _interpret(eqn.params['jaxpr'], *invals, *allocs) + out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs) for a in allocs: if isinstance(a, tuple): @@ -1106,6 +1139,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): pass elif prim is state_primitives.get_p: + invals = deferred_invals() out = callback.io_callback( functools.partial(get, source_info=eqn.source_info), eqn.outvars[0].aval, @@ -1116,6 +1150,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): ordered=True) elif prim is state_primitives.swap_p: + invals = deferred_invals() out = callback.io_callback( functools.partial(swap, source_info=eqn.source_info), eqn.outvars[0].aval, @@ -1128,11 +1163,17 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): ordered=True) elif prim is mosaic_primitives.dma_start_p: - (src, src_transforms, - dst, dst_transforms, - dst_sem, dst_sem_transforms, - src_sem, src_sem_transforms, - target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals) + ( + src, + src_transforms, + dst, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + target_device_id, + ) = jax.tree.unflatten(eqn.params['tree'], deferred_invals()) target_device_id = _device_id_to_logical( target_device_id, eqn.params['device_id_type'], axis_sizes) (orig_src_ref, _, orig_dst_ref, *_ @@ -1152,11 +1193,17 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): out = [] elif prim is mosaic_primitives.dma_wait_p: - (src, src_transforms, - dst, dst_transforms, - dst_sem, dst_sem_transforms, - src_sem, src_sem_transforms, - target_device_id) = jax.tree.unflatten(eqn.params['tree'], invals) + ( + src, + src_transforms, + dst, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + target_device_id, + ) = jax.tree.unflatten(eqn.params['tree'], deferred_invals()) read_shape, read_dtype = _compute_transformed_shape_and_dtype( eqn.invars[0].aval.shape, eqn.invars[0].aval.dtype, src_transforms) callback.io_callback( @@ -1178,7 +1225,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): elif prim is mosaic_primitives.semaphore_signal_p: sem, sem_transforms, inc, target_device_id, core_index = ( - jax.tree.unflatten(eqn.params['args_tree'], invals)) + jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) target_device_id = _device_id_to_logical( target_device_id, eqn.params['device_id_type'], axis_sizes) callback.io_callback( @@ -1194,7 +1241,7 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): elif prim is mosaic_primitives.semaphore_wait_p: sem, sem_transforms, value = ( - jax.tree.unflatten(eqn.params['args_tree'], invals)) + jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) callback.io_callback( semaphore_wait, (), @@ -1211,8 +1258,19 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): raise NotImplementedError('atomic_cas_p') else: - subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - out = prim.bind(*subfuns, *invals, **bind_params) + if interpret_params.skip_floating_point_ops and all( + _is_float(ovar.aval.dtype) for ovar in eqn.outvars + ): + # Skip `prim.bind` since `prim` only produces floating-point values. + # It is safe to populate `out` with avals since mapping `write` over + # `out` below only relies on the shape and dtype (for writing + # `Placeholder`s). + out = [ovar.aval for ovar in eqn.outvars] + if not prim.multiple_results: + out = out[0] + else: + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + out = prim.bind(*subfuns, *deferred_invals(), **bind_params) out = out if prim.multiple_results else [out] jax.util.safe_map(write, eqn.outvars, out) diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 1219f37fb..71e91a697 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -134,6 +134,27 @@ class InterpretTest(jtu.JaxTestCase): )(x).block_until_ready() self.assertTrue(mosaic_interpret.races.races_found) + def test_skip_floating_point_ops(self): + def matmul_kernel(x_ref, y_ref, z_ref): + z_ref[...] = x_ref[...] @ y_ref[...] + + def matmul(x: jax.Array, y: jax.Array): + return pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), + interpret=mosaic_interpret.TPUInterpretParams( + skip_floating_point_ops=True + ), + )(x, y) + + k1, k2 = jax.random.split(jax.random.key(0)) + x = jax.random.normal(k1, (1024, 1024)) + y = jax.random.normal(k2, (1024, 1024)) + z = jax.jit(matmul)(x, y) + np.testing.assert_array_equal(z, jnp.full_like(z, jnp.inf)) + + lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo") + self.assertNotIn("dot_general", lowered) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 988a1208a956c519658087776f6caa20c95658c7 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 10 Mar 2025 16:54:21 -0700 Subject: [PATCH 095/100] Better error message when `raise_if_error()` is called within a traced context PiperOrigin-RevId: 735557928 --- jax/_src/error_check.py | 7 ++++++- tests/error_check_test.py | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index edfcdf3f7..11e65a7dd 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -88,12 +88,17 @@ def raise_if_error() -> None: """Raise error if an error is set. This function should be called after the computation is finished. It should - be used outside jit. + not be called within a traced context, such as within a jitted function." """ if _error_storage.ref is None: # if not initialized, do nothing return error_code = _error_storage.ref[...] + if isinstance(error_code, core.Tracer): + raise ValueError( + "raise_if_error() should not be called within a traced context, such as" + " within a jitted function." + ) if error_code == jnp.uint32(_NO_ERROR): return _error_storage.ref[...] = jnp.uint32(_NO_ERROR) diff --git a/tests/error_check_test.py b/tests/error_check_test.py index 653f901a6..5cdde30b1 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -170,6 +170,26 @@ class ErrorCheckTests(jtu.JaxTestCase): _ = body(init, xs) error_check.raise_if_error() # should not raise error + @parameterized.product(jit=[True, False]) + def test_raise_if_error_fails_in_traced_context(self, jit): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + if jit: + f = jax.jit(f) + + x = jnp.full((4,), 1, dtype=jnp.int32) + f(x) + with self.assertRaises( + ValueError, + msg=( + "raise_if_error() should not be called within a traced context," + " such as within a jitted function." + ), + ): + jax.jit(error_check.raise_if_error)() + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 02505fa757efd64e31e3e2ddb27ec07a7970c204 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 10 Mar 2025 17:18:49 -0700 Subject: [PATCH 096/100] [Pallas TPU] Remove `next_slot` SMEM tensor from pipeline emitter PiperOrigin-RevId: 735564365 --- jax/_src/pallas/mosaic/pipeline.py | 115 ++++++++++++++++++++++------- 1 file changed, 87 insertions(+), 28 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 0a79de771..2044d3d18 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -207,6 +207,8 @@ class BufferedRef: is_accumulator: whether this BufferedRef is an accumulator. is_input_output: whether this BufferedRef is an input/output without automatic accumulation. + swap: Tracks whether the BufferedRef slots need to be swapped before next + copy. """ spec: pl.BlockSpec # static metadata dtype: Any # static metadata @@ -214,9 +216,14 @@ class BufferedRef: window_ref: REF | None accum_ref: REF | None current_slot: ArrayRef | None + # TODO(ramiroleal): Unused by class. Remove argument from + # BufferedRef instantiations. next_slot: ArrayRef | None sem_recvs: SemaphoreTuple | None sem_sends: SemaphoreTuple | None + # TODO(ramiroleal): Improve prefetch/postyeet interface to avoid + # using this ref. + swap: ArrayRef | None def tree_flatten(self): return ( @@ -227,6 +234,7 @@ class BufferedRef: self.next_slot, self.sem_recvs, self.sem_sends, + self.swap, ), (self.spec, self.dtype, self.buffer_type), ) @@ -240,7 +248,7 @@ class BufferedRef: return BufferType @classmethod - def create(cls, spec, dtype, buffer_type) -> BufferedRef: + def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef: """Create a BufferedRef. Args: @@ -248,6 +256,7 @@ class BufferedRef: dtype: dtype for buffers. buffer_type: enum indicating whether this is an input, output, or in/out accumulator buffered reference. + needs_swap_ref: whether a swap slots tracker needs to be allocated. Returns: Initialized BufferedRef @@ -271,6 +280,7 @@ class BufferedRef: next_slot=None, sem_recvs=None, sem_sends=None, + swap=None, ) else: memory_space = SMEM if spec.memory_space == SMEM else VMEM @@ -281,7 +291,7 @@ class BufferedRef: window_ref=memory_space((2,) + block_shape, dtype), accum_ref=accum_ref, current_slot=SMEM((1,), jnp.int32), - next_slot=SMEM((1,), jnp.int32), + next_slot=None, sem_recvs=( None if buffer_type is BufferType.OUTPUT @@ -292,23 +302,24 @@ class BufferedRef: if buffer_type is BufferType.INPUT else SemaphoreType.DMA((2,)) ), + swap=SMEM((1,), jnp.bool) if needs_swap_ref else None, ) @classmethod - def input(cls, spec, dtype): - return cls.create(spec, dtype, BufferType.INPUT) + def input(cls, spec, dtype, needs_swap_ref=True): + return cls.create(spec, dtype, BufferType.INPUT, needs_swap_ref) @classmethod - def output(cls, spec, dtype): - return cls.create(spec, dtype, BufferType.OUTPUT) + def output(cls, spec, dtype, needs_swap_ref=True): + return cls.create(spec, dtype, BufferType.OUTPUT, needs_swap_ref) @classmethod - def accumulator(cls, spec, dtype): - return cls.create(spec, dtype, BufferType.ACCUMULATOR) + def accumulator(cls, spec, dtype, needs_swap_ref=True): + return cls.create(spec, dtype, BufferType.ACCUMULATOR, needs_swap_ref) @classmethod - def input_output(cls, spec, dtype): - return cls.create(spec, dtype, BufferType.INPUT_OUTPUT) + def input_output(cls, spec, dtype, needs_swap_ref=True): + return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, needs_swap_ref) @property def block_shape(self): @@ -329,7 +340,7 @@ class BufferedRef: if self.memory_space == VMEM: return self.window_ref.at[buffer_slice] else: - return self.window_ref.at[(self.current_slot[0], *buffer_slice)] + return self.window_ref.at[(self.current_slot_index, *buffer_slice)] @property def is_input(self): @@ -355,6 +366,14 @@ class BufferedRef: def is_input_output(self): return self.buffer_type == BufferType.INPUT_OUTPUT + @property + def current_slot_index(self): + return self.current_slot[0] + + @property + def next_slot_index(self): + return lax.rem(self.current_slot_index + 1, 2) + def bind_existing_ref(self, window_ref, indices): """For handling VMEM references, the pipeline aliases the existing ref.""" if self.memory_space == VMEM: @@ -373,12 +392,15 @@ class BufferedRef: """Initialize slot indices.""" if self.memory_space == VMEM: return self.current_slot[0] = 0 - self.next_slot[0] = 0 + if self.swap is not None: + self.swap[0] = False def swap_slots(self): """Switch to the next slot.""" if self.memory_space == VMEM: return - self.current_slot[0] = self.next_slot[0] + self.current_slot[0] = self.next_slot_index + if self.swap is not None: + self.swap[0] = False def get_dma_slice(self, src_shape, src_dtype, grid_indices): # We need to handle blocks that might go OOB in the src array. An in bounds @@ -441,8 +463,9 @@ class BufferedRef: """Starts copy of HBM dma slice into the current slot.""" assert self.is_input if self.memory_space == VMEM: return - next_slot = lax.rem(self.current_slot[0] + 1, 2) - self.next_slot[0] = next_slot + if self.swap is not None: + self.swap[0] = True + next_slot = self.next_slot_index src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) tpu_primitives.make_async_copy( @@ -455,8 +478,9 @@ class BufferedRef: """Starts copy of HBM dma slice from the current slot.""" assert self.is_output if self.memory_space == VMEM: return - slot = self.current_slot[0] - self.next_slot[0] = lax.rem(slot + 1, 2) + if self.swap is not None: + self.swap[0] = True + slot = self.current_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( @@ -471,7 +495,7 @@ class BufferedRef: if self.memory_space == VMEM: return src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) - current_slot = self.current_slot[0] + current_slot = self.current_slot_index tpu_primitives.make_async_copy( src_ref.at[src_slice], # nb: doesn't matter self.window_ref.at[current_slot].at[ @@ -484,7 +508,8 @@ class BufferedRef: """Waits for output copy to finish.""" assert self.is_output if self.memory_space == VMEM: return - prev_slot = lax.rem(self.current_slot[0] + 1, 2) + # In a double buffer, previous slot is the same as next slot. + prev_slot = self.next_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( @@ -671,10 +696,7 @@ class Scheduler: def _start(): if buffered_ref.is_input: buffered_ref.copy_in(src_ref, self.indices) - - # In the prologue this makes it so we wait on the prologue copy to finish. - # In other iterations this is the regular swap. - buffered_ref.swap_slots() + buffered_ref.swap_slots() def wait_in(self, buffered_ref, src_ref, schedule=None): if schedule is None: @@ -780,9 +802,32 @@ class Scheduler: @self._named_scope("ep_finalize") def _end(): if buffered_ref.is_output: - buffered_ref.swap_slots() # formally correct, not actually necessary. buffered_ref.wait_out(dst_ref, self.indices) + def swap_slots(self, buffered_ref, hbm_ref, schedule=None): + if buffered_ref.swap is not None: + swap = buffered_ref.swap[0] + else: + # If we are not using an SMEM `swap` tensor to keep track of + # swaps needed, then all the copies into and out of BufferedRefs + # are done by direct calls to the `copy_in` and `copy_out` + # methods in the pipeline loop. To determine if the BufferedRef + # needs a swap of slots, we recalculate the copy-in/copy-out + # conditions. + if schedule is None: + schedule = _default_schedule + pred_in = schedule["copy_in"](self, buffered_ref, hbm_ref) + pred_out = schedule["copy_out"](self, buffered_ref, hbm_ref) + + copied_in = pred_in & buffered_ref.is_input & ~self.last_step + copied_out = pred_out & buffered_ref.is_output + swap = copied_in | copied_out + + @pl.when(swap) + @self._named_scope("ep_swap") + def _swap(): + buffered_ref.swap_slots() + # END SCHEDULE -------------------------------------------------------------- @@ -875,6 +920,7 @@ def make_pipeline_allocations( in_specs=None, out_specs=None, should_accumulate_out=False, + needs_swap_ref=True, ): """Create BufferedRefs for the pipeline. @@ -887,6 +933,7 @@ def make_pipeline_allocations( out_specs: output pallas block specs should_accumulate_out: booleans to indicate which outputs should be treated as accumulators. + needs_swap_ref: whether a swap slots tracker needs to be allocated. Returns: A list of BufferedRefs, one corresponding to each ref specified in the @@ -905,12 +952,12 @@ def make_pipeline_allocations( in_refs = refs[:num_in_specs] out_refs = refs[num_in_specs:] def make_input_bref(in_spec, in_ref): - return BufferedRef.input(in_spec, in_ref.dtype) + return BufferedRef.input(in_spec, in_ref.dtype, needs_swap_ref) in_brefs = jax.tree.map(make_input_bref, in_specs, in_refs) def make_output_bref(out_spec, out_ref, accumulate): if accumulate: - return BufferedRef.accumulator(out_spec, out_ref.dtype) - return BufferedRef.output(out_spec, out_ref.dtype) + return BufferedRef.accumulator(out_spec, out_ref.dtype, needs_swap_ref) + return BufferedRef.output(out_spec, out_ref.dtype, needs_swap_ref) out_brefs = jax.tree.map( make_output_bref, out_specs, out_refs, should_accumulate_out) return (*in_brefs, *out_brefs) @@ -1109,6 +1156,14 @@ def emit_pipeline( scratches = () if allocations is None: # run with inline scoped allocations + + # Prefetch and postyeet are arbitrary functions that can copy + # into or out of any of the BufferedRefs. Thus, we need a ref + # for the scheduler to mark when the prefetch or postyeet + # functions perform a copy and the slots need to be + # swapped. Without prefetch and postyeet, the swapping logic can + # be performed without the need for state. + needs_swap_ref = prefetch is not None or postyeet is not None return primitives.run_scoped( lambda allocations: pipeline( *refs, @@ -1125,7 +1180,9 @@ def emit_pipeline( *refs, in_specs=in_specs, out_specs=out_specs, - should_accumulate_out=should_accumulate_out), + should_accumulate_out=should_accumulate_out, + needs_swap_ref=needs_swap_ref, + ), ) if isinstance(allocations, list): allocations = tuple(allocations) @@ -1184,6 +1241,8 @@ def emit_pipeline( lax.cond(step == 0, lambda: postyeet(*brefs, scheduler), lambda: None) + + map_brefs(scheduler.swap_slots, brefs, refs, schedule) map_brefs(scheduler.finalize, brefs, refs, schedule) return _next_index(indices, grid) From 76dec382861a96f6e4e7b3c7dd00f60254e08ded Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 10 Mar 2025 20:20:20 -0700 Subject: [PATCH 097/100] Under pjit the `with mesh:` context will use `use_mesh(mesh): jit` instead of tracking separately using `resource_env`. This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested. This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally. PiperOrigin-RevId: 735602187 --- jax/_src/checkify.py | 4 +- jax/_src/custom_partitioning.py | 3 +- jax/_src/interpreters/pxla.py | 3 +- jax/_src/pjit.py | 129 ++++++++++++--------------- jax/experimental/jax2tf/jax2tf.py | 2 +- jax/experimental/sparse/transform.py | 4 +- tests/pjit_test.py | 28 +----- 7 files changed, 65 insertions(+), 108 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 5a6561afa..74ea53714 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -901,7 +901,7 @@ error_checks[lax.while_p] = while_loop_error_check def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, inline, keep_unused, + donated_invars, ctx_mesh, name, inline, keep_unused, compiler_options_kvs): # jaxpr to checked_jaxpr err_vals, err_tree = jtu.tree_flatten(error) @@ -928,8 +928,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, out_shardings=new_out_shardings, in_layouts=new_in_layouts, out_layouts=new_out_layouts, - resource_env=resource_env, donated_invars=new_donated_invars, + ctx_mesh=ctx_mesh, name=name, inline=inline, keep_unused=keep_unused, diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index b835c4c83..658a6f7a2 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -181,7 +181,8 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( *tiled_args ) - if closed_jaxpr.out_avals != tiled_results: + if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] != + [(t.shape, t.dtype) for t in tiled_results]): raise ValueError( "Mismatch in result shapes. %s vs %s" % (repr(closed_jaxpr.out_avals), repr(tiled_results)) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 1b2d85006..f97ee5414 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1663,7 +1663,7 @@ class MismatchType(enum.Enum): elif self.name == 'OUT_SHARDING': return 'explicit output sharding' elif self.name == 'CONTEXT_DEVICES': - return 'devices' + return 'context mesh' return f'{self.name}' @@ -3060,7 +3060,6 @@ class JitGlobalCppCacheKeys: in_layouts_leaves: tuple[Any, ...] | None = None out_layouts_treedef: PyTreeDef | None = None out_layouts_leaves: tuple[Any, ...] | None = None - use_resource_env: bool = False compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None @functools.cached_property diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 69cd8e809..06892aa9f 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -357,7 +357,6 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo): in_layouts_leaves=jit_info.in_layouts_leaves, out_layouts_treedef=jit_info.out_layouts_treedef, out_layouts_leaves=jit_info.out_layouts_leaves, - use_resource_env=jit_info.use_resource_env, compiler_options_kvs=jit_info.compiler_options_kvs) cpp_pjit_f = xc._xla.pjit( fun_name(fun), fun, cache_miss, jit_info.static_argnums, @@ -544,8 +543,7 @@ class PjitParams(NamedTuple): def _infer_params_impl( fun: Callable, ji: PjitInfo, - pjit_mesh: mesh_lib.Mesh | None, - resource_env: mesh_lib.ResourceEnv | None, + ctx_mesh: mesh_lib.Mesh | None, dbg: core.DebugInfo, args: tuple[Any, ...], kwargs: dict[str, Any], @@ -557,8 +555,8 @@ def _infer_params_impl( raise ValueError( "pjit does not support kwargs when in_shardings is specified.") - if pjit_mesh is not None: - if (ji.backend or ji.device) and not pjit_mesh.empty: + if ctx_mesh is not None: + if (ji.backend or ji.device) and not ctx_mesh.empty: raise ValueError( "Mesh context manager should not be used with jit when backend or " "device is also specified as an argument to jit.") @@ -590,11 +588,11 @@ def _infer_params_impl( in_shardings_treedef = out_shardings_treedef = treedef else: in_shardings_leaves = tuple( - _create_sharding_for_array(pjit_mesh, x, 'in_shardings', 'jit') + _create_sharding_for_array(ctx_mesh, x, 'in_shardings', 'jit') for x in ji.in_shardings_leaves) in_shardings_treedef = ji.in_shardings_treedef out_shardings_leaves = tuple( - _create_sharding_for_array(pjit_mesh, x, 'out_shardings', 'jit') + _create_sharding_for_array(ctx_mesh, x, 'out_shardings', 'jit') for x in ji.out_shardings_leaves) out_shardings_treedef = ji.out_shardings_treedef @@ -652,8 +650,8 @@ def _infer_params_impl( out_shardings=out_shardings_flat, in_layouts=in_layouts_flat, out_layouts=out_layouts_flat, - resource_env=resource_env, donated_invars=donated_invars, + ctx_mesh=ctx_mesh, name=fun_qual_name(flat_fun), keep_unused=ji.keep_unused, inline=ji.inline, @@ -683,38 +681,30 @@ def _infer_params_cached( jit_info: PjitInfo, signature: jax_jit.ArgumentSignature, in_avals: tuple[core.AbstractValue, ...], - pjit_mesh: mesh_lib.Mesh | None, - resource_env: mesh_lib.ResourceEnv | None, + ctx_mesh: mesh_lib.Mesh | None, ) -> InferParamsCacheEntry: return InferParamsCacheEntry() -def disallow_use_mesh_and_legacy_mesh_ctx_mgr_together(): - if (not mesh_lib.thread_resources.env.physical_mesh.empty and - mesh_lib.get_concrete_mesh() is not None): - raise ValueError( - 'Using `with mesh:` context manager and `jax.sharding.use_mesh`' - ' together is not allowed.') def _infer_params( fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] -) -> tuple[PjitParams, list[Any]]: - disallow_use_mesh_and_legacy_mesh_ctx_mgr_together() + ) -> tuple[PjitParams, list[Any]]: if ji.use_resource_env: - # We need to fetch the mesh from inside the wrapped function, because - # meshes are dynamically scoped (i.e., with a context manager). - resource_env = mesh_lib.thread_resources.env - pjit_mesh = resource_env.physical_mesh - else: - resource_env = None - pjit_mesh = None + with mesh_lib.use_mesh(mesh_lib.thread_resources.env.physical_mesh): + return _infer_params_internal(fun, ji, args, kwargs) + return _infer_params_internal(fun, ji, args, kwargs) +def _infer_params_internal( + fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[PjitParams, list[Any]]: + ctx_mesh = mesh_lib.get_concrete_mesh() dbg = debug_info( 'jit', fun, args, kwargs, static_argnums=ji.static_argnums, static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo, signature=ji.fun_signature) if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache - p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, dbg, + p, args_flat = _infer_params_impl(fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=None) return p, p.consts + args_flat @@ -722,10 +712,11 @@ def _infer_params( args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, ji.static_argnames, tree_util.default_registry) avals = _infer_input_type(fun, dbg, dynargs) - entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env) + entry = _infer_params_cached(fun, ji, signature, avals, ctx_mesh) + if entry.pjit_params is None: p, args_flat = _infer_params_impl( - fun, ji, pjit_mesh, resource_env, dbg, args, kwargs, in_avals=avals) + fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals) if p.attrs_tracked: # if attrs, don't popoulate the cache return p, p.consts + args_flat entry.pjit_params = p @@ -1616,7 +1607,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] def _resolve_and_lower( args, jaxpr, in_shardings, out_shardings, in_layouts, - out_layouts, resource_env, donated_invars, name, keep_unused, inline, + out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, lowering_platforms, lowering_parameters, pgle_profiler, compiler_options_kvs): in_shardings = _resolve_in_shardings(args, in_shardings) @@ -1624,8 +1615,8 @@ def _resolve_and_lower( jaxpr.in_avals) out_layouts = _resolve_out_layouts(out_layouts, out_shardings, jaxpr.out_avals) return _pjit_lower( - jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, - donated_invars, name, keep_unused, inline, compiler_options_kvs, + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters, pgle_profiler=pgle_profiler) @@ -1634,7 +1625,7 @@ _pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore def _pjit_call_impl_python( *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): pgle_compile_options, pgle_profiler = {}, None if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: @@ -1659,8 +1650,8 @@ def _pjit_call_impl_python( compiled = _resolve_and_lower( args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, - out_layouts=out_layouts, resource_env=resource_env, - donated_invars=donated_invars, name=name, keep_unused=keep_unused, + out_layouts=out_layouts, donated_invars=donated_invars, + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, lowering_platforms=None, lowering_parameters=mlir.LoweringParameters(), pgle_profiler=pgle_profiler, @@ -1691,7 +1682,7 @@ def _pjit_call_impl_python( @weakref_lru_cache def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, - out_layouts, resource_env, donated_invars, name, + out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): # The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so # returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to @@ -1705,14 +1696,14 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, def _pjit_call_impl(*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): def call_impl_cache_miss(*args_, **kwargs_): out_flat, compiled, pgle_profiler = _pjit_call_impl_python( *args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, - out_layouts=out_layouts, resource_env=resource_env, - donated_invars=donated_invars, name=name, keep_unused=keep_unused, + out_layouts=out_layouts, donated_invars=donated_invars, + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) fastpath_data = _get_fastpath_data( compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, @@ -1721,7 +1712,7 @@ def _pjit_call_impl(*args, jaxpr, f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs) donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) cache_key = pxla.JitGlobalCppCacheKeys( @@ -1730,8 +1721,7 @@ def _pjit_call_impl(*args, jaxpr, in_shardings_treedef=None, in_shardings_leaves=in_shardings, out_shardings_treedef=None, out_shardings_leaves=out_shardings, in_layouts_treedef=None, in_layouts_leaves=in_layouts, - out_layouts_treedef=None, out_layouts_leaves=out_layouts, - use_resource_env=resource_env is not None) + out_layouts_treedef=None, out_layouts_leaves=out_layouts) return xc._xla.pjit( name, f, call_impl_cache_miss, [], [], cache_key, tree_util.dispatch_registry, pxla.cc_shard_arg, @@ -1746,8 +1736,8 @@ def _pjit_lower( out_shardings, in_layouts: pxla.MaybeLayout, out_layouts: pxla.MaybeLayout, - resource_env, donated_invars, + ctx_mesh, name: str, keep_unused: bool, inline: bool, @@ -1757,12 +1747,10 @@ def _pjit_lower( lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): util.test_event("pjit_lower") - mesh = (resource_env.physical_mesh if resource_env is not None else - mesh_lib.get_concrete_mesh()) return pxla.lower_sharding_computation( jaxpr, 'jit', name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), - keep_unused=keep_unused, context_mesh=mesh, + keep_unused=keep_unused, context_mesh=ctx_mesh, compiler_options_kvs=compiler_options_kvs, lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters, @@ -1914,8 +1902,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext, def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str, jaxpr: core.ClosedJaxpr, in_shardings, - out_shardings, in_layouts, out_layouts, resource_env, - donated_invars, keep_unused, inline, compiler_options_kvs): + out_shardings, in_layouts, out_layouts, donated_invars, + ctx_mesh, keep_unused, inline, compiler_options_kvs): effects = list(ctx.tokens_in.effects()) output_types = map(mlir.aval_to_ir_type, ctx.avals_out) output_types = [mlir.token_type()] * len(effects) + output_types @@ -1945,23 +1933,20 @@ def _pjit_batcher(axis_data, vals_in, dims_in: tuple[int, ...], jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in) - if resource_env is not None: - mesh = resource_env.physical_mesh - else: - mesh = None - # TODO(axch): prepend with Nones (?) to account for new segment_lens inputs in_shardings = tuple( - _pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, mesh, aval.ndim) + _pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, ctx_mesh, + aval.ndim) if axis_in is not None else i for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals)) out_shardings = tuple( - _pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, mesh, aval.ndim) + _pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, ctx_mesh, + aval.ndim) if axis_out is not None else o for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals)) # TODO(yashkatariya): Figure out layouts should change under vmap. @@ -1977,8 +1962,8 @@ def _pjit_batcher(axis_data, vals_in, out_shardings=out_shardings, in_layouts=in_layouts, out_layouts=out_layouts, - resource_env=resource_env, donated_invars=donated_invars, + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, @@ -2000,8 +1985,8 @@ def _insert_axis_partitions(spec, dim, val): def _pjit_batcher_for_sharding( s: Sharding | UnspecifiedValue, - dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, mesh, - ndim: int): + dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, + mesh, ndim: int): if isinstance(s, UnspecifiedValue): return s hlo_s = s._to_xla_hlo_sharding(ndim) @@ -2040,7 +2025,7 @@ def _pjit_batcher_for_sharding( def _pjit_jvp(primals_in, tangents_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in] jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr( @@ -2057,8 +2042,8 @@ def _pjit_jvp(primals_in, tangents_in, out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)), in_layouts=(*in_layouts, *_filter_zeros_in(in_layouts)), out_layouts=(*out_layouts, *_filter_zeros_out(out_layouts)), - resource_env=resource_env, donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)), + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, @@ -2074,7 +2059,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp def _pjit_linearization(nzs, *primals_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs) # constvars will become residuals. Move them to the end of the ordinary args. @@ -2090,8 +2075,8 @@ def _pjit_linearization(nzs, *primals_in, jaxpr, out_shardings=_filter_zeros(nzs_out, out_shardings), in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts, out_layouts=_filter_zeros(nzs_out, out_layouts), - resource_env=resource_env, donated_invars=_filter_zeros(nzs, donated_invars) + res_donated, + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, @@ -2110,8 +2095,8 @@ def _pjit_linearization(nzs, *primals_in, jaxpr, out_shardings=(*res_shardings, *out_shardings), in_layouts=in_layouts, out_layouts=(*res_layouts, *out_layouts), - resource_env=resource_env, donated_invars=donated_invars, + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, @@ -2126,7 +2111,7 @@ ad.primitive_linearizations[pjit_p] = _pjit_linearization def _pjit_partial_eval(trace: pe.JaxprTrace, *in_tracers, jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, - in_layouts, out_layouts, resource_env, donated_invars, + in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): in_pvals = [t.pval for t in in_tracers] @@ -2193,8 +2178,9 @@ def _pjit_partial_eval(trace: pe.JaxprTrace, jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins), out_shardings=known_out_shardings, in_layouts=keep_where(in_layouts, known_ins), - out_layouts=known_out_layouts, resource_env=resource_env, + out_layouts=known_out_layouts, donated_invars=keep_where(donated_invars, known_ins), + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals) @@ -2225,9 +2211,9 @@ def _pjit_partial_eval(trace: pe.JaxprTrace, out_shardings=keep_where(out_shardings, unknown_outs), in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts), out_layouts=keep_where(out_layouts, unknown_outs), - resource_env=resource_env, donated_invars=(keep_where(donated_invars, unknown_ins) + (False,) * num_residuals), + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, @@ -2313,7 +2299,7 @@ def _pjit_transpose_trace(fun: lu.WrappedFun, def _pjit_transpose(cts_in, *primals_in, jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline, + donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): def prune_type(ty, xs, maybe_zeros): return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) @@ -2362,8 +2348,8 @@ def _pjit_transpose(cts_in, *primals_in, out_shardings=transpose_out_shardings, in_layouts=transpose_in_layouts, out_layouts=transpose_out_layouts, - resource_env=resource_env, donated_invars=(False,) * len(primals_and_nz_cts_in), + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, @@ -2447,9 +2433,8 @@ def _pjit_pp_rule(eqn: core.JaxprEqn, del params['out_layouts'] if not params['keep_unused']: del params['keep_unused'] - if (params['resource_env'] is None or - params['resource_env'].physical_mesh.empty): - del params['resource_env'] + if params['ctx_mesh'] is None or params['ctx_mesh'].empty: + del params['ctx_mesh'] if not params['compiler_options_kvs']: del params['compiler_options_kvs'] @@ -2549,8 +2534,6 @@ def with_sharding_constraint(x, shardings): flatten_axes("with_sharding_constraint layouts", tree, layouts)) del layouts - disallow_use_mesh_and_legacy_mesh_ctx_mgr_together() - context_mesh = ( mesh_lib.get_abstract_mesh() if mesh_lib.get_concrete_mesh() is not None else mesh_lib.thread_resources.env.physical_mesh) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 1809f211f..7f98ce433 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3573,8 +3573,8 @@ def _pjit(*args: TfVal, in_shardings: Sequence[sharding.Sharding], out_shardings: Sequence[sharding.Sharding], in_layouts, out_layouts, - resource_env: mesh.ResourceEnv, donated_invars, + ctx_mesh, name: str, keep_unused: bool, inline: bool, diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index bd72850bc..582fdf411 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -775,7 +775,7 @@ sparse_rules_bcoo[lax.while_p] = _while_sparse def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, - in_layouts, out_layouts, resource_env, donated_invars, name, + in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): if any(donated_invars): raise NotImplementedError("sparse xla_call with donated_invars") @@ -808,8 +808,8 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, out_shardings=out_shardings, in_layouts=in_layouts, out_layouts=out_layouts, - resource_env=resource_env, donated_invars=donated_invars, + ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 720a77410..bd7954d60 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1205,8 +1205,7 @@ class PJitTest(jtu.BufferDonationTestCase): with self.assertRaisesRegex( ValueError, r"One of with_sharding_constraint.*Sharding " - r"NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), " - r"spec=PartitionSpec\(None, 'mdl', None, None\).*\) is only " + r"NamedSharding.*PartitionSpec\(None, 'mdl', None, None\).*\) is only " "valid for values of rank at least 4, but was applied to a value of rank 1"): pjit_f(jnp.array([1, 2, 3])) @@ -6873,31 +6872,6 @@ class ShardingInTypesTest(jtu.JaxTestCase): ' axis_types are `Auto`'): NamedSharding(mesh, P(P.UNCONSTRAINED)) - def test_use_mesh_legacy_mesh_ctx_mgr_mix_error(self): - mesh = jtu.create_mesh((1, 1), ('x', 'y')) - - with self.assertRaisesRegex( - ValueError, - 'Using `with mesh:` context manager and `jax.sharding.use_mesh`' - ' together is not allowed'): - with jax.sharding.use_mesh(mesh), mesh: - jax.jit(lambda x: x)(jnp.arange(8)) - - with self.assertRaisesRegex( - ValueError, - 'Using `with mesh:` context manager and `jax.sharding.use_mesh`' - ' together is not allowed'): - with jax.sharding.use_mesh(mesh), mesh: - jnp.zeros((8, 2), dtype=jnp.int32) - - x = jnp.arange(8) - with self.assertRaisesRegex( - ValueError, - 'Using `with mesh:` context manager and `jax.sharding.use_mesh`' - ' together is not allowed'): - with jax.sharding.use_mesh(mesh), mesh: - jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P())) - def test_pspec_einsum_no_context_mesh(self): mesh = jtu.create_mesh((1, 1), ('x', 'y'), axis_types={AxisTypes.Explicit: ('x', 'y')}) From cb2eb15739459d86e7297a058997473600c2f782 Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Tue, 11 Mar 2025 03:16:21 -0700 Subject: [PATCH 098/100] PR #22800: Change the default value of print_operand_shape_ to false and print_large_constants_ to true. Imported from GitHub PR https://github.com/openxla/xla/pull/22800 Operand shape in long hlo text adds redundant information, which shouldn't be required. Changing the default value to off. The large constants were also printed earlier by default print options, and it is required for parsability and reproducibility. Turning this on by default. This is still controlled by debug option and the default value of that flag disables the large constants, and that behavior is not changed. Just the default print options change here. Copybara import of the project: -- e30dea20489b3fb4d03d373fec0391d69486f4aa by Shraiysh Vaishay : Change the default value of print_operand_shape_ to false and print_large_constants_ to true. Operand shape in long hlo text adds redundant information, which shouldn't be required. Changing the default value to off. The large constants were also printed earlier by default print options, and it is required for parsability and reproducibility. Turning this on by default. This is still controlled by debug option and the default value of that flag disables the large constants, and that behavior is not changed. Just the default print options change here. -- 7008af0dd0ce342ecbe9475f1d0e277319f1705a by Shraiysh Vaishay : Handle tests -- b22d5f95cfb7e15f930a2198279a76c38593cc53 by Shraiysh Vaishay : Fix more tests -- d51579cae7359c6426a87ad4a7ff1b4b0c80f74a by Shraiysh Vaishay : Fix more tests Merging this change closes #22800 PiperOrigin-RevId: 735690598 --- tests/memories_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index acb13336e..a08c5f36c 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1834,7 +1834,7 @@ class ActivationOffloadingTest(jtu.JaxTestCase): self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)") self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)") self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)") - self.assertRegex(compiled_text, r"dynamic-slice-done.*S\(5\)") + self.assertIn("dynamic-slice-start", compiled_text) compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: From b6da46ecda106c9b78b2ef4704c35d6eb9b6b9b6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 11 Mar 2025 04:41:13 -0700 Subject: [PATCH 099/100] Update XLA dependency to use revision http://github.com/openxla/xla/commit/fae64d49aa41e774922ca46e94cd754c800b6240. PiperOrigin-RevId: 735709684 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a01d837fc..6710de12a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "efb27eb924fd5d9b20b908a7cadb11d78d2a81a1" -XLA_SHA256 = "b3a3e0df6bd5923d081fc1a96df41f2f29497b329da58b9b992f4345abf21c8b" +XLA_COMMIT = "fae64d49aa41e774922ca46e94cd754c800b6240" +XLA_SHA256 = "846ce8037cc0cba5135bff0bfd6fd02810e72b42ce0928002c595c97bf7b3603" def repo(): tf_http_archive( From 7fd32ecc04078ae9a77de25bbe3107aae9303f92 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 11 Mar 2025 07:10:22 -0700 Subject: [PATCH 100/100] [Pallas/Mosaic GPU] Explicitly disable `ops_test` on Mosaic GPU pre-Hopper. PiperOrigin-RevId: 735744473 --- tests/pallas/ops_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 24ce3b722..907f4601a 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -283,6 +283,9 @@ class PallasBaseTest(jtu.JaxTestCase): if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPUs with capability >= sm80") + if (jtu.test_device_matches(["cuda"]) and use_mosaic_gpu and + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Mosaic GPU requires capability >= sm90") super().setUp()