Jax caches should depend on axis env.

This commit is contained in:
Parker Schuh 2022-06-28 13:50:49 -07:00
parent 90af8e8135
commit 6c5d204d7e
6 changed files with 79 additions and 9 deletions

View File

@ -21,7 +21,7 @@ import itertools
import os
import sys
import threading
from typing import Any, List, Callable, NamedTuple, Iterator, Optional
from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional
import warnings
from absl import logging
@ -374,7 +374,12 @@ class Config:
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately."""
return (self.x64_enabled, self.jax_numpy_rank_promotion,
tls = jax_jit.thread_local_state()
axis_env_state = ()
context = tls.extra_jit_context
if context and context.axis_env_state is not None:
axis_env_state = context.axis_env_state
return (axis_env_state, self.x64_enabled, self.jax_numpy_rank_promotion,
self.jax_default_matmul_precision, self.jax_dynamic_shapes,
self.jax_numpy_dtype_promotion, self.jax_default_device)
@ -483,6 +488,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
`_update_thread_local_jit_state` in core.py to prevent circular imports.
"""
dynamic_trace_state: Optional[Any] = None
axis_env_state: Optional[Hashable] = None
numpy_rank_promotion: Optional[str] = None
numpy_dtype_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None

View File

@ -4695,6 +4695,7 @@ def _compress_method(a, condition, axis=None, out=None):
return compress(condition, a, axis, out)
@core.stash_axis_env()
@partial(jit, static_argnums=(1,2,3))
def _multi_slice(arr,
start_indices: Tuple[Tuple[int, ...]],

View File

@ -2051,21 +2051,51 @@ aval_mapping_handlers: Dict[Type, AvalMapHandlerPair] = {
@contextmanager
def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
frame = AxisEnvFrame(axis_name, size, tag)
thread_local_state.trace_state.axis_env.append(frame)
ts = thread_local_state.trace_state
ts.axis_env.append(frame)
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
try:
yield
finally:
thread_local_state.trace_state.axis_env.pop()
ts.axis_env.pop()
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
@contextmanager
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]):
frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes]
thread_local_state.trace_state.axis_env.extend(frames)
ts = thread_local_state.trace_state
ts.axis_env.extend(frames)
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
try:
yield
finally:
for _ in frames:
thread_local_state.trace_state.axis_env.pop()
for _ in frames: ts.axis_env.pop()
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
@contextmanager
def stash_axis_env():
"Promise that a function or with-suite does not depend implicitly on axis env"
# If the promise is broken, then a NameError about an unbound axis name will
# be raised.
ts = thread_local_state.trace_state
prev_axis_env, ts.axis_env = ts.axis_env, []
jax_config.update_thread_local_jit_state(axis_env_state=())
try:
yield
finally:
ts.axis_env = prev_axis_env
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(f for f in ts.axis_env
if f.name is not no_axis_name))
# When a mapped function is given no axis name, we generate a name object based
@ -2601,7 +2631,8 @@ def _compact_eqn_should_include(k: str, v: Any) -> bool:
if k == 'branches': return False
if isinstance(v, (Jaxpr, ClosedJaxpr)): return False
if (isinstance(v, tuple) and
any(isinstance(e, (Jaxpr, ClosedJaxpr)) for e in v)): return False
any(isinstance(e, (Jaxpr, ClosedJaxpr)) for e in v)):
return False
return True
def str_eqn_compact(primitive_name: str, params: Dict) -> str:

View File

@ -1106,6 +1106,28 @@ class CPPJitTest(jtu.BufferDonationTestCase):
finally:
xla.xla_call_p.def_impl(jit_impl)
def test_caches_depend_on_axis_env(self):
# https://github.com/google/jax/issues/9187
f = lambda: lax.psum(1, "i")
g = jax.jit(f)
expected = jax.vmap(f, axis_name="i", axis_size=2, out_axes=None)()
ans = jax.vmap(g, axis_name="i", axis_size=2, out_axes=None)()
self.assertEqual(ans, expected)
# This second call to g could erroneously get a cache hit.
expected = jax.vmap(f, axis_name="i", axis_size=3, out_axes=None)()
ans = jax.vmap(g, axis_name="i", axis_size=3, out_axes=None)()
self.assertEqual(ans, expected)
def test_caches_dont_depend_on_unnamed_axis_env(self):
# https://github.com/google/jax/issues/9187
f = jax.jit(lambda: jnp.sin(1))
expected = f()
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
ans = jax.vmap(f, axis_size=2, out_axes=None)()
self.assertEqual(count[0], 0) # no compiles
self.assertArraysAllClose(ans, expected, check_dtypes=True)
class PythonJitTest(CPPJitTest):

View File

@ -2974,6 +2974,15 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"])
def test_caches_depend_on_axis_env(self):
# https://github.com/google/jax/issues/9187
scanned_f = lambda _, __: (lax.psum(1, 'i'), None)
f = lambda: lax.scan(scanned_f, 0, None, length=1)[0]
ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
self.assertEqual(ans, 2)
ans = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
self.assertEqual(ans, 3)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1949,8 +1949,9 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertEqual(count[0], 2) # one for fwd, one for bwd
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
_ = jax.vjp(f, x)
_, f_bwd2 = jax.vjp(f, x)
_ = f_bwd(x)
_ = f_bwd2(x)
self.assertEqual(count[0], 0) # cache hits on fwd and bwd
def testSizeOverflow(self):