mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14031 from gnecula:tf_jit_pjit
PiperOrigin-RevId: 502790279
This commit is contained in:
commit
74df2f9927
@ -761,16 +761,18 @@ def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
|
||||
) -> Sequence[Tuple[TfVal, core.ShapedArray]]:
|
||||
try:
|
||||
prev_constant_cache = _thread_local_state.constant_cache
|
||||
prev_constant_cache_keys = set(prev_constant_cache.keys()) if prev_constant_cache is not None else set()
|
||||
# Start a new cache, so that we don't share constants across tf.function
|
||||
# boundaries.
|
||||
if fresh_constant_cache:
|
||||
_thread_local_state.constant_cache = {}
|
||||
|
||||
else:
|
||||
prev_constant_cache_keys = set(prev_constant_cache.keys()) if prev_constant_cache is not None else set()
|
||||
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
|
||||
fun.call_wrapped(*in_vals)
|
||||
finally:
|
||||
if prev_constant_cache is not None and not fresh_constant_cache:
|
||||
if (not fresh_constant_cache and
|
||||
prev_constant_cache is not None and
|
||||
_WRAP_JAX_JIT_WITH_TF_FUNCTION):
|
||||
newly_added_keys = set(prev_constant_cache.keys()) - prev_constant_cache_keys
|
||||
# Delete the newly added keys
|
||||
for k in newly_added_keys:
|
||||
@ -940,10 +942,14 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||||
# JAX has the same problem when generating HLO.
|
||||
const_key = (id(val), jax_dtype)
|
||||
# Since we use id(val) as a cache key, we have to make sure that we keep
|
||||
# the previous `val` alive. Otherwise, for an ndarray, it can get garbage
|
||||
# the previous `val` alive. Otherwise, for a ndarray, it can get garbage
|
||||
# collected and reused for a different value, which would create correctness
|
||||
# issues. We keep the `val` alive by storing in the cache the pair
|
||||
# `(val, tf_val)`.
|
||||
# Only memoize non-scalars. JAX will lift all non-scalar constants as
|
||||
# Jaxpr consts, to the top level of the Jaxpr. This ensures that we see them
|
||||
# early, when entering the Jaxpr, so we create the tf.const early and its
|
||||
# scope is the entire Jaxpr.
|
||||
do_memoize = (memoize_constants and np.shape(val) and _thread_local_state.constant_cache is not None)
|
||||
if do_memoize:
|
||||
_, tf_val = _thread_local_state.constant_cache.get(const_key, (None, None))
|
||||
|
@ -889,8 +889,9 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
if config.jax2tf_default_experimental_native_lowering:
|
||||
self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf))
|
||||
else:
|
||||
self.assertIn("my_test_function_jax/jit_fn_/Mul",
|
||||
str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def()))
|
||||
graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def())
|
||||
if "my_test_function_jax/pjit_fn_/Mul" not in graph_def:
|
||||
self.assertIn("my_test_function_jax/jit_fn_/Mul", graph_def)
|
||||
|
||||
def test_bfloat16_constant(self):
|
||||
# Re: https://github.com/google/jax/issues/3942
|
||||
@ -921,41 +922,44 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
def f(x):
|
||||
return x + const + const + const + const
|
||||
|
||||
f_tf_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f), const)
|
||||
self.assertEqual(f_tf_nr_consts, 1)
|
||||
f_tf_consts = self.FindLargeTfConstants(jax2tf.convert(f), const)
|
||||
self.assertLen(f_tf_consts, 1)
|
||||
|
||||
def test_shared_constants_under_cond(self):
|
||||
# Check that the constants are shared properly in converted functions
|
||||
# See https://github.com/google/jax/issues/7992.
|
||||
if config.jax2tf_default_experimental_native_lowering:
|
||||
raise unittest.SkipTest("shared constants tests not interesting for native lowering")
|
||||
const = np.random.uniform(size=256).astype(np.float32) # A shared constant
|
||||
x = np.ones((256,), dtype=np.float32)
|
||||
const_size = 512
|
||||
const = np.random.uniform(size=const_size).astype(np.float32) # A shared constant
|
||||
x = np.ones((const_size,), dtype=np.float32)
|
||||
def f1(x):
|
||||
# Ensure that we first see the constants in the inside jaxpr
|
||||
return lax.cond(x[0] >= 0., lambda x: x + const, lambda x: x * const, x) + const
|
||||
def f2(x):
|
||||
return f1(x) + const # The extra const should not cost anything
|
||||
f1_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f1), x)
|
||||
f2_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f2), x)
|
||||
self.assertEqual(f1_nr_consts, f2_nr_consts)
|
||||
f1_consts = self.FindLargeTfConstants(jax2tf.convert(f1), x, at_least=const_size)
|
||||
f2_consts = self.FindLargeTfConstants(jax2tf.convert(f2), x, at_least=const_size)
|
||||
self.assertLen(f2_consts, len(f1_consts))
|
||||
|
||||
def test_shared_constants_under_scan(self):
|
||||
# See https://github.com/google/jax/issues/7992.
|
||||
if config.jax2tf_default_experimental_native_lowering:
|
||||
raise unittest.SkipTest("shared constants tests not interesting for native lowering")
|
||||
const = np.random.uniform(size=256).astype(np.float32) # A shared constant
|
||||
xs = np.ones((8, 256), dtype=np.float32)
|
||||
const_size = 512
|
||||
const = np.random.uniform(size=const_size).astype(np.float32) # A shared constant
|
||||
xs = np.ones((8, const_size), dtype=np.float32)
|
||||
def f1(xs):
|
||||
res, _ = lax.scan(lambda carry, x: (carry + x + const, None),
|
||||
np.zeros((256,), dtype=np.float32), xs)
|
||||
jnp.zeros((const_size,), dtype=np.float32), xs)
|
||||
return res
|
||||
|
||||
def f2(xs):
|
||||
return f1(xs) + const # The extra const should not be saved
|
||||
|
||||
f1_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f1), xs)
|
||||
f2_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f2), xs)
|
||||
self.assertEqual(f1_nr_consts, f2_nr_consts)
|
||||
f1_consts = self.FindLargeTfConstants(jax2tf.convert(f1), xs, at_least=const_size)
|
||||
f2_consts = self.FindLargeTfConstants(jax2tf.convert(f2), xs, at_least=const_size)
|
||||
self.assertLen(f2_consts, len(f1_consts))
|
||||
|
||||
def test_shared_constants_under_jit(self):
|
||||
# We do not share constants under jit.
|
||||
@ -968,9 +972,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
def f(x):
|
||||
return g_jit(x) + const + const
|
||||
|
||||
f_tf_graph_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f), const)
|
||||
# TODO(b/207464757): TF compilation is disabled
|
||||
self.assertEqual(f_tf_graph_nr_consts, 1)
|
||||
f_tf_graph_consts = self.FindLargeTfConstants(jax2tf.convert(f), const)
|
||||
self.assertLen(f_tf_graph_consts, 1)
|
||||
|
||||
def test_weak_types(self):
|
||||
mul = jax.jit(jnp.multiply)
|
||||
|
@ -374,9 +374,9 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
return tf_function.experimental_get_compiler_ir(*args)(stage="hlo",
|
||||
device_name=device_name)
|
||||
|
||||
def CountLargeTfConstants(self, tf_fun: Callable, *args,
|
||||
at_least=256):
|
||||
# A hacky way to count how many "large" constants are embedded in the
|
||||
def FindLargeTfConstants(self, tf_fun: Callable, *args,
|
||||
at_least=256):
|
||||
# A hacky way to find the "large" constants that are embedded in the
|
||||
# graph. We count the number of characters in the textual representation
|
||||
# of the constant.
|
||||
f_tf_graph = tf.function(tf_fun, autograph=False).get_concrete_function(*args).graph.as_graph_def()
|
||||
@ -387,8 +387,8 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
else:
|
||||
# We cannot find the constants just with string matching because their
|
||||
# representation may contain escaped "
|
||||
large_consts = [n for n in f_tf_graph.node if n.op == "Const" and len(str(n)) >= at_least]
|
||||
return len(large_consts)
|
||||
large_consts = [str(n) for n in f_tf_graph.node if n.op == "Const" and len(str(n)) >= at_least]
|
||||
return large_consts
|
||||
|
||||
def CheckOpMetadata(self, jax_fun, x,
|
||||
expected: Sequence[OpMetadataGraph],
|
||||
|
Loading…
x
Reference in New Issue
Block a user