mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
add tuple_args logic to xla primitive application
This commit is contained in:
parent
876c9c0ede
commit
fbde09f567
@ -162,18 +162,20 @@ def xla_primitive_callable(prim, *abstract_args, **params):
|
||||
handle_result = lambda xs: tuple(h(x) for h, x in zip(handlers, xs.destructure()))
|
||||
else:
|
||||
handle_result = aval_to_result_handler(aval_out)
|
||||
built_c = primitive_computation(prim, *abstract_args, **params)
|
||||
tuple_args = len(abstract_args) > 100
|
||||
built_c = primitive_computation(prim, tuple_args, *abstract_args, **params)
|
||||
compiled = built_c.Compile(compile_options=xb.get_compile_options(),
|
||||
backend=xb.get_backend(backend))
|
||||
return partial(_execute_compiled_primitive, prim, compiled, backend, handle_result)
|
||||
return partial(_execute_compiled_primitive, prim, compiled, backend,
|
||||
tuple_args, handle_result)
|
||||
|
||||
@cache()
|
||||
def primitive_computation(prim, *avals, **params):
|
||||
def primitive_computation(prim, tuple_args, *avals, **params):
|
||||
c = xb.make_computation_builder("primitive_computation_{}".format(prim.name))
|
||||
c.SetOpMetadata(xc.OpMetadata(op_type=prim.name, op_name=str(params)))
|
||||
backend = params.pop("backend", None)
|
||||
platform = xb.get_backend(backend).platform
|
||||
xla_args = _xla_callable_args(c, avals, False)
|
||||
xla_args = _xla_callable_args(c, avals, tuple_args)
|
||||
if prim in backend_specific_translations[platform]:
|
||||
rule = backend_specific_translations[platform][prim]
|
||||
rule(c, *xla_args, **params) # return val set as a side-effect on c
|
||||
@ -197,9 +199,15 @@ def primitive_computation(prim, *avals, **params):
|
||||
"https://github.com/google/jax/issues\n")
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def _execute_compiled_primitive(prim, compiled, backend, result_handler, *args):
|
||||
def primitive_subcomputation(prim, *avals, **params):
|
||||
return primitive_computation(prim, False, *avals, **params)
|
||||
|
||||
def _execute_compiled_primitive(prim, compiled, backend, tuple_args,
|
||||
result_handler, *args):
|
||||
device, = compiled.local_devices()
|
||||
input_bufs = [device_put(x, device) for x in args if x is not token]
|
||||
if tuple_args:
|
||||
input_bufs = [make_tuple(input_bufs, device, backend)]
|
||||
out_buf = compiled.Execute(input_bufs)
|
||||
if FLAGS.jax_debug_nans:
|
||||
check_nans(prim, out_buf.destructure() if prim.multiple_results else out_buf)
|
||||
|
@ -3393,7 +3393,7 @@ def _reduce_sum_translation_rule(c, operand, axes, input_shape):
|
||||
dtype = c.GetShape(operand).numpy_dtype()
|
||||
scalar = ShapedArray((), dtype)
|
||||
return c.Reduce(operand, c.Constant(onp.array(0, dtype)),
|
||||
xla.primitive_computation(add_p, scalar, scalar),
|
||||
xla.primitive_subcomputation(add_p, scalar, scalar),
|
||||
axes)
|
||||
|
||||
def _reduce_sum_transpose_rule(cotangent, input_shape, axes):
|
||||
@ -3417,7 +3417,7 @@ def _reduce_prod_translation_rule(c, operand, axes):
|
||||
dtype = c.GetShape(operand).numpy_dtype()
|
||||
scalar = ShapedArray((), dtype)
|
||||
return c.Reduce(operand, c.Constant(onp.array(1, dtype)),
|
||||
xla.primitive_computation(mul_p, scalar, scalar),
|
||||
xla.primitive_subcomputation(mul_p, scalar, scalar),
|
||||
axes)
|
||||
|
||||
def _reduce_prod_jvp_rule(tangent, operand, axes):
|
||||
@ -3463,7 +3463,7 @@ def _reduce_chooser_translation_rule(prim, identity, c, operand, axes):
|
||||
dtype = c.GetShape(operand).numpy_dtype()
|
||||
scalar = ShapedArray((), dtype)
|
||||
return c.Reduce(operand, c.Constant(identity(dtype)),
|
||||
xla.primitive_computation(prim, scalar, scalar), axes)
|
||||
xla.primitive_subcomputation(prim, scalar, scalar), axes)
|
||||
|
||||
def _reduce_chooser_jvp_rule(g, ans, operand, axes):
|
||||
# TODO(mattjj): an alternative is to use variadic reduce to compute the chosen
|
||||
@ -3500,7 +3500,7 @@ def _reduce_logical_shape_rule(operand, axes):
|
||||
def _reduce_logical_translation_rule(prim, identity, c, operand, axes):
|
||||
scalar = ShapedArray((), onp.bool_)
|
||||
return c.Reduce(operand, c.Constant(identity(onp.bool_)),
|
||||
xla.primitive_computation(prim, scalar, scalar), axes)
|
||||
xla.primitive_subcomputation(prim, scalar, scalar), axes)
|
||||
|
||||
_reduce_or_translation_rule = partial(_reduce_logical_translation_rule,
|
||||
or_p, _get_max_identity)
|
||||
@ -3563,7 +3563,7 @@ def _reduce_window_sum_translation_rule(c, operand, window_dimensions,
|
||||
dtype = c.GetShape(operand).numpy_dtype()
|
||||
scalar = ShapedArray((), dtype)
|
||||
return c.ReduceWindow(operand, c.Constant(onp.array(0, dtype)),
|
||||
xla.primitive_computation(add_p, scalar, scalar),
|
||||
xla.primitive_subcomputation(add_p, scalar, scalar),
|
||||
window_dimensions, window_strides, padding)
|
||||
|
||||
def _reduce_window_sum_transpose_rule(cotangent, window_dimensions,
|
||||
@ -3610,7 +3610,7 @@ def _reduce_window_chooser_translation_rule(
|
||||
dtype = c.GetShape(operand).numpy_dtype()
|
||||
scalar = ShapedArray((), dtype)
|
||||
return c.ReduceWindow(operand, c.Constant(identity(dtype)),
|
||||
xla.primitive_computation(prim, scalar, scalar),
|
||||
xla.primitive_subcomputation(prim, scalar, scalar),
|
||||
window_dimensions, window_strides, padding)
|
||||
|
||||
def _reduce_window_chooser_jvp_rule(prim, g, operand, window_dimensions,
|
||||
@ -3700,8 +3700,8 @@ def _select_and_scatter_add_translation(
|
||||
padding):
|
||||
dtype = c.GetShape(operand).numpy_dtype()
|
||||
scalar = ShapedArray((), dtype)
|
||||
select = xla.primitive_computation(select_prim, scalar, scalar)
|
||||
scatter = xla.primitive_computation(add_p, scalar, scalar)
|
||||
select = xla.primitive_subcomputation(select_prim, scalar, scalar)
|
||||
scatter = xla.primitive_subcomputation(add_p, scalar, scalar)
|
||||
zero = c.Constant(onp.array(0, dtype))
|
||||
return c.SelectAndScatter(operand, select, window_dimensions, window_strides,
|
||||
padding, source, zero, scatter)
|
||||
|
@ -235,7 +235,7 @@ def _while_loop_translation_rule(c, axis_env, *args, **kwargs):
|
||||
_map(cond_c.Constant, cond_jaxpr.literals), (), *(x + z))
|
||||
if batched:
|
||||
scalar = ShapedArray((), onp.bool_)
|
||||
or_ = xla.primitive_computation(lax.or_p, scalar, scalar)
|
||||
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
|
||||
pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_,
|
||||
list(range(cond_jaxpr.out_avals[0].ndim)))
|
||||
|
||||
|
@ -191,7 +191,7 @@ def _allreduce_split_axis_rule(prim, reducer, vals, which_mapped, axis_name):
|
||||
def _allreduce_translation_rule(prim, c, val, replica_groups, backend=None):
|
||||
dtype = c.GetShape(val).numpy_dtype()
|
||||
scalar = ShapedArray((), dtype)
|
||||
computation = xla.primitive_computation(prim, scalar, scalar, backend=backend)
|
||||
computation = xla.primitive_subcomputation(prim, scalar, scalar, backend=backend)
|
||||
return c.AllReduce(val, computation, replica_groups=replica_groups)
|
||||
|
||||
# psum translation rule has special handling for complex dtypes
|
||||
|
Loading…
x
Reference in New Issue
Block a user