let XLA metadata be unset in nested dynamic scopes

Treat `None` metadata values as a special instruction not to set (or
to unset, if nested) the corresponding entry.

In particular, this makes it possible to unset metadata within the
sub-computations of higher-order operations (e.g. branches in
conditionals, loop bodies, etc.). This can be used, for example, to
annotate a conditional but not all the operations in its
branches. That is, the HLO for the following function `f` on a scalar
float argument:

```
def cos(x):
  with set_xla_metadata(a=None):
    return jnp.cos(x)

@jax.jit
def f(x):
  with set_xla_metadata(a="b"):
    return jax.lax.cond(x < 0., jnp.sin, cos, x)
```

produces an attribute `a` on the conditional and on the sine, but not
on the cosine.
This commit is contained in:
Roy Frostig 2025-04-01 19:10:39 -07:00
parent 994af3efb8
commit 1875c76bd2
3 changed files with 44 additions and 4 deletions

View File

@ -320,7 +320,7 @@ class JaxprEqnContextManager:
def __exit__(self, exc_type, exc_value, traceback):
config.compute_on_context_manager.set_local(self.prev_compute_type)
config.threefry_partitionable.set_local(self.prev_threefry_partitionable)
if self.context.xla_metadata is not None:
if self.context.xla_metadata:
config.xla_metadata_context_manager.set_local(self.prev_xla_metadata)
config.abstract_mesh_context_manager.set_local(self.prev_abstract_mesh)

View File

@ -24,6 +24,8 @@ config_ext = xla_client._xla.config
class XlaMetadata:
__slots__ = ['val', 'hash']
val: dict[str, Any]
def __init__(self, val):
self.val = val
self.hash = hash(tuple(sorted(self.val.items())))
@ -35,14 +37,19 @@ class XlaMetadata:
return other is not None and self.val == other.val
def filter_nones(d: dict) -> dict:
return {k: v for k, v in d.items() if v is not None}
def update_metadata(a, b: dict[str, Any]):
if not b:
return a
if a is None or a is config_ext.unset:
return XlaMetadata(b)
val = a.val.copy()
val = {}
else:
val = a.val.copy()
val.update(b)
return XlaMetadata(val)
return XlaMetadata(filter_nones(val))
def current_xla_metadata():

View File

@ -190,6 +190,39 @@ class XlaMetadataTest(jtu.JaxTestCase):
if "stablehlo.add" in line:
self.assertIn('mhlo.frontend_attributes = {a = "c"}', line)
def test_cond_annotates_branches(self):
sin = jnp.sin
cos = jnp.cos
@jax.jit
def f(x):
with set_xla_metadata(a="b"):
return jax.lax.cond(x < 0., sin, cos, x)
hlo_lines = f.lower(1.).as_text().split("\n")
sin_hlo, = [line for line in hlo_lines if "stablehlo.sine" in line]
cos_hlo, = [line for line in hlo_lines if "stablehlo.cosine" in line]
self.assertIn('mhlo.frontend_attributes = {a = "b"}', sin_hlo)
self.assertIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo)
def test_cond_annotates_branches_and_none_unsets(self):
sin = jnp.sin
def cos(x):
with set_xla_metadata(a=None):
return jnp.cos(x)
@jax.jit
def f(x):
with set_xla_metadata(a="b"):
return jax.lax.cond(x < 0., sin, cos, x)
hlo_lines = f.lower(1.).as_text().split("\n")
sin_hlo, = [line for line in hlo_lines if "stablehlo.sine" in line]
cos_hlo, = [line for line in hlo_lines if "stablehlo.cosine" in line]
self.assertIn( 'mhlo.frontend_attributes = {a = "b"}', sin_hlo)
self.assertNotIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo)
def test_nested_jit(self):
@jax.jit
def f(x, y):