diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index e92de91f4..1ad7be815 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -83,10 +83,15 @@ class TPUInterpretParams: replaced with arrays all of `jnp.inf`. Additionaly any floating point operands to any operation will be replaced with (arrays of) `jnp.inf`. Default: False. + uninitialized_memory: If "nan", allocated buffers are initialized to + to contain all NaNs (or to their maximum possible value for integers). + If "zero", allocated buffers are initialized to all zeros. + Default: "nan". """ dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" detect_races: bool = False skip_floating_point_ops: bool = False + uninitialized_memory: Literal["nan", "zero"] = "nan" VectorClock = np.ndarray @@ -1114,7 +1119,8 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): jax.ShapeDtypeStruct((), jnp.int16), device_id, TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], - primitives.uninitialized_value(v.aval.shape, v.aval.dtype), + _uninitialized_value( + v.aval.shape, v.aval.dtype, interpret_params), ordered=True)) out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs) @@ -1279,16 +1285,19 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): def _initialize_output_vals( block_mappings_output: Iterable[BlockMapping], - input_args, input_output_aliases) -> Sequence[jax.Array]: + input_args, input_output_aliases, + interpret_params: TPUInterpretParams, +) -> Sequence[jax.Array]: oi_map = {v: k for k, v in input_output_aliases} output_vals = [] for i, bm in enumerate(block_mappings_output): if i in oi_map: output_vals.append(input_args[oi_map[i]]) else: - output_vals.append(primitives.uninitialized_value( + output_vals.append(_uninitialized_value( bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype)) + bm.array_shape_dtype.dtype, + interpret_params)) return output_vals def _compute_start_indices(block_mapping, loop_idx, *args): @@ -1319,7 +1328,20 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): dtype=np.bool_)]) return lax.squeeze(output, squeeze_dims) -def _pad_to_block_dimension(value, block_shape): +def _uninitialized_value(shape, dtype, interpret_params): + if interpret_params.uninitialized_memory == 'nan': + if jnp.issubdtype(dtype, jnp.floating): + return jnp.full(shape, jnp.nan, dtype) + elif jnp.issubdtype(dtype, jnp.integer): + return jnp.full(shape, jnp.iinfo(dtype).max, dtype) + elif jnp.issubdtype(dtype, jnp.bool): + return jnp.full(shape, False, dtype) + if interpret_params.uninitialized_memory == 'zero': + return jnp.full(shape, 0, dtype) + raise NotImplementedError( + interpret_params.uninitialized_memory + ' + ' + str(dtype)) + +def _pad_to_block_dimension(value, block_shape, interpret_params): """Pads values so the shape evenly divides into block dimensions. For example, if values has a shape of (33, 2, 5) with a block_shape of @@ -1338,7 +1360,7 @@ def _pad_to_block_dimension(value, block_shape): ) if padded_shape != value.shape: pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape)) - pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype) + pad_value = _uninitialized_value((), value.dtype, interpret_params) value = jnp.pad(value, pad_width, constant_values=pad_value) return value @@ -1397,7 +1419,7 @@ def interpret_pallas_call( ] num_inputs = grid_mapping.num_inputs input_args = [ - _pad_to_block_dimension(a, bs) + _pad_to_block_dimension(a, bs, interpret_params) for a, bs in zip(input_args, block_shapes[:num_inputs]) ] @@ -1407,11 +1429,12 @@ def interpret_pallas_call( output_vals = _initialize_output_vals( grid_mapping.block_mappings_output, scalars + input_args, - input_output_aliases) + input_output_aliases, + interpret_params) num_outputs = grid_mapping.num_outputs output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] for out_val, bs in zip(output_vals, output_block_shapes): - padded_val = _pad_to_block_dimension(out_val, bs) + padded_val = _pad_to_block_dimension(out_val, bs, interpret_params) output_buffer_shapes.append(padded_val.shape) output_buffer_ids.append(callback.io_callback( _allocate_buffer, @@ -1466,7 +1489,8 @@ def interpret_pallas_call( jax.ShapeDtypeStruct((), jnp.int16), device_id, TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - primitives.uninitialized_value(var.aval.shape, var.aval.dtype), + _uninitialized_value( + var.aval.shape, var.aval.dtype, interpret_params), ordered=True)) _, input_ids, kernel_output_ids, _ = split_list( diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 71e91a697..bc589855b 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -156,5 +156,36 @@ class InterpretTest(jtu.JaxTestCase): lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo") self.assertNotIn("dot_general", lowered) + @parameterized.parameters('nan', 'zero') + def test_uninitialized_memory(self, uninitialized_memory): + def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): + o1_ref[...] = t1_ref[...] + o2_ref[...] = t2_ref[...] + + x, y, z = pl.pallas_call( + kernel, + out_shape=[ + jax.ShapeDtypeStruct((8, 128), jnp.bfloat16), + jax.ShapeDtypeStruct((8, 128), jnp.int16), + jax.ShapeDtypeStruct((8, 128), jnp.float32), + ], + in_specs=[], + scratch_shapes=[ + pltpu.VMEM((8, 128), jnp.bfloat16), + pltpu.VMEM((8, 128), jnp.int16), + ], + interpret=mosaic_interpret.TPUInterpretParams( + uninitialized_memory=uninitialized_memory), + )() + if uninitialized_memory == 'nan': + self.assertTrue(jnp.isnan(x).all()) + np.testing.assert_equal(np.array(y), 32767) + self.assertTrue(jnp.isnan(z).all()) + if uninitialized_memory == 'zero': + np.testing.assert_equal(np.array(x), 0) + np.testing.assert_equal(np.array(y), 0) + np.testing.assert_equal(np.array(z), 0) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())