Allow sparsecore compute with T(8) layout via the layout API and compute_on API. To annotate compute on sparsecore, use @compute_on('tpu_sparsecore').

PiperOrigin-RevId: 691225280
This commit is contained in:
Yash Katariya 2024-10-29 17:58:12 -07:00 committed by jax authors
parent 72f9a49358
commit e35e7f8e20
4 changed files with 36 additions and 4 deletions

View File

@ -46,9 +46,10 @@ def current_compute_type() -> str | None:
return compute_on_context.stack[-1] if compute_on_context.stack else None
def _check_valid(c_type: str):
if c_type not in {'device_host', 'device'}:
raise ValueError('Invalid compute type received. Current supported values '
f'are `device_host` and `device`. Got {c_type}')
if c_type not in {'device_host', 'device', 'tpu_sparsecore'}:
raise ValueError(
'Invalid compute type received. Current supported values '
f'are `device_host`, `device` and `tpu_sparsecore`. Got {c_type}')
@contextmanager
def compute_on(compute_type: str):

View File

@ -1878,6 +1878,8 @@ def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None
return ()
if eqn_ctx.compute_type == 'device_host':
return ('cpu',)
if eqn_ctx.compute_type == 'tpu_sparsecore':
return ('tpu',)
return ()
@ -2160,8 +2162,10 @@ def map_compute_type(c_type):
return 'host'
elif c_type == 'device':
return 'dense'
elif c_type == 'tpu_sparsecore':
return 'sparse'
raise ValueError('Invalid compute type received. Current supported values '
'are `device_host` and `device`')
'are `device_host`, `device` and `tpu_sparsecore')
def wrap_compute_type_in_place(ctx, op):
if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None:

View File

@ -271,6 +271,9 @@ jax_multiplatform_test(
"tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit.
},
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
],
)
jax_multiplatform_test(

View File

@ -25,6 +25,7 @@ from jax._src import config
from jax._src.layout import Layout, DeviceLocalLayout as DLL
from jax._src import test_util as jtu
from jax._src.util import safe_zip
from jax.experimental.compute_on import compute_on
config.parse_flags_with_absl()
@ -600,6 +601,29 @@ class LayoutTest(jtu.JaxTestCase):
ValueError, ".*Did you mean to set the.*input layout.*AUTO.*"):
g(jnp.arange(8))
def test_sparsecore_compute(self):
if not (jtu.is_device_tpu('5', 'f') or jtu.is_device_tpu_at_least(6)):
self.skipTest('Does not have a sparsecore present')
shape = (128, 128)
inp = jnp.arange(math.prod(shape)).reshape(shape)
dll = DLL(major_to_minor=(0, 1), _tiling=((8,),))
s = SingleDeviceSharding(jax.devices()[0])
sparse_layout = Layout(dll, s)
sparecore_arr = jax.device_put(inp, sparse_layout)
dense_layout = Layout(DLL(major_to_minor=(0, 1)), s)
@compute_on('tpu_sparsecore')
@jax.jit
def sparsecore_compute(x):
return x * x
@partial(jax.jit, out_shardings=(dense_layout, sparse_layout))
def f(x, y):
return x * 2, sparsecore_compute(y)
f(inp, sparecore_arr)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())