prune trivial convert_element_types from jaxprs

also add a test for not performing H2D transfers while tracing jnp.array
This commit is contained in:
Matthew Johnson 2021-04-09 21:29:42 -07:00
parent 99b755f845
commit ba9233b9b6
4 changed files with 49 additions and 32 deletions

View File

@ -380,10 +380,8 @@ For the example consider the function ``func11`` below
f = convert_element_type[ new_dtype=float32
weak_type=False ] b
g = add f e
h = convert_element_type[ new_dtype=float32
weak_type=False ] a
i = add g h
in (i, b) }
h = add g a
in (h, b) }
length=16
linear=(False, False, False, False)
num_carry=1
@ -424,13 +422,11 @@ computation should run. For example
call_jaxpr={ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
d = convert_element_type[ new_dtype=float32
weak_type=False ] a
e = mul d c
f = convert_element_type[ new_dtype=float32
d = mul a c
e = convert_element_type[ new_dtype=float32
weak_type=False ] b
g = add f e
in (g,) }
f = add e d
in (f,) }
device=None
donated_invars=(False, False)
name=inner ] a b
@ -460,16 +456,14 @@ captured using the ``xla_pmap`` primitive. Consider this example
axis_size=1
backend=None
call_jaxpr={ lambda ; a b.
let c = convert_element_type[ new_dtype=float32
weak_type=False ] a
d = add b c
e = broadcast_in_dim[ broadcast_dimensions=( )
let c = add b a
d = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
f = add d e
g = psum[ axes=('rows',)
e = add c d
f = psum[ axes=('rows',)
axis_index_groups=None ] b
h = div f g
in (h,) }
g = div e f
in (g,) }
devices=None
donated_invars=(False, False)
global_arg_shapes=(None,)

View File

@ -47,7 +47,6 @@ from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_sha
from jax.config import config
from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
from jax import lax
from jax._src.lax.lax import _device_put_raw
from jax import ops
from jax._src.util import (partial, unzip2, prod as _prod, subvals, safe_zip,
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
@ -2932,12 +2931,16 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
# large integers; see discussion in https://github.com/google/jax/pull/6047.
object = _np_array(object, dtype=dtype, ndmin=ndmin, copy=False)
# call _np_array a second time with canonicalized dtype
dtype = dtypes.canonicalize_dtype(object.dtype)
object = _np_array(object, dtype=dtype, copy=False)
assert type(object) not in dtypes.python_scalar_dtypes
if type(object) is np.ndarray:
_inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)
lax._check_user_dtype_supported(_inferred_dtype, "array")
out = np.array(object, copy=copy, dtype=dtype)
out = _np_array(object, copy=copy, dtype=dtype)
if dtype: assert _dtype(out) == dtype
elif isinstance(object, (DeviceArray, core.Tracer)):
if isinstance(object, DeviceArray) and copy:

View File

@ -940,6 +940,7 @@ class JaxprStackFrame:
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
jaxpr, constvals = _prune_convert_element_types(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
out_avals = [t.aval for t in out_tracers]
return jaxpr, out_avals, constvals
@ -962,7 +963,28 @@ class JaxprStackFrame:
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
return invar_positions, const_eqns
def _prune_convert_element_types(jaxpr, constvals):
consts = dict(zip(jaxpr.constvars, constvals))
new_eqns = []
for eqn in jaxpr.eqns:
if eqn.primitive is core.convert_element_type_p:
c = consts.get(eqn.invars[0])
if type(c) in core.literalable_types and not np.shape(c):
# constant-fold dtype conversion of literals to be inlined
consts[eqn.outvars[0]] = np.array(c, eqn.params['new_dtype'])
continue
if c is not None and dtypes.dtype(c) == eqn.params['new_dtype']:
# don't stage out no-op convert_element_type calls as clutter
consts[eqn.outvars[0]] = c
continue
new_eqns.append(eqn)
new_constvars, new_constvals = unzip2(consts.items())
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, jaxpr.outvars, new_eqns)
return new_jaxpr, new_constvals
def _inline_literals(jaxpr, constvals):
# This function also ensures variables are labeled in a canonical ordering,
# prunes unused constants, and inserts `dropvar` symbols.
consts = dict(zip(jaxpr.constvars, constvals))
newvar = core.gensym()
newvars = {}
@ -976,20 +998,16 @@ def _inline_literals(jaxpr, constvals):
return None
used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars)
new_constvars = [var(v) for v in jaxpr.constvars if not lit(v)]
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)]
new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)]
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals)
if v in used and not lit(v)]
new_invars = [var(v) for v in jaxpr.invars]
new_eqns = []
for eqn in jaxpr.eqns:
invars = [lit(v) or var(v) for v in eqn.invars]
if (eqn.primitive is core.convert_element_type_p and type(invars[0]) is Literal):
# constant-fold dtype conversion of literals to be inlined
consts[eqn.outvars[0]] = np.array(invars[0].val, eqn.params['new_dtype'])
else:
# might do DCE here, but we won't until we're more careful about effects
outvars = [var(v) if v in used else dropvar for v in eqn.outvars]
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
eqn.source_info))
outvars = [var(v) if v in used else dropvar for v in eqn.outvars]
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
eqn.source_info))
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
return new_jaxpr, new_constvals

View File

@ -2269,8 +2269,6 @@ class APITest(jtu.JaxTestCase):
f()
def test_xla_computation_zeros_doesnt_device_put(self):
raise unittest.SkipTest("broken test") # TODO(mattjj): fix
with jtu.count_device_put() as count:
api.xla_computation(lambda: jnp.zeros(3))()
self.assertEqual(count[0], 0)
@ -2666,6 +2664,10 @@ class APITest(jtu.JaxTestCase):
jtu.check_grads(batched_scan_over_mul, (x_batch, coeff), order=2,
modes=['rev'])
def test_jnp_array_doesnt_device_put(self):
with jtu.count_device_put() as count:
api.make_jaxpr(lambda: jnp.array(3))()
self.assertEqual(count[0], 0)
class RematTest(jtu.JaxTestCase):