mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #6189 from LenaMartens:changelist/364621874
PiperOrigin-RevId: 365886852
This commit is contained in:
commit
634397dc59
@ -63,14 +63,18 @@ T = TypeVar('T')
|
||||
Array = Any
|
||||
|
||||
@cache()
|
||||
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
|
||||
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
||||
transform_name: str = ""):
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals,
|
||||
transform_name=transform_name)
|
||||
return jaxpr, consts, out_tree()
|
||||
|
||||
@cache()
|
||||
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
|
||||
jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
|
||||
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals,
|
||||
transform_name: str = ""):
|
||||
jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals,
|
||||
transform_name)
|
||||
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
return closed_jaxpr, consts, out_tree
|
||||
|
||||
@ -266,8 +270,8 @@ def while_loop(cond_fun: Callable[[T], bool],
|
||||
def _create_jaxpr(init_val):
|
||||
init_vals, in_tree = tree_flatten((init_val,))
|
||||
init_avals = tuple(_map(_abstractify, init_vals))
|
||||
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
|
||||
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
|
||||
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals, "while_cond")
|
||||
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals, "while_loop")
|
||||
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
|
||||
msg = "cond_fun must return a boolean scalar, but got pytree {}."
|
||||
raise TypeError(msg.format(cond_tree))
|
||||
@ -1256,7 +1260,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
in_flat, in_tree = tree_flatten((init, xs))
|
||||
|
||||
carry_avals = tuple(_map(_abstractify, init_flat))
|
||||
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
|
||||
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals, "scan")
|
||||
out_tree_children = out_tree.children()
|
||||
if len(out_tree_children) != 2:
|
||||
msg = "scan body output must be a pair, got {}."
|
||||
|
@ -1169,9 +1169,11 @@ def _memoize(thunk):
|
||||
return memoized
|
||||
|
||||
|
||||
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
|
||||
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
transform_name: str = ""):
|
||||
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
|
||||
main.source_info = fun_sourceinfo(fun.f) # type: ignore
|
||||
main.source_info = fun_sourceinfo(fun.f, transform_name) # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
||||
del main, fun
|
||||
@ -1198,9 +1200,11 @@ def extend_jaxpr_stack(main, frame):
|
||||
assert frame is main.jaxpr_stack[-1]
|
||||
main.jaxpr_stack = main.jaxpr_stack[:-1]
|
||||
|
||||
def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
|
||||
def trace_to_jaxpr_final(fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
transform_name: str = ""):
|
||||
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
||||
main.source_info = fun_sourceinfo(fun.f) # type: ignore
|
||||
main.source_info = fun_sourceinfo(fun.f, transform_name) # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
||||
del fun, main
|
||||
@ -1214,12 +1218,15 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial
|
||||
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
|
||||
return trace_to_jaxpr(fun, in_pvals)
|
||||
|
||||
def fun_sourceinfo(fun):
|
||||
def fun_sourceinfo(fun, transform_name: str = ""):
|
||||
if isinstance(fun, functools.partial):
|
||||
fun = fun.func
|
||||
try:
|
||||
filename = fun.__code__.co_filename
|
||||
lineno = fun.__code__.co_firstlineno
|
||||
return f"{fun.__name__} at {filename}:{lineno}"
|
||||
line_info = f"{fun.__name__} at {filename}:{lineno}"
|
||||
if transform_name:
|
||||
line_info += f', transformed by {transform_name}.'
|
||||
return line_info
|
||||
except AttributeError:
|
||||
return "<unknown>"
|
||||
|
@ -685,7 +685,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
|
||||
|
||||
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals)
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals, transform_name="pmap")
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
out_axes = out_axes_thunk()
|
||||
|
@ -649,7 +649,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
"got device={} and backend={}".format(device, backend))
|
||||
|
||||
abstract_args, arg_devices = unzip2(arg_specs)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
|
||||
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
|
||||
|
@ -2070,6 +2070,17 @@ class APITest(jtu.JaxTestCase):
|
||||
# level, which is no longer live.
|
||||
jax.jit(jnp.add)(jnp.ones(()), count)
|
||||
|
||||
def test_escaped_tracer_transform_name(self):
|
||||
with self.assertRaisesRegex(core.UnexpectedTracerError,
|
||||
"transformed by jit"):
|
||||
jax.jit(self.helper_save_tracer)(1)
|
||||
_ = self._saved_tracer+1
|
||||
|
||||
with self.assertRaisesRegex(core.UnexpectedTracerError,
|
||||
"transformed by pmap"):
|
||||
jax.pmap(self.helper_save_tracer)(jnp.ones((1, 2)))
|
||||
_ = self._saved_tracer+1
|
||||
|
||||
def test_pmap_static_kwarg_error_message(self):
|
||||
# https://github.com/google/jax/issues/3007
|
||||
def f(a, b):
|
||||
|
@ -2602,6 +2602,24 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False)
|
||||
|
||||
def test_unexpected_tracer_error(self):
|
||||
with self.assertRaisesRegex(core.UnexpectedTracerError,
|
||||
"transformed by while_loop"):
|
||||
lst = []
|
||||
def side_effecting_body(val):
|
||||
lst.append(val)
|
||||
return val+1
|
||||
lax.while_loop(lambda x: x < 2, side_effecting_body, 1)
|
||||
lst[0] += 1
|
||||
|
||||
with self.assertRaisesRegex(core.UnexpectedTracerError,
|
||||
"transformed by scan"):
|
||||
lst = []
|
||||
def side_effecting_scan(carry, val):
|
||||
lst.append(val)
|
||||
return carry, val+1
|
||||
lax.scan(side_effecting_scan, None, jnp.ones((2, 2)))
|
||||
lst[0] += 1
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user