[jax2tf] Adjust jax2tf for the pjit==jit API migration.

jax2tf treats jit and pjit differently: jit was inlined while
pjit resulted in a recursive call to _interpret_jaxpr. This
resulted in differences of handling of constant sharing.

This PR actually makes the constant sharing more aggressive.
This should be Ok, because we are only sharing non-scalars
which JAX has already lifted to the top-level of the Jaxpr.
This commit is contained in:
George Necula 2023-01-17 12:15:01 +02:00
parent 469a8eb520
commit 97b58bfae7
3 changed files with 36 additions and 27 deletions

View File

@ -761,16 +761,18 @@ def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
) -> Sequence[Tuple[TfVal, core.ShapedArray]]:
try:
prev_constant_cache = _thread_local_state.constant_cache
prev_constant_cache_keys = set(prev_constant_cache.keys()) if prev_constant_cache is not None else set()
# Start a new cache, so that we don't share constants across tf.function
# boundaries.
if fresh_constant_cache:
_thread_local_state.constant_cache = {}
else:
prev_constant_cache_keys = set(prev_constant_cache.keys()) if prev_constant_cache is not None else set()
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
fun.call_wrapped(*in_vals)
finally:
if prev_constant_cache is not None and not fresh_constant_cache:
if (not fresh_constant_cache and
prev_constant_cache is not None and
_WRAP_JAX_JIT_WITH_TF_FUNCTION):
newly_added_keys = set(prev_constant_cache.keys()) - prev_constant_cache_keys
# Delete the newly added keys
for k in newly_added_keys:
@ -940,10 +942,14 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
# JAX has the same problem when generating HLO.
const_key = (id(val), jax_dtype)
# Since we use id(val) as a cache key, we have to make sure that we keep
# the previous `val` alive. Otherwise, for an ndarray, it can get garbage
# the previous `val` alive. Otherwise, for a ndarray, it can get garbage
# collected and reused for a different value, which would create correctness
# issues. We keep the `val` alive by storing in the cache the pair
# `(val, tf_val)`.
# Only memoize non-scalars. JAX will lift all non-scalar constants as
# Jaxpr consts, to the top level of the Jaxpr. This ensures that we see them
# early, when entering the Jaxpr, so we create the tf.const early and its
# scope is the entire Jaxpr.
do_memoize = (memoize_constants and np.shape(val) and _thread_local_state.constant_cache is not None)
if do_memoize:
_, tf_val = _thread_local_state.constant_cache.get(const_key, (None, None))

View File

@ -889,8 +889,9 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
if config.jax2tf_default_experimental_native_lowering:
self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf))
else:
self.assertIn("my_test_function_jax/jit_fn_/Mul",
str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def()))
graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def())
if "my_test_function_jax/pjit_fn_/Mul" not in graph_def:
self.assertIn("my_test_function_jax/jit_fn_/Mul", graph_def)
def test_bfloat16_constant(self):
# Re: https://github.com/google/jax/issues/3942
@ -921,41 +922,44 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def f(x):
return x + const + const + const + const
f_tf_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f), const)
self.assertEqual(f_tf_nr_consts, 1)
f_tf_consts = self.FindLargeTfConstants(jax2tf.convert(f), const)
self.assertLen(f_tf_consts, 1)
def test_shared_constants_under_cond(self):
# Check that the constants are shared properly in converted functions
# See https://github.com/google/jax/issues/7992.
if config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("shared constants tests not interesting for native lowering")
const = np.random.uniform(size=256).astype(np.float32) # A shared constant
x = np.ones((256,), dtype=np.float32)
const_size = 512
const = np.random.uniform(size=const_size).astype(np.float32) # A shared constant
x = np.ones((const_size,), dtype=np.float32)
def f1(x):
# Ensure that we first see the constants in the inside jaxpr
return lax.cond(x[0] >= 0., lambda x: x + const, lambda x: x * const, x) + const
def f2(x):
return f1(x) + const # The extra const should not cost anything
f1_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f1), x)
f2_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f2), x)
self.assertEqual(f1_nr_consts, f2_nr_consts)
f1_consts = self.FindLargeTfConstants(jax2tf.convert(f1), x, at_least=const_size)
f2_consts = self.FindLargeTfConstants(jax2tf.convert(f2), x, at_least=const_size)
self.assertLen(f2_consts, len(f1_consts))
def test_shared_constants_under_scan(self):
# See https://github.com/google/jax/issues/7992.
if config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("shared constants tests not interesting for native lowering")
const = np.random.uniform(size=256).astype(np.float32) # A shared constant
xs = np.ones((8, 256), dtype=np.float32)
const_size = 512
const = np.random.uniform(size=const_size).astype(np.float32) # A shared constant
xs = np.ones((8, const_size), dtype=np.float32)
def f1(xs):
res, _ = lax.scan(lambda carry, x: (carry + x + const, None),
np.zeros((256,), dtype=np.float32), xs)
jnp.zeros((const_size,), dtype=np.float32), xs)
return res
def f2(xs):
return f1(xs) + const # The extra const should not be saved
f1_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f1), xs)
f2_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f2), xs)
self.assertEqual(f1_nr_consts, f2_nr_consts)
f1_consts = self.FindLargeTfConstants(jax2tf.convert(f1), xs, at_least=const_size)
f2_consts = self.FindLargeTfConstants(jax2tf.convert(f2), xs, at_least=const_size)
self.assertLen(f2_consts, len(f1_consts))
def test_shared_constants_under_jit(self):
# We do not share constants under jit.
@ -968,9 +972,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def f(x):
return g_jit(x) + const + const
f_tf_graph_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f), const)
# TODO(b/207464757): TF compilation is disabled
self.assertEqual(f_tf_graph_nr_consts, 1)
f_tf_graph_consts = self.FindLargeTfConstants(jax2tf.convert(f), const)
self.assertLen(f_tf_graph_consts, 1)
def test_weak_types(self):
mul = jax.jit(jnp.multiply)

View File

@ -374,9 +374,9 @@ class JaxToTfTestCase(jtu.JaxTestCase):
return tf_function.experimental_get_compiler_ir(*args)(stage="hlo",
device_name=device_name)
def CountLargeTfConstants(self, tf_fun: Callable, *args,
at_least=256):
# A hacky way to count how many "large" constants are embedded in the
def FindLargeTfConstants(self, tf_fun: Callable, *args,
at_least=256):
# A hacky way to find the "large" constants that are embedded in the
# graph. We count the number of characters in the textual representation
# of the constant.
f_tf_graph = tf.function(tf_fun, autograph=False).get_concrete_function(*args).graph.as_graph_def()
@ -387,8 +387,8 @@ class JaxToTfTestCase(jtu.JaxTestCase):
else:
# We cannot find the constants just with string matching because their
# representation may contain escaped "
large_consts = [n for n in f_tf_graph.node if n.op == "Const" and len(str(n)) >= at_least]
return len(large_consts)
large_consts = [str(n) for n in f_tf_graph.node if n.op == "Const" and len(str(n)) >= at_least]
return large_consts
def CheckOpMetadata(self, jax_fun, x,
expected: Sequence[OpMetadataGraph],