From e35e7f8e205632c6914cabaea3f54b89c35985b5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 29 Oct 2024 17:58:12 -0700 Subject: [PATCH] Allow sparsecore compute with T(8) layout via the layout API and `compute_on` API. To annotate compute on sparsecore, use `@compute_on('tpu_sparsecore')`. PiperOrigin-RevId: 691225280 --- jax/_src/compute_on.py | 7 ++++--- jax/_src/interpreters/mlir.py | 6 +++++- tests/BUILD | 3 +++ tests/layout_test.py | 24 ++++++++++++++++++++++++ 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index 4495d38f9..b5194ddad 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -46,9 +46,10 @@ def current_compute_type() -> str | None: return compute_on_context.stack[-1] if compute_on_context.stack else None def _check_valid(c_type: str): - if c_type not in {'device_host', 'device'}: - raise ValueError('Invalid compute type received. Current supported values ' - f'are `device_host` and `device`. Got {c_type}') + if c_type not in {'device_host', 'device', 'tpu_sparsecore'}: + raise ValueError( + 'Invalid compute type received. Current supported values ' + f'are `device_host`, `device` and `tpu_sparsecore`. Got {c_type}') @contextmanager def compute_on(compute_type: str): diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 2adeb4b16..c71e52385 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1878,6 +1878,8 @@ def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None return () if eqn_ctx.compute_type == 'device_host': return ('cpu',) + if eqn_ctx.compute_type == 'tpu_sparsecore': + return ('tpu',) return () @@ -2160,8 +2162,10 @@ def map_compute_type(c_type): return 'host' elif c_type == 'device': return 'dense' + elif c_type == 'tpu_sparsecore': + return 'sparse' raise ValueError('Invalid compute type received. Current supported values ' - 'are `device_host` and `device`') + 'are `device_host`, `device` and `tpu_sparsecore') def wrap_compute_type_in_place(ctx, op): if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None: diff --git a/tests/BUILD b/tests/BUILD index 657d169ba..316e98f5b 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -271,6 +271,9 @@ jax_multiplatform_test( "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, tags = ["multiaccelerator"], + deps = [ + "//jax:experimental", + ], ) jax_multiplatform_test( diff --git a/tests/layout_test.py b/tests/layout_test.py index 9d26d96e2..406d06dac 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -25,6 +25,7 @@ from jax._src import config from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip +from jax.experimental.compute_on import compute_on config.parse_flags_with_absl() @@ -600,6 +601,29 @@ class LayoutTest(jtu.JaxTestCase): ValueError, ".*Did you mean to set the.*input layout.*AUTO.*"): g(jnp.arange(8)) + def test_sparsecore_compute(self): + if not (jtu.is_device_tpu('5', 'f') or jtu.is_device_tpu_at_least(6)): + self.skipTest('Does not have a sparsecore present') + shape = (128, 128) + inp = jnp.arange(math.prod(shape)).reshape(shape) + + dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) + s = SingleDeviceSharding(jax.devices()[0]) + sparse_layout = Layout(dll, s) + sparecore_arr = jax.device_put(inp, sparse_layout) + dense_layout = Layout(DLL(major_to_minor=(0, 1)), s) + + @compute_on('tpu_sparsecore') + @jax.jit + def sparsecore_compute(x): + return x * x + + @partial(jax.jit, out_shardings=(dense_layout, sparse_layout)) + def f(x, y): + return x * 2, sparsecore_compute(y) + + f(inp, sparecore_arr) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())