mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Allow sparsecore compute with T(8) layout via the layout API and compute_on
API. To annotate compute on sparsecore, use @compute_on('tpu_sparsecore')
.
PiperOrigin-RevId: 691225280
This commit is contained in:
parent
72f9a49358
commit
e35e7f8e20
@ -46,9 +46,10 @@ def current_compute_type() -> str | None:
|
||||
return compute_on_context.stack[-1] if compute_on_context.stack else None
|
||||
|
||||
def _check_valid(c_type: str):
|
||||
if c_type not in {'device_host', 'device'}:
|
||||
raise ValueError('Invalid compute type received. Current supported values '
|
||||
f'are `device_host` and `device`. Got {c_type}')
|
||||
if c_type not in {'device_host', 'device', 'tpu_sparsecore'}:
|
||||
raise ValueError(
|
||||
'Invalid compute type received. Current supported values '
|
||||
f'are `device_host`, `device` and `tpu_sparsecore`. Got {c_type}')
|
||||
|
||||
@contextmanager
|
||||
def compute_on(compute_type: str):
|
||||
|
@ -1878,6 +1878,8 @@ def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None
|
||||
return ()
|
||||
if eqn_ctx.compute_type == 'device_host':
|
||||
return ('cpu',)
|
||||
if eqn_ctx.compute_type == 'tpu_sparsecore':
|
||||
return ('tpu',)
|
||||
return ()
|
||||
|
||||
|
||||
@ -2160,8 +2162,10 @@ def map_compute_type(c_type):
|
||||
return 'host'
|
||||
elif c_type == 'device':
|
||||
return 'dense'
|
||||
elif c_type == 'tpu_sparsecore':
|
||||
return 'sparse'
|
||||
raise ValueError('Invalid compute type received. Current supported values '
|
||||
'are `device_host` and `device`')
|
||||
'are `device_host`, `device` and `tpu_sparsecore')
|
||||
|
||||
def wrap_compute_type_in_place(ctx, op):
|
||||
if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None:
|
||||
|
@ -271,6 +271,9 @@ jax_multiplatform_test(
|
||||
"tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit.
|
||||
},
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
|
@ -25,6 +25,7 @@ from jax._src import config
|
||||
from jax._src.layout import Layout, DeviceLocalLayout as DLL
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import safe_zip
|
||||
from jax.experimental.compute_on import compute_on
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -600,6 +601,29 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
ValueError, ".*Did you mean to set the.*input layout.*AUTO.*"):
|
||||
g(jnp.arange(8))
|
||||
|
||||
def test_sparsecore_compute(self):
|
||||
if not (jtu.is_device_tpu('5', 'f') or jtu.is_device_tpu_at_least(6)):
|
||||
self.skipTest('Does not have a sparsecore present')
|
||||
shape = (128, 128)
|
||||
inp = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
dll = DLL(major_to_minor=(0, 1), _tiling=((8,),))
|
||||
s = SingleDeviceSharding(jax.devices()[0])
|
||||
sparse_layout = Layout(dll, s)
|
||||
sparecore_arr = jax.device_put(inp, sparse_layout)
|
||||
dense_layout = Layout(DLL(major_to_minor=(0, 1)), s)
|
||||
|
||||
@compute_on('tpu_sparsecore')
|
||||
@jax.jit
|
||||
def sparsecore_compute(x):
|
||||
return x * x
|
||||
|
||||
@partial(jax.jit, out_shardings=(dense_layout, sparse_layout))
|
||||
def f(x, y):
|
||||
return x * 2, sparsecore_compute(y)
|
||||
|
||||
f(inp, sparecore_arr)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user