mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
revise handling of 'backend' values
This commit is contained in:
parent
286ec51f61
commit
8bd1a46ce7
@ -172,6 +172,7 @@ def xla_primitive_callable(prim, *arg_specs, **params):
|
||||
else:
|
||||
all_devices = it.chain(xb.devices(), xb.devices('cpu'))
|
||||
device = device and next(d for d in all_devices if (type(d), d.id) == device)
|
||||
backend = xb.get_device_backend(device)
|
||||
aval_out = prim.abstract_eval(*avals, **params)
|
||||
if prim.multiple_results:
|
||||
handlers = tuple(map(aval_to_result_handler, aval_out))
|
||||
@ -179,18 +180,17 @@ def xla_primitive_callable(prim, *arg_specs, **params):
|
||||
else:
|
||||
handle_result = aval_to_result_handler(aval_out)
|
||||
tuple_args = len(avals) > 100
|
||||
built_c = primitive_computation(prim, tuple_args, *avals, **params)
|
||||
built_c = primitive_computation(prim, backend, tuple_args, *avals, **params)
|
||||
options = xb.get_compile_options(device_assignment=(device.id,) if device else None)
|
||||
compiled = built_c.Compile(compile_options=options,
|
||||
backend=xb.get_device_backend(device))
|
||||
return partial(_execute_compiled_primitive, prim, compiled, tuple_args,
|
||||
handle_result)
|
||||
compiled = built_c.Compile(compile_options=options, backend=backend)
|
||||
return partial(_execute_compiled_primitive, prim, compiled, backend,
|
||||
tuple_args, handle_result)
|
||||
|
||||
@cache()
|
||||
def primitive_computation(prim, tuple_args, *avals, **params):
|
||||
def primitive_computation(prim, backend, tuple_args, *avals, **params):
|
||||
c = xb.make_computation_builder("primitive_computation_{}".format(prim.name))
|
||||
c.SetOpMetadata(xc.OpMetadata(op_type=prim.name, op_name=str(params)))
|
||||
platform = xb.get_backend(None).platform
|
||||
platform = xb.get_backend(backend).platform
|
||||
xla_args = _xla_callable_args(c, avals, tuple_args)
|
||||
if prim in backend_specific_translations[platform]:
|
||||
rule = backend_specific_translations[platform][prim]
|
||||
@ -198,12 +198,9 @@ def primitive_computation(prim, tuple_args, *avals, **params):
|
||||
elif prim in translations:
|
||||
rule = translations[prim]
|
||||
rule(c, *xla_args, **params) # return val set as a side-effect on c
|
||||
elif prim in reduction_translations:
|
||||
rule = reduction_translations[prim]
|
||||
rule(c, *xla_args, **params) # return val set as a side-effect on c
|
||||
elif prim in initial_style_translations:
|
||||
rule = initial_style_translations[prim]
|
||||
rule(c, AxisEnv(), *xla_args, **params) # side-effect on c
|
||||
rule(c, AxisEnv(), *xla_args, backend=backend, **params) # side-effect on c
|
||||
else:
|
||||
raise NotImplementedError("XLA translation rule for {} not found".format(prim))
|
||||
c.ClearOpMetadata()
|
||||
@ -216,14 +213,14 @@ def primitive_computation(prim, tuple_args, *avals, **params):
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def primitive_subcomputation(prim, *avals, **params):
|
||||
return primitive_computation(prim, False, *avals, **params)
|
||||
return primitive_computation(prim, None, False, *avals, **params)
|
||||
|
||||
def _execute_compiled_primitive(prim, compiled, tuple_args,
|
||||
def _execute_compiled_primitive(prim, compiled, backend, tuple_args,
|
||||
result_handler, *args):
|
||||
device, = compiled.local_devices()
|
||||
input_bufs = [device_put(x, device) for x in args if x is not token]
|
||||
if tuple_args:
|
||||
input_bufs = [make_tuple(input_bufs, device, None)]
|
||||
input_bufs = [make_tuple(input_bufs, device, backend)]
|
||||
out_buf = compiled.Execute(input_bufs)
|
||||
if FLAGS.jax_debug_nans:
|
||||
check_nans(prim, out_buf.destructure() if prim.multiple_results else out_buf)
|
||||
@ -296,17 +293,13 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, freevars, *args):
|
||||
ans = rule(c, *in_nodes, **eqn.params)
|
||||
elif eqn.primitive in translations:
|
||||
ans = translations[eqn.primitive](c, *in_nodes, **eqn.params)
|
||||
elif eqn.primitive in reduction_translations:
|
||||
new_params = check_backend_params(eqn.params, backend)
|
||||
ans = reduction_translations[eqn.primitive](c, *in_nodes, backend=backend, **new_params)
|
||||
elif eqn.primitive in initial_style_translations:
|
||||
new_params = check_backend_params(eqn.params, backend)
|
||||
rule = initial_style_translations[eqn.primitive]
|
||||
ans = rule(c, axis_env, *in_nodes, backend=backend, **new_params)
|
||||
elif eqn.primitive in parallel_translations:
|
||||
new_params = check_backend_params(eqn.params, backend)
|
||||
replica_groups = axis_groups(axis_env, new_params['axis_name'])
|
||||
new_params = {k: new_params[k] for k in new_params if k != 'axis_name'}
|
||||
replica_groups = axis_groups(axis_env, eqn.params['axis_name'])
|
||||
new_params = {k: v for k, v in eqn.params.items() if k != 'axis_name'}
|
||||
rule = parallel_translations[eqn.primitive]
|
||||
ans = rule(c, *in_nodes, replica_groups=replica_groups, **new_params)
|
||||
elif eqn.primitive in call_translations:
|
||||
@ -431,7 +424,7 @@ def eqn_collectives(eqn):
|
||||
|
||||
def _xla_call_impl(fun, *args, **params):
|
||||
device = params['device']
|
||||
backend = params.get('backend', None)
|
||||
backend = params['backend']
|
||||
compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args))
|
||||
try:
|
||||
return compiled_fun(*args)
|
||||
@ -563,7 +556,7 @@ xla_call_p.def_custom_bind(xla_call)
|
||||
xla_call_p.def_impl(_xla_call_impl)
|
||||
|
||||
def _xla_call_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes,
|
||||
in_nodes, device=None, backend=None):
|
||||
in_nodes, backend, device=None):
|
||||
del device # Ignored.
|
||||
subc = xb.make_computation_builder("jaxpr_subcomputation") # TODO(mattjj): name
|
||||
consts = [subc.ParameterWithShape(c.GetShape(n)) for n in const_nodes]
|
||||
@ -578,7 +571,6 @@ ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
|
||||
### translation tables
|
||||
|
||||
translations = {}
|
||||
reduction_translations = {}
|
||||
parallel_translations = {}
|
||||
initial_style_translations = {}
|
||||
call_translations = {}
|
||||
@ -853,7 +845,7 @@ ad.deflinear(device_put_p, lambda cotangent, **kwargs: [cotangent])
|
||||
|
||||
|
||||
def _remat_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes, in_nodes,
|
||||
backend=None, device=None, concrete=None):
|
||||
backend, device=None, concrete=None):
|
||||
# This looks a lot like _xla_call_translation_rule, except for a widget we use
|
||||
# to foil CSE.
|
||||
del device, concrete # Unused.
|
||||
|
@ -1488,14 +1488,6 @@ def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
||||
return prim
|
||||
|
||||
|
||||
def standard_reduction_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
||||
prim = Primitive(name)
|
||||
prim.def_impl(partial(xla.apply_primitive, prim))
|
||||
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule))
|
||||
xla.reduction_translations[prim] = translation_rule or partial(standard_translate, name)
|
||||
return prim
|
||||
|
||||
|
||||
def standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs):
|
||||
assert all(isinstance(arg, UnshapedArray) for arg in args), args
|
||||
least_specialized = _max(
|
||||
@ -3100,11 +3092,11 @@ def _scatter_shape_rule(operand, scatter_indices, updates, **kwargs):
|
||||
|
||||
def _scatter_translation_rule(c, operand, scatter_indices, updates,
|
||||
update_jaxpr, update_consts, dimension_numbers,
|
||||
updates_shape, backend=None):
|
||||
updates_shape):
|
||||
dtype = c.GetShape(operand).numpy_dtype()
|
||||
init_value = c.Constant(onp.array(0, dtype))
|
||||
update_computation = _reduction_computation(
|
||||
c, update_jaxpr, backend, update_consts, init_value)
|
||||
c, update_jaxpr, update_consts, init_value)
|
||||
indices_shape = c.GetShape(scatter_indices)
|
||||
return c.Scatter(operand, scatter_indices, updates, update_computation,
|
||||
_scatter_dimensions_proto(indices_shape, dimension_numbers))
|
||||
@ -3203,7 +3195,7 @@ def _scatter_batching_rule(
|
||||
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
|
||||
return scatter_op(operand, scatter_indices, updates, dnums), 0
|
||||
|
||||
scatter_add_p = standard_reduction_primitive(
|
||||
scatter_add_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
|
||||
_scatter_translation_rule)
|
||||
ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
|
||||
@ -3212,14 +3204,14 @@ batching.primitive_batchers[scatter_add_p] = (
|
||||
partial(_scatter_batching_rule, scatter_add))
|
||||
|
||||
# TODO(jlebar): Add derivatives.
|
||||
scatter_min_p = standard_reduction_primitive(
|
||||
scatter_min_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
|
||||
_scatter_translation_rule)
|
||||
batching.primitive_batchers[scatter_min_p] = (
|
||||
partial(_scatter_batching_rule, scatter_min))
|
||||
|
||||
# TODO(jlebar): Add derivatives.
|
||||
scatter_max_p = standard_reduction_primitive(
|
||||
scatter_max_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
|
||||
_scatter_translation_rule)
|
||||
batching.primitive_batchers[scatter_max_p] = (
|
||||
@ -3319,7 +3311,7 @@ def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,
|
||||
return val_out, tangent_out
|
||||
|
||||
|
||||
scatter_p = standard_reduction_primitive(
|
||||
scatter_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter',
|
||||
_scatter_translation_rule)
|
||||
ad.primitive_jvps[scatter_p] = _scatter_jvp
|
||||
@ -3330,9 +3322,8 @@ batching.primitive_batchers[scatter_p] = (
|
||||
def _reduce_shape_rule(operand, init_value, computation, jaxpr, consts, dimensions):
|
||||
return tuple(onp.delete(operand.shape, dimensions))
|
||||
|
||||
def _reduce_translation_rule(c, operand, init_value, computation, jaxpr, consts, dimensions,
|
||||
backend=None):
|
||||
xla_computation = _reduction_computation(c, jaxpr, backend, consts, init_value)
|
||||
def _reduce_translation_rule(c, operand, init_value, computation, jaxpr, consts, dimensions):
|
||||
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
|
||||
return c.Reduce(operand, init_value, xla_computation, dimensions)
|
||||
|
||||
def _reduce_batch_rule(batched_args, batch_dims, computation, jaxpr, consts, dimensions):
|
||||
@ -3346,13 +3337,13 @@ def _reduce_batch_rule(batched_args, batch_dims, computation, jaxpr, consts, dim
|
||||
else:
|
||||
raise NotImplementedError # loop and stack
|
||||
|
||||
def _reduction_computation(c, jaxpr, backend, consts, init_value):
|
||||
def _reduction_computation(c, jaxpr, consts, init_value):
|
||||
shape = c.GetShape(init_value)
|
||||
axis_env = xla.AxisEnv() # no parallel primitives inside reductions
|
||||
subc = xla_bridge.make_computation_builder("reduction_computation")
|
||||
consts = [subc.ParameterWithShape(const) for const in consts]
|
||||
args = [subc.ParameterWithShape(shape), subc.ParameterWithShape(shape)]
|
||||
out, = xla.jaxpr_subcomp(subc, jaxpr, backend, axis_env, consts, (), *args)
|
||||
out, = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, (), *args)
|
||||
return subc.Build(out)
|
||||
|
||||
def _masking_defreducer(prim, identity):
|
||||
@ -3374,7 +3365,7 @@ def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes,
|
||||
masked_val = select(mask, padded_val, identity(padded_shape, padded_val.dtype))
|
||||
return prim.bind(masked_val, axes=axes, input_shape=padded_shape)
|
||||
|
||||
reduce_p = standard_reduction_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
|
||||
reduce_p = standard_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
|
||||
_reduce_translation_rule)
|
||||
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
||||
|
||||
@ -3419,8 +3410,7 @@ def _reduce_prod_translation_rule(c, operand, axes):
|
||||
dtype = c.GetShape(operand).numpy_dtype()
|
||||
scalar = ShapedArray((), dtype)
|
||||
return c.Reduce(operand, c.Constant(onp.array(1, dtype)),
|
||||
xla.primitive_subcomputation(mul_p, scalar, scalar),
|
||||
axes)
|
||||
xla.primitive_subcomputation(mul_p, scalar, scalar), axes)
|
||||
|
||||
def _reduce_prod_jvp_rule(tangent, operand, axes):
|
||||
input_shape = onp.array(operand.shape)
|
||||
@ -3528,8 +3518,8 @@ def _reduce_window_shape_rule(operand, init_value, jaxpr, consts,
|
||||
window_strides, padding)
|
||||
|
||||
def _reduce_window_translation_rule(c, operand, init_value, jaxpr, consts,
|
||||
window_dimensions, window_strides, padding, backend=None):
|
||||
xla_computation = _reduction_computation(c, jaxpr, backend, consts, init_value)
|
||||
window_dimensions, window_strides, padding):
|
||||
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
|
||||
return c.ReduceWindow(operand, init_value, xla_computation, window_dimensions,
|
||||
window_strides, padding)
|
||||
|
||||
@ -3550,7 +3540,7 @@ def _generic_reduce_window_batch_rule(
|
||||
window_dimensions, window_strides, padding)
|
||||
|
||||
|
||||
reduce_window_p = standard_reduction_primitive(
|
||||
reduce_window_p = standard_primitive(
|
||||
_reduce_window_shape_rule, _input_dtype, 'reduce_window',
|
||||
_reduce_window_translation_rule)
|
||||
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
|
||||
@ -3683,13 +3673,13 @@ def _select_and_scatter_shape_rule(
|
||||
|
||||
def _select_and_scatter_translation(
|
||||
c, operand, source, init_value, select_jaxpr, select_consts, scatter_jaxpr,
|
||||
scatter_consts, window_dimensions, window_strides, padding, backend=None):
|
||||
select = _reduction_computation(c, select_jaxpr, backend, select_consts, init_value)
|
||||
scatter = _reduction_computation(c, scatter_jaxpr, backend, scatter_consts, init_value)
|
||||
scatter_consts, window_dimensions, window_strides, padding):
|
||||
select = _reduction_computation(c, select_jaxpr, select_consts, init_value)
|
||||
scatter = _reduction_computation(c, scatter_jaxpr, scatter_consts, init_value)
|
||||
return c.SelectAndScatter(operand, select, window_dimensions, window_strides,
|
||||
padding, source, init_value, scatter)
|
||||
|
||||
select_and_scatter_p = standard_reduction_primitive(
|
||||
select_and_scatter_p = standard_primitive(
|
||||
_select_and_scatter_shape_rule, _input_dtype, 'select_and_scatter',
|
||||
_select_and_scatter_translation)
|
||||
|
||||
|
@ -213,7 +213,7 @@ def _while_loop_abstract_eval(*args, **kwargs):
|
||||
return kwargs["body_jaxpr"].out_avals
|
||||
|
||||
def _while_loop_translation_rule(c, axis_env, *args, **kwargs):
|
||||
backend = kwargs.pop('backend', None)
|
||||
backend = kwargs.pop('backend')
|
||||
cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict(
|
||||
kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"])
|
||||
cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])
|
||||
|
@ -147,6 +147,11 @@ _backend_lock = threading.Lock()
|
||||
|
||||
@util.memoize
|
||||
def get_backend(platform=None):
|
||||
# TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
|
||||
# 'backend' values are handled
|
||||
if isinstance(platform, xla_client.Backend):
|
||||
return platform
|
||||
|
||||
with _backend_lock:
|
||||
backend = _backends.get(FLAGS.jax_xla_backend)
|
||||
if backend is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user