mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
prune trivial convert_element_types from jaxprs
also add a test for not performing H2D transfers while tracing jnp.array
This commit is contained in:
parent
99b755f845
commit
ba9233b9b6
@ -380,10 +380,8 @@ For the example consider the function ``func11`` below
|
||||
f = convert_element_type[ new_dtype=float32
|
||||
weak_type=False ] b
|
||||
g = add f e
|
||||
h = convert_element_type[ new_dtype=float32
|
||||
weak_type=False ] a
|
||||
i = add g h
|
||||
in (i, b) }
|
||||
h = add g a
|
||||
in (h, b) }
|
||||
length=16
|
||||
linear=(False, False, False, False)
|
||||
num_carry=1
|
||||
@ -424,13 +422,11 @@ computation should run. For example
|
||||
call_jaxpr={ lambda ; a b.
|
||||
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
shape=(1,) ] 1.0
|
||||
d = convert_element_type[ new_dtype=float32
|
||||
weak_type=False ] a
|
||||
e = mul d c
|
||||
f = convert_element_type[ new_dtype=float32
|
||||
d = mul a c
|
||||
e = convert_element_type[ new_dtype=float32
|
||||
weak_type=False ] b
|
||||
g = add f e
|
||||
in (g,) }
|
||||
f = add e d
|
||||
in (f,) }
|
||||
device=None
|
||||
donated_invars=(False, False)
|
||||
name=inner ] a b
|
||||
@ -460,16 +456,14 @@ captured using the ``xla_pmap`` primitive. Consider this example
|
||||
axis_size=1
|
||||
backend=None
|
||||
call_jaxpr={ lambda ; a b.
|
||||
let c = convert_element_type[ new_dtype=float32
|
||||
weak_type=False ] a
|
||||
d = add b c
|
||||
e = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
let c = add b a
|
||||
d = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
shape=(1,) ] 1.0
|
||||
f = add d e
|
||||
g = psum[ axes=('rows',)
|
||||
e = add c d
|
||||
f = psum[ axes=('rows',)
|
||||
axis_index_groups=None ] b
|
||||
h = div f g
|
||||
in (h,) }
|
||||
g = div e f
|
||||
in (g,) }
|
||||
devices=None
|
||||
donated_invars=(False, False)
|
||||
global_arg_shapes=(None,)
|
||||
|
@ -47,7 +47,6 @@ from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_sha
|
||||
from jax.config import config
|
||||
from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
|
||||
from jax import lax
|
||||
from jax._src.lax.lax import _device_put_raw
|
||||
from jax import ops
|
||||
from jax._src.util import (partial, unzip2, prod as _prod, subvals, safe_zip,
|
||||
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
|
||||
@ -2932,12 +2931,16 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
|
||||
# large integers; see discussion in https://github.com/google/jax/pull/6047.
|
||||
object = _np_array(object, dtype=dtype, ndmin=ndmin, copy=False)
|
||||
|
||||
# call _np_array a second time with canonicalized dtype
|
||||
dtype = dtypes.canonicalize_dtype(object.dtype)
|
||||
object = _np_array(object, dtype=dtype, copy=False)
|
||||
|
||||
assert type(object) not in dtypes.python_scalar_dtypes
|
||||
|
||||
if type(object) is np.ndarray:
|
||||
_inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)
|
||||
lax._check_user_dtype_supported(_inferred_dtype, "array")
|
||||
out = np.array(object, copy=copy, dtype=dtype)
|
||||
out = _np_array(object, copy=copy, dtype=dtype)
|
||||
if dtype: assert _dtype(out) == dtype
|
||||
elif isinstance(object, (DeviceArray, core.Tracer)):
|
||||
if isinstance(object, DeviceArray) and copy:
|
||||
|
@ -940,6 +940,7 @@ class JaxprStackFrame:
|
||||
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
|
||||
constvars, constvals = unzip2(self.constvar_to_val.items())
|
||||
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
|
||||
jaxpr, constvals = _prune_convert_element_types(jaxpr, constvals)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
||||
out_avals = [t.aval for t in out_tracers]
|
||||
return jaxpr, out_avals, constvals
|
||||
@ -962,7 +963,28 @@ class JaxprStackFrame:
|
||||
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
|
||||
return invar_positions, const_eqns
|
||||
|
||||
def _prune_convert_element_types(jaxpr, constvals):
|
||||
consts = dict(zip(jaxpr.constvars, constvals))
|
||||
new_eqns = []
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive is core.convert_element_type_p:
|
||||
c = consts.get(eqn.invars[0])
|
||||
if type(c) in core.literalable_types and not np.shape(c):
|
||||
# constant-fold dtype conversion of literals to be inlined
|
||||
consts[eqn.outvars[0]] = np.array(c, eqn.params['new_dtype'])
|
||||
continue
|
||||
if c is not None and dtypes.dtype(c) == eqn.params['new_dtype']:
|
||||
# don't stage out no-op convert_element_type calls as clutter
|
||||
consts[eqn.outvars[0]] = c
|
||||
continue
|
||||
new_eqns.append(eqn)
|
||||
new_constvars, new_constvals = unzip2(consts.items())
|
||||
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, jaxpr.outvars, new_eqns)
|
||||
return new_jaxpr, new_constvals
|
||||
|
||||
def _inline_literals(jaxpr, constvals):
|
||||
# This function also ensures variables are labeled in a canonical ordering,
|
||||
# prunes unused constants, and inserts `dropvar` symbols.
|
||||
consts = dict(zip(jaxpr.constvars, constvals))
|
||||
newvar = core.gensym()
|
||||
newvars = {}
|
||||
@ -976,20 +998,16 @@ def _inline_literals(jaxpr, constvals):
|
||||
return None
|
||||
|
||||
used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars)
|
||||
new_constvars = [var(v) for v in jaxpr.constvars if not lit(v)]
|
||||
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)]
|
||||
new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)]
|
||||
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals)
|
||||
if v in used and not lit(v)]
|
||||
new_invars = [var(v) for v in jaxpr.invars]
|
||||
new_eqns = []
|
||||
for eqn in jaxpr.eqns:
|
||||
invars = [lit(v) or var(v) for v in eqn.invars]
|
||||
if (eqn.primitive is core.convert_element_type_p and type(invars[0]) is Literal):
|
||||
# constant-fold dtype conversion of literals to be inlined
|
||||
consts[eqn.outvars[0]] = np.array(invars[0].val, eqn.params['new_dtype'])
|
||||
else:
|
||||
# might do DCE here, but we won't until we're more careful about effects
|
||||
outvars = [var(v) if v in used else dropvar for v in eqn.outvars]
|
||||
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
|
||||
eqn.source_info))
|
||||
outvars = [var(v) if v in used else dropvar for v in eqn.outvars]
|
||||
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
|
||||
eqn.source_info))
|
||||
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
|
||||
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
|
||||
return new_jaxpr, new_constvals
|
||||
|
@ -2269,8 +2269,6 @@ class APITest(jtu.JaxTestCase):
|
||||
f()
|
||||
|
||||
def test_xla_computation_zeros_doesnt_device_put(self):
|
||||
raise unittest.SkipTest("broken test") # TODO(mattjj): fix
|
||||
|
||||
with jtu.count_device_put() as count:
|
||||
api.xla_computation(lambda: jnp.zeros(3))()
|
||||
self.assertEqual(count[0], 0)
|
||||
@ -2666,6 +2664,10 @@ class APITest(jtu.JaxTestCase):
|
||||
jtu.check_grads(batched_scan_over_mul, (x_batch, coeff), order=2,
|
||||
modes=['rev'])
|
||||
|
||||
def test_jnp_array_doesnt_device_put(self):
|
||||
with jtu.count_device_put() as count:
|
||||
api.make_jaxpr(lambda: jnp.array(3))()
|
||||
self.assertEqual(count[0], 0)
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user