mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
let XLA metadata be unset in nested dynamic scopes
Treat `None` metadata values as a special instruction not to set (or to unset, if nested) the corresponding entry. In particular, this makes it possible to unset metadata within the sub-computations of higher-order operations (e.g. branches in conditionals, loop bodies, etc.). This can be used, for example, to annotate a conditional but not all the operations in its branches. That is, the HLO for the following function `f` on a scalar float argument: ``` def cos(x): with set_xla_metadata(a=None): return jnp.cos(x) @jax.jit def f(x): with set_xla_metadata(a="b"): return jax.lax.cond(x < 0., jnp.sin, cos, x) ``` produces an attribute `a` on the conditional and on the sine, but not on the cosine.
This commit is contained in:
parent
994af3efb8
commit
1875c76bd2
@ -320,7 +320,7 @@ class JaxprEqnContextManager:
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
config.compute_on_context_manager.set_local(self.prev_compute_type)
|
||||
config.threefry_partitionable.set_local(self.prev_threefry_partitionable)
|
||||
if self.context.xla_metadata is not None:
|
||||
if self.context.xla_metadata:
|
||||
config.xla_metadata_context_manager.set_local(self.prev_xla_metadata)
|
||||
config.abstract_mesh_context_manager.set_local(self.prev_abstract_mesh)
|
||||
|
||||
|
@ -24,6 +24,8 @@ config_ext = xla_client._xla.config
|
||||
class XlaMetadata:
|
||||
__slots__ = ['val', 'hash']
|
||||
|
||||
val: dict[str, Any]
|
||||
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
self.hash = hash(tuple(sorted(self.val.items())))
|
||||
@ -35,14 +37,19 @@ class XlaMetadata:
|
||||
return other is not None and self.val == other.val
|
||||
|
||||
|
||||
def filter_nones(d: dict) -> dict:
|
||||
return {k: v for k, v in d.items() if v is not None}
|
||||
|
||||
|
||||
def update_metadata(a, b: dict[str, Any]):
|
||||
if not b:
|
||||
return a
|
||||
if a is None or a is config_ext.unset:
|
||||
return XlaMetadata(b)
|
||||
val = a.val.copy()
|
||||
val = {}
|
||||
else:
|
||||
val = a.val.copy()
|
||||
val.update(b)
|
||||
return XlaMetadata(val)
|
||||
return XlaMetadata(filter_nones(val))
|
||||
|
||||
|
||||
def current_xla_metadata():
|
||||
|
@ -190,6 +190,39 @@ class XlaMetadataTest(jtu.JaxTestCase):
|
||||
if "stablehlo.add" in line:
|
||||
self.assertIn('mhlo.frontend_attributes = {a = "c"}', line)
|
||||
|
||||
def test_cond_annotates_branches(self):
|
||||
sin = jnp.sin
|
||||
cos = jnp.cos
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
with set_xla_metadata(a="b"):
|
||||
return jax.lax.cond(x < 0., sin, cos, x)
|
||||
|
||||
hlo_lines = f.lower(1.).as_text().split("\n")
|
||||
sin_hlo, = [line for line in hlo_lines if "stablehlo.sine" in line]
|
||||
cos_hlo, = [line for line in hlo_lines if "stablehlo.cosine" in line]
|
||||
self.assertIn('mhlo.frontend_attributes = {a = "b"}', sin_hlo)
|
||||
self.assertIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo)
|
||||
|
||||
def test_cond_annotates_branches_and_none_unsets(self):
|
||||
sin = jnp.sin
|
||||
|
||||
def cos(x):
|
||||
with set_xla_metadata(a=None):
|
||||
return jnp.cos(x)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
with set_xla_metadata(a="b"):
|
||||
return jax.lax.cond(x < 0., sin, cos, x)
|
||||
|
||||
hlo_lines = f.lower(1.).as_text().split("\n")
|
||||
sin_hlo, = [line for line in hlo_lines if "stablehlo.sine" in line]
|
||||
cos_hlo, = [line for line in hlo_lines if "stablehlo.cosine" in line]
|
||||
self.assertIn( 'mhlo.frontend_attributes = {a = "b"}', sin_hlo)
|
||||
self.assertNotIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo)
|
||||
|
||||
def test_nested_jit(self):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
|
Loading…
x
Reference in New Issue
Block a user