mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Jax caches should depend on axis env.
This commit is contained in:
parent
90af8e8135
commit
6c5d204d7e
@ -21,7 +21,7 @@ import itertools
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, List, Callable, NamedTuple, Iterator, Optional
|
||||
from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
@ -374,7 +374,12 @@ class Config:
|
||||
|
||||
Values included in this set should also most likely be included in
|
||||
the C++ JIT state, which is handled separately."""
|
||||
return (self.x64_enabled, self.jax_numpy_rank_promotion,
|
||||
tls = jax_jit.thread_local_state()
|
||||
axis_env_state = ()
|
||||
context = tls.extra_jit_context
|
||||
if context and context.axis_env_state is not None:
|
||||
axis_env_state = context.axis_env_state
|
||||
return (axis_env_state, self.x64_enabled, self.jax_numpy_rank_promotion,
|
||||
self.jax_default_matmul_precision, self.jax_dynamic_shapes,
|
||||
self.jax_numpy_dtype_promotion, self.jax_default_device)
|
||||
|
||||
@ -483,6 +488,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
|
||||
`_update_thread_local_jit_state` in core.py to prevent circular imports.
|
||||
"""
|
||||
dynamic_trace_state: Optional[Any] = None
|
||||
axis_env_state: Optional[Hashable] = None
|
||||
numpy_rank_promotion: Optional[str] = None
|
||||
numpy_dtype_promotion: Optional[str] = None
|
||||
default_matmul_precision: Optional[Any] = None
|
||||
|
@ -4695,6 +4695,7 @@ def _compress_method(a, condition, axis=None, out=None):
|
||||
return compress(condition, a, axis, out)
|
||||
|
||||
|
||||
@core.stash_axis_env()
|
||||
@partial(jit, static_argnums=(1,2,3))
|
||||
def _multi_slice(arr,
|
||||
start_indices: Tuple[Tuple[int, ...]],
|
||||
|
43
jax/core.py
43
jax/core.py
@ -2051,21 +2051,51 @@ aval_mapping_handlers: Dict[Type, AvalMapHandlerPair] = {
|
||||
@contextmanager
|
||||
def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
|
||||
frame = AxisEnvFrame(axis_name, size, tag)
|
||||
thread_local_state.trace_state.axis_env.append(frame)
|
||||
ts = thread_local_state.trace_state
|
||||
ts.axis_env.append(frame)
|
||||
jax_config.update_thread_local_jit_state(
|
||||
axis_env_state=tuple(f for f in ts.axis_env
|
||||
if f.name is not no_axis_name))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
thread_local_state.trace_state.axis_env.pop()
|
||||
ts.axis_env.pop()
|
||||
jax_config.update_thread_local_jit_state(
|
||||
axis_env_state=tuple(f for f in ts.axis_env
|
||||
if f.name is not no_axis_name))
|
||||
|
||||
@contextmanager
|
||||
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]):
|
||||
frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes]
|
||||
thread_local_state.trace_state.axis_env.extend(frames)
|
||||
ts = thread_local_state.trace_state
|
||||
ts.axis_env.extend(frames)
|
||||
jax_config.update_thread_local_jit_state(
|
||||
axis_env_state=tuple(f for f in ts.axis_env
|
||||
if f.name is not no_axis_name))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for _ in frames:
|
||||
thread_local_state.trace_state.axis_env.pop()
|
||||
for _ in frames: ts.axis_env.pop()
|
||||
jax_config.update_thread_local_jit_state(
|
||||
axis_env_state=tuple(f for f in ts.axis_env
|
||||
if f.name is not no_axis_name))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def stash_axis_env():
|
||||
"Promise that a function or with-suite does not depend implicitly on axis env"
|
||||
# If the promise is broken, then a NameError about an unbound axis name will
|
||||
# be raised.
|
||||
ts = thread_local_state.trace_state
|
||||
prev_axis_env, ts.axis_env = ts.axis_env, []
|
||||
jax_config.update_thread_local_jit_state(axis_env_state=())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
ts.axis_env = prev_axis_env
|
||||
jax_config.update_thread_local_jit_state(
|
||||
axis_env_state=tuple(f for f in ts.axis_env
|
||||
if f.name is not no_axis_name))
|
||||
|
||||
|
||||
# When a mapped function is given no axis name, we generate a name object based
|
||||
@ -2601,7 +2631,8 @@ def _compact_eqn_should_include(k: str, v: Any) -> bool:
|
||||
if k == 'branches': return False
|
||||
if isinstance(v, (Jaxpr, ClosedJaxpr)): return False
|
||||
if (isinstance(v, tuple) and
|
||||
any(isinstance(e, (Jaxpr, ClosedJaxpr)) for e in v)): return False
|
||||
any(isinstance(e, (Jaxpr, ClosedJaxpr)) for e in v)):
|
||||
return False
|
||||
return True
|
||||
|
||||
def str_eqn_compact(primitive_name: str, params: Dict) -> str:
|
||||
|
@ -1106,6 +1106,28 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
finally:
|
||||
xla.xla_call_p.def_impl(jit_impl)
|
||||
|
||||
def test_caches_depend_on_axis_env(self):
|
||||
# https://github.com/google/jax/issues/9187
|
||||
f = lambda: lax.psum(1, "i")
|
||||
g = jax.jit(f)
|
||||
expected = jax.vmap(f, axis_name="i", axis_size=2, out_axes=None)()
|
||||
ans = jax.vmap(g, axis_name="i", axis_size=2, out_axes=None)()
|
||||
self.assertEqual(ans, expected)
|
||||
|
||||
# This second call to g could erroneously get a cache hit.
|
||||
expected = jax.vmap(f, axis_name="i", axis_size=3, out_axes=None)()
|
||||
ans = jax.vmap(g, axis_name="i", axis_size=3, out_axes=None)()
|
||||
self.assertEqual(ans, expected)
|
||||
|
||||
def test_caches_dont_depend_on_unnamed_axis_env(self):
|
||||
# https://github.com/google/jax/issues/9187
|
||||
f = jax.jit(lambda: jnp.sin(1))
|
||||
expected = f()
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
ans = jax.vmap(f, axis_size=2, out_axes=None)()
|
||||
self.assertEqual(count[0], 0) # no compiles
|
||||
self.assertArraysAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
|
||||
class PythonJitTest(CPPJitTest):
|
||||
|
||||
|
@ -2974,6 +2974,15 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
|
||||
jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"])
|
||||
|
||||
def test_caches_depend_on_axis_env(self):
|
||||
# https://github.com/google/jax/issues/9187
|
||||
scanned_f = lambda _, __: (lax.psum(1, 'i'), None)
|
||||
f = lambda: lax.scan(scanned_f, 0, None, length=1)[0]
|
||||
ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
|
||||
self.assertEqual(ans, 2)
|
||||
ans = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
|
||||
self.assertEqual(ans, 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -1949,8 +1949,9 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(count[0], 2) # one for fwd, one for bwd
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
_ = jax.vjp(f, x)
|
||||
_, f_bwd2 = jax.vjp(f, x)
|
||||
_ = f_bwd(x)
|
||||
_ = f_bwd2(x)
|
||||
self.assertEqual(count[0], 0) # cache hits on fwd and bwd
|
||||
|
||||
def testSizeOverflow(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user