mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #27227 from jburnim:jburnim_pallas_interpret_mode4
PiperOrigin-RevId: 738235363
This commit is contained in:
commit
e9ce8fb92d
@ -83,10 +83,15 @@ class TPUInterpretParams:
|
|||||||
replaced with arrays all of `jnp.inf`. Additionaly any floating point
|
replaced with arrays all of `jnp.inf`. Additionaly any floating point
|
||||||
operands to any operation will be replaced with (arrays of) `jnp.inf`.
|
operands to any operation will be replaced with (arrays of) `jnp.inf`.
|
||||||
Default: False.
|
Default: False.
|
||||||
|
uninitialized_memory: If "nan", allocated buffers are initialized to
|
||||||
|
to contain all NaNs (or to their maximum possible value for integers).
|
||||||
|
If "zero", allocated buffers are initialized to all zeros.
|
||||||
|
Default: "nan".
|
||||||
"""
|
"""
|
||||||
dma_execution_mode: Literal["eager", "on_wait"] = "on_wait"
|
dma_execution_mode: Literal["eager", "on_wait"] = "on_wait"
|
||||||
detect_races: bool = False
|
detect_races: bool = False
|
||||||
skip_floating_point_ops: bool = False
|
skip_floating_point_ops: bool = False
|
||||||
|
uninitialized_memory: Literal["nan", "zero"] = "nan"
|
||||||
|
|
||||||
|
|
||||||
VectorClock = np.ndarray
|
VectorClock = np.ndarray
|
||||||
@ -1114,7 +1119,8 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params):
|
|||||||
jax.ShapeDtypeStruct((), jnp.int16),
|
jax.ShapeDtypeStruct((), jnp.int16),
|
||||||
device_id,
|
device_id,
|
||||||
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
|
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
|
||||||
primitives.uninitialized_value(v.aval.shape, v.aval.dtype),
|
_uninitialized_value(
|
||||||
|
v.aval.shape, v.aval.dtype, interpret_params),
|
||||||
ordered=True))
|
ordered=True))
|
||||||
|
|
||||||
out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs)
|
out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs)
|
||||||
@ -1279,16 +1285,19 @@ def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params):
|
|||||||
|
|
||||||
def _initialize_output_vals(
|
def _initialize_output_vals(
|
||||||
block_mappings_output: Iterable[BlockMapping],
|
block_mappings_output: Iterable[BlockMapping],
|
||||||
input_args, input_output_aliases) -> Sequence[jax.Array]:
|
input_args, input_output_aliases,
|
||||||
|
interpret_params: TPUInterpretParams,
|
||||||
|
) -> Sequence[jax.Array]:
|
||||||
oi_map = {v: k for k, v in input_output_aliases}
|
oi_map = {v: k for k, v in input_output_aliases}
|
||||||
output_vals = []
|
output_vals = []
|
||||||
for i, bm in enumerate(block_mappings_output):
|
for i, bm in enumerate(block_mappings_output):
|
||||||
if i in oi_map:
|
if i in oi_map:
|
||||||
output_vals.append(input_args[oi_map[i]])
|
output_vals.append(input_args[oi_map[i]])
|
||||||
else:
|
else:
|
||||||
output_vals.append(primitives.uninitialized_value(
|
output_vals.append(_uninitialized_value(
|
||||||
bm.array_shape_dtype.shape,
|
bm.array_shape_dtype.shape,
|
||||||
bm.array_shape_dtype.dtype))
|
bm.array_shape_dtype.dtype,
|
||||||
|
interpret_params))
|
||||||
return output_vals
|
return output_vals
|
||||||
|
|
||||||
def _compute_start_indices(block_mapping, loop_idx, *args):
|
def _compute_start_indices(block_mapping, loop_idx, *args):
|
||||||
@ -1319,7 +1328,20 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
|
|||||||
dtype=np.bool_)])
|
dtype=np.bool_)])
|
||||||
return lax.squeeze(output, squeeze_dims)
|
return lax.squeeze(output, squeeze_dims)
|
||||||
|
|
||||||
def _pad_to_block_dimension(value, block_shape):
|
def _uninitialized_value(shape, dtype, interpret_params):
|
||||||
|
if interpret_params.uninitialized_memory == 'nan':
|
||||||
|
if jnp.issubdtype(dtype, jnp.floating):
|
||||||
|
return jnp.full(shape, jnp.nan, dtype)
|
||||||
|
elif jnp.issubdtype(dtype, jnp.integer):
|
||||||
|
return jnp.full(shape, jnp.iinfo(dtype).max, dtype)
|
||||||
|
elif jnp.issubdtype(dtype, jnp.bool):
|
||||||
|
return jnp.full(shape, False, dtype)
|
||||||
|
if interpret_params.uninitialized_memory == 'zero':
|
||||||
|
return jnp.full(shape, 0, dtype)
|
||||||
|
raise NotImplementedError(
|
||||||
|
interpret_params.uninitialized_memory + ' + ' + str(dtype))
|
||||||
|
|
||||||
|
def _pad_to_block_dimension(value, block_shape, interpret_params):
|
||||||
"""Pads values so the shape evenly divides into block dimensions.
|
"""Pads values so the shape evenly divides into block dimensions.
|
||||||
|
|
||||||
For example, if values has a shape of (33, 2, 5) with a block_shape of
|
For example, if values has a shape of (33, 2, 5) with a block_shape of
|
||||||
@ -1338,7 +1360,7 @@ def _pad_to_block_dimension(value, block_shape):
|
|||||||
)
|
)
|
||||||
if padded_shape != value.shape:
|
if padded_shape != value.shape:
|
||||||
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
|
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
|
||||||
pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype)
|
pad_value = _uninitialized_value((), value.dtype, interpret_params)
|
||||||
value = jnp.pad(value, pad_width, constant_values=pad_value)
|
value = jnp.pad(value, pad_width, constant_values=pad_value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -1397,7 +1419,7 @@ def interpret_pallas_call(
|
|||||||
]
|
]
|
||||||
num_inputs = grid_mapping.num_inputs
|
num_inputs = grid_mapping.num_inputs
|
||||||
input_args = [
|
input_args = [
|
||||||
_pad_to_block_dimension(a, bs)
|
_pad_to_block_dimension(a, bs, interpret_params)
|
||||||
for a, bs in zip(input_args, block_shapes[:num_inputs])
|
for a, bs in zip(input_args, block_shapes[:num_inputs])
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1407,11 +1429,12 @@ def interpret_pallas_call(
|
|||||||
output_vals = _initialize_output_vals(
|
output_vals = _initialize_output_vals(
|
||||||
grid_mapping.block_mappings_output,
|
grid_mapping.block_mappings_output,
|
||||||
scalars + input_args,
|
scalars + input_args,
|
||||||
input_output_aliases)
|
input_output_aliases,
|
||||||
|
interpret_params)
|
||||||
num_outputs = grid_mapping.num_outputs
|
num_outputs = grid_mapping.num_outputs
|
||||||
output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs]
|
output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs]
|
||||||
for out_val, bs in zip(output_vals, output_block_shapes):
|
for out_val, bs in zip(output_vals, output_block_shapes):
|
||||||
padded_val = _pad_to_block_dimension(out_val, bs)
|
padded_val = _pad_to_block_dimension(out_val, bs, interpret_params)
|
||||||
output_buffer_shapes.append(padded_val.shape)
|
output_buffer_shapes.append(padded_val.shape)
|
||||||
output_buffer_ids.append(callback.io_callback(
|
output_buffer_ids.append(callback.io_callback(
|
||||||
_allocate_buffer,
|
_allocate_buffer,
|
||||||
@ -1466,7 +1489,8 @@ def interpret_pallas_call(
|
|||||||
jax.ShapeDtypeStruct((), jnp.int16),
|
jax.ShapeDtypeStruct((), jnp.int16),
|
||||||
device_id,
|
device_id,
|
||||||
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
|
||||||
primitives.uninitialized_value(var.aval.shape, var.aval.dtype),
|
_uninitialized_value(
|
||||||
|
var.aval.shape, var.aval.dtype, interpret_params),
|
||||||
ordered=True))
|
ordered=True))
|
||||||
|
|
||||||
_, input_ids, kernel_output_ids, _ = split_list(
|
_, input_ids, kernel_output_ids, _ = split_list(
|
||||||
|
@ -156,5 +156,36 @@ class InterpretTest(jtu.JaxTestCase):
|
|||||||
lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo")
|
lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo")
|
||||||
self.assertNotIn("dot_general", lowered)
|
self.assertNotIn("dot_general", lowered)
|
||||||
|
|
||||||
|
@parameterized.parameters('nan', 'zero')
|
||||||
|
def test_uninitialized_memory(self, uninitialized_memory):
|
||||||
|
def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref):
|
||||||
|
o1_ref[...] = t1_ref[...]
|
||||||
|
o2_ref[...] = t2_ref[...]
|
||||||
|
|
||||||
|
x, y, z = pl.pallas_call(
|
||||||
|
kernel,
|
||||||
|
out_shape=[
|
||||||
|
jax.ShapeDtypeStruct((8, 128), jnp.bfloat16),
|
||||||
|
jax.ShapeDtypeStruct((8, 128), jnp.int16),
|
||||||
|
jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||||
|
],
|
||||||
|
in_specs=[],
|
||||||
|
scratch_shapes=[
|
||||||
|
pltpu.VMEM((8, 128), jnp.bfloat16),
|
||||||
|
pltpu.VMEM((8, 128), jnp.int16),
|
||||||
|
],
|
||||||
|
interpret=mosaic_interpret.TPUInterpretParams(
|
||||||
|
uninitialized_memory=uninitialized_memory),
|
||||||
|
)()
|
||||||
|
if uninitialized_memory == 'nan':
|
||||||
|
self.assertTrue(jnp.isnan(x).all())
|
||||||
|
np.testing.assert_equal(np.array(y), 32767)
|
||||||
|
self.assertTrue(jnp.isnan(z).all())
|
||||||
|
if uninitialized_memory == 'zero':
|
||||||
|
np.testing.assert_equal(np.array(x), 0)
|
||||||
|
np.testing.assert_equal(np.array(y), 0)
|
||||||
|
np.testing.assert_equal(np.array(z), 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user