add tuple_args logic to xla primitive application

This commit is contained in:
Matthew Johnson 2019-12-12 05:14:57 -08:00 committed by Matthew Johnson
parent 876c9c0ede
commit fbde09f567
4 changed files with 23 additions and 15 deletions

View File

@ -162,18 +162,20 @@ def xla_primitive_callable(prim, *abstract_args, **params):
handle_result = lambda xs: tuple(h(x) for h, x in zip(handlers, xs.destructure()))
else:
handle_result = aval_to_result_handler(aval_out)
built_c = primitive_computation(prim, *abstract_args, **params)
tuple_args = len(abstract_args) > 100
built_c = primitive_computation(prim, tuple_args, *abstract_args, **params)
compiled = built_c.Compile(compile_options=xb.get_compile_options(),
backend=xb.get_backend(backend))
return partial(_execute_compiled_primitive, prim, compiled, backend, handle_result)
return partial(_execute_compiled_primitive, prim, compiled, backend,
tuple_args, handle_result)
@cache()
def primitive_computation(prim, *avals, **params):
def primitive_computation(prim, tuple_args, *avals, **params):
c = xb.make_computation_builder("primitive_computation_{}".format(prim.name))
c.SetOpMetadata(xc.OpMetadata(op_type=prim.name, op_name=str(params)))
backend = params.pop("backend", None)
platform = xb.get_backend(backend).platform
xla_args = _xla_callable_args(c, avals, False)
xla_args = _xla_callable_args(c, avals, tuple_args)
if prim in backend_specific_translations[platform]:
rule = backend_specific_translations[platform][prim]
rule(c, *xla_args, **params) # return val set as a side-effect on c
@ -197,9 +199,15 @@ def primitive_computation(prim, *avals, **params):
"https://github.com/google/jax/issues\n")
raise RuntimeError(msg)
def _execute_compiled_primitive(prim, compiled, backend, result_handler, *args):
def primitive_subcomputation(prim, *avals, **params):
return primitive_computation(prim, False, *avals, **params)
def _execute_compiled_primitive(prim, compiled, backend, tuple_args,
result_handler, *args):
device, = compiled.local_devices()
input_bufs = [device_put(x, device) for x in args if x is not token]
if tuple_args:
input_bufs = [make_tuple(input_bufs, device, backend)]
out_buf = compiled.Execute(input_bufs)
if FLAGS.jax_debug_nans:
check_nans(prim, out_buf.destructure() if prim.multiple_results else out_buf)

View File

@ -3393,7 +3393,7 @@ def _reduce_sum_translation_rule(c, operand, axes, input_shape):
dtype = c.GetShape(operand).numpy_dtype()
scalar = ShapedArray((), dtype)
return c.Reduce(operand, c.Constant(onp.array(0, dtype)),
xla.primitive_computation(add_p, scalar, scalar),
xla.primitive_subcomputation(add_p, scalar, scalar),
axes)
def _reduce_sum_transpose_rule(cotangent, input_shape, axes):
@ -3417,7 +3417,7 @@ def _reduce_prod_translation_rule(c, operand, axes):
dtype = c.GetShape(operand).numpy_dtype()
scalar = ShapedArray((), dtype)
return c.Reduce(operand, c.Constant(onp.array(1, dtype)),
xla.primitive_computation(mul_p, scalar, scalar),
xla.primitive_subcomputation(mul_p, scalar, scalar),
axes)
def _reduce_prod_jvp_rule(tangent, operand, axes):
@ -3463,7 +3463,7 @@ def _reduce_chooser_translation_rule(prim, identity, c, operand, axes):
dtype = c.GetShape(operand).numpy_dtype()
scalar = ShapedArray((), dtype)
return c.Reduce(operand, c.Constant(identity(dtype)),
xla.primitive_computation(prim, scalar, scalar), axes)
xla.primitive_subcomputation(prim, scalar, scalar), axes)
def _reduce_chooser_jvp_rule(g, ans, operand, axes):
# TODO(mattjj): an alternative is to use variadic reduce to compute the chosen
@ -3500,7 +3500,7 @@ def _reduce_logical_shape_rule(operand, axes):
def _reduce_logical_translation_rule(prim, identity, c, operand, axes):
scalar = ShapedArray((), onp.bool_)
return c.Reduce(operand, c.Constant(identity(onp.bool_)),
xla.primitive_computation(prim, scalar, scalar), axes)
xla.primitive_subcomputation(prim, scalar, scalar), axes)
_reduce_or_translation_rule = partial(_reduce_logical_translation_rule,
or_p, _get_max_identity)
@ -3563,7 +3563,7 @@ def _reduce_window_sum_translation_rule(c, operand, window_dimensions,
dtype = c.GetShape(operand).numpy_dtype()
scalar = ShapedArray((), dtype)
return c.ReduceWindow(operand, c.Constant(onp.array(0, dtype)),
xla.primitive_computation(add_p, scalar, scalar),
xla.primitive_subcomputation(add_p, scalar, scalar),
window_dimensions, window_strides, padding)
def _reduce_window_sum_transpose_rule(cotangent, window_dimensions,
@ -3610,7 +3610,7 @@ def _reduce_window_chooser_translation_rule(
dtype = c.GetShape(operand).numpy_dtype()
scalar = ShapedArray((), dtype)
return c.ReduceWindow(operand, c.Constant(identity(dtype)),
xla.primitive_computation(prim, scalar, scalar),
xla.primitive_subcomputation(prim, scalar, scalar),
window_dimensions, window_strides, padding)
def _reduce_window_chooser_jvp_rule(prim, g, operand, window_dimensions,
@ -3700,8 +3700,8 @@ def _select_and_scatter_add_translation(
padding):
dtype = c.GetShape(operand).numpy_dtype()
scalar = ShapedArray((), dtype)
select = xla.primitive_computation(select_prim, scalar, scalar)
scatter = xla.primitive_computation(add_p, scalar, scalar)
select = xla.primitive_subcomputation(select_prim, scalar, scalar)
scatter = xla.primitive_subcomputation(add_p, scalar, scalar)
zero = c.Constant(onp.array(0, dtype))
return c.SelectAndScatter(operand, select, window_dimensions, window_strides,
padding, source, zero, scatter)

View File

@ -235,7 +235,7 @@ def _while_loop_translation_rule(c, axis_env, *args, **kwargs):
_map(cond_c.Constant, cond_jaxpr.literals), (), *(x + z))
if batched:
scalar = ShapedArray((), onp.bool_)
or_ = xla.primitive_computation(lax.or_p, scalar, scalar)
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_,
list(range(cond_jaxpr.out_avals[0].ndim)))

View File

@ -191,7 +191,7 @@ def _allreduce_split_axis_rule(prim, reducer, vals, which_mapped, axis_name):
def _allreduce_translation_rule(prim, c, val, replica_groups, backend=None):
dtype = c.GetShape(val).numpy_dtype()
scalar = ShapedArray((), dtype)
computation = xla.primitive_computation(prim, scalar, scalar, backend=backend)
computation = xla.primitive_subcomputation(prim, scalar, scalar, backend=backend)
return c.AllReduce(val, computation, replica_groups=replica_groups)
# psum translation rule has special handling for complex dtypes