From 47e8effdcea5c17dd9f974f1020cfd6bf4630f76 Mon Sep 17 00:00:00 2001
From: Jacob Burnim <jburnim@google.com>
Date: Tue, 18 Mar 2025 10:59:17 -0700
Subject: [PATCH] Adds option to initialize buffers to NaNs or zeros in TPU
 interpret mode.

---
 jax/_src/pallas/mosaic/interpret.py       | 44 +++++++++++++++++------
 tests/pallas/tpu_pallas_interpret_test.py | 31 ++++++++++++++++
 2 files changed, 65 insertions(+), 10 deletions(-)

diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py
index e92de91f4..1ad7be815 100644
--- a/jax/_src/pallas/mosaic/interpret.py
+++ b/jax/_src/pallas/mosaic/interpret.py
@@ -83,10 +83,15 @@ class TPUInterpretParams:
       replaced with arrays all of `jnp.inf`. Additionaly any floating point
       operands to any operation will be replaced with (arrays of) `jnp.inf`.
       Default: False.
+    uninitialized_memory: If "nan", allocated buffers are initialized to
+      to contain all NaNs (or to their maximum possible value for integers).
+      If "zero", allocated buffers are initialized to all zeros.
+      Default: "nan".
   """
   dma_execution_mode: Literal["eager", "on_wait"] = "on_wait"
   detect_races: bool = False
   skip_floating_point_ops: bool = False
+  uninitialized_memory: Literal["nan", "zero"] = "nan"
 
 
 VectorClock = np.ndarray
@@ -1114,7 +1119,8 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params):
                 jax.ShapeDtypeStruct((), jnp.int16),
                 device_id,
                 TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
-                primitives.uninitialized_value(v.aval.shape, v.aval.dtype),
+                _uninitialized_value(
+                    v.aval.shape, v.aval.dtype, interpret_params),
                 ordered=True))
 
         out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs)
@@ -1279,16 +1285,19 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params):
 
 def _initialize_output_vals(
     block_mappings_output: Iterable[BlockMapping],
-    input_args, input_output_aliases) -> Sequence[jax.Array]:
+    input_args, input_output_aliases,
+    interpret_params: TPUInterpretParams,
+) -> Sequence[jax.Array]:
   oi_map = {v: k for k, v in input_output_aliases}
   output_vals = []
   for i, bm in enumerate(block_mappings_output):
     if i in oi_map:
       output_vals.append(input_args[oi_map[i]])
     else:
-      output_vals.append(primitives.uninitialized_value(
+      output_vals.append(_uninitialized_value(
           bm.array_shape_dtype.shape,
-          bm.array_shape_dtype.dtype))
+          bm.array_shape_dtype.dtype,
+          interpret_params))
   return output_vals
 
 def _compute_start_indices(block_mapping, loop_idx, *args):
@@ -1319,7 +1328,20 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
                                                             dtype=np.bool_)])
   return lax.squeeze(output, squeeze_dims)
 
-def _pad_to_block_dimension(value, block_shape):
+def _uninitialized_value(shape, dtype, interpret_params):
+  if interpret_params.uninitialized_memory == 'nan':
+    if jnp.issubdtype(dtype, jnp.floating):
+      return jnp.full(shape, jnp.nan, dtype)
+    elif jnp.issubdtype(dtype, jnp.integer):
+      return jnp.full(shape, jnp.iinfo(dtype).max, dtype)
+    elif jnp.issubdtype(dtype, jnp.bool):
+      return jnp.full(shape, False, dtype)
+  if interpret_params.uninitialized_memory == 'zero':
+    return jnp.full(shape, 0, dtype)
+  raise NotImplementedError(
+      interpret_params.uninitialized_memory + ' + ' + str(dtype))
+
+def _pad_to_block_dimension(value, block_shape, interpret_params):
   """Pads values so the shape evenly divides into block dimensions.
 
   For example, if values has a shape of (33, 2, 5) with a block_shape of
@@ -1338,7 +1360,7 @@ def _pad_to_block_dimension(value, block_shape):
   )
   if padded_shape != value.shape:
     pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
-    pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype)
+    pad_value = _uninitialized_value((), value.dtype, interpret_params)
     value = jnp.pad(value, pad_width, constant_values=pad_value)
   return value
 
@@ -1397,7 +1419,7 @@ def interpret_pallas_call(
   ]
   num_inputs = grid_mapping.num_inputs
   input_args = [
-      _pad_to_block_dimension(a, bs)
+      _pad_to_block_dimension(a, bs, interpret_params)
       for a, bs in zip(input_args, block_shapes[:num_inputs])
   ]
 
@@ -1407,11 +1429,12 @@ def interpret_pallas_call(
   output_vals = _initialize_output_vals(
       grid_mapping.block_mappings_output,
       scalars + input_args,
-      input_output_aliases)
+      input_output_aliases,
+      interpret_params)
   num_outputs = grid_mapping.num_outputs
   output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs]
   for out_val, bs in zip(output_vals, output_block_shapes):
-    padded_val = _pad_to_block_dimension(out_val, bs)
+    padded_val = _pad_to_block_dimension(out_val, bs, interpret_params)
     output_buffer_shapes.append(padded_val.shape)
     output_buffer_ids.append(callback.io_callback(
         _allocate_buffer,
@@ -1466,7 +1489,8 @@ def interpret_pallas_call(
           jax.ShapeDtypeStruct((), jnp.int16),
           device_id,
           TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
-          primitives.uninitialized_value(var.aval.shape, var.aval.dtype),
+          _uninitialized_value(
+              var.aval.shape, var.aval.dtype, interpret_params),
           ordered=True))
 
   _, input_ids, kernel_output_ids, _  = split_list(
diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py
index 71e91a697..bc589855b 100644
--- a/tests/pallas/tpu_pallas_interpret_test.py
+++ b/tests/pallas/tpu_pallas_interpret_test.py
@@ -156,5 +156,36 @@ class InterpretTest(jtu.JaxTestCase):
     lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo")
     self.assertNotIn("dot_general", lowered)
 
+  @parameterized.parameters('nan', 'zero')
+  def test_uninitialized_memory(self, uninitialized_memory):
+    def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref):
+      o1_ref[...] = t1_ref[...]
+      o2_ref[...] = t2_ref[...]
+
+    x, y, z = pl.pallas_call(
+        kernel,
+        out_shape=[
+            jax.ShapeDtypeStruct((8, 128), jnp.bfloat16),
+            jax.ShapeDtypeStruct((8, 128), jnp.int16),
+            jax.ShapeDtypeStruct((8, 128), jnp.float32),
+        ],
+        in_specs=[],
+        scratch_shapes=[
+            pltpu.VMEM((8, 128), jnp.bfloat16),
+            pltpu.VMEM((8, 128), jnp.int16),
+        ],
+        interpret=mosaic_interpret.TPUInterpretParams(
+            uninitialized_memory=uninitialized_memory),
+    )()
+    if uninitialized_memory == 'nan':
+      self.assertTrue(jnp.isnan(x).all())
+      np.testing.assert_equal(np.array(y), 32767)
+      self.assertTrue(jnp.isnan(z).all())
+    if uninitialized_memory == 'zero':
+      np.testing.assert_equal(np.array(x), 0)
+      np.testing.assert_equal(np.array(y), 0)
+      np.testing.assert_equal(np.array(z), 0)
+
+
 if __name__ == "__main__":
   absltest.main(testLoader=jtu.JaxTestLoader())