revise handling of 'backend' values

This commit is contained in:
Matthew Johnson 2019-12-18 11:18:33 -08:00 committed by Matthew Johnson
parent 286ec51f61
commit 8bd1a46ce7
4 changed files with 41 additions and 54 deletions

View File

@ -172,6 +172,7 @@ def xla_primitive_callable(prim, *arg_specs, **params):
else:
all_devices = it.chain(xb.devices(), xb.devices('cpu'))
device = device and next(d for d in all_devices if (type(d), d.id) == device)
backend = xb.get_device_backend(device)
aval_out = prim.abstract_eval(*avals, **params)
if prim.multiple_results:
handlers = tuple(map(aval_to_result_handler, aval_out))
@ -179,18 +180,17 @@ def xla_primitive_callable(prim, *arg_specs, **params):
else:
handle_result = aval_to_result_handler(aval_out)
tuple_args = len(avals) > 100
built_c = primitive_computation(prim, tuple_args, *avals, **params)
built_c = primitive_computation(prim, backend, tuple_args, *avals, **params)
options = xb.get_compile_options(device_assignment=(device.id,) if device else None)
compiled = built_c.Compile(compile_options=options,
backend=xb.get_device_backend(device))
return partial(_execute_compiled_primitive, prim, compiled, tuple_args,
handle_result)
compiled = built_c.Compile(compile_options=options, backend=backend)
return partial(_execute_compiled_primitive, prim, compiled, backend,
tuple_args, handle_result)
@cache()
def primitive_computation(prim, tuple_args, *avals, **params):
def primitive_computation(prim, backend, tuple_args, *avals, **params):
c = xb.make_computation_builder("primitive_computation_{}".format(prim.name))
c.SetOpMetadata(xc.OpMetadata(op_type=prim.name, op_name=str(params)))
platform = xb.get_backend(None).platform
platform = xb.get_backend(backend).platform
xla_args = _xla_callable_args(c, avals, tuple_args)
if prim in backend_specific_translations[platform]:
rule = backend_specific_translations[platform][prim]
@ -198,12 +198,9 @@ def primitive_computation(prim, tuple_args, *avals, **params):
elif prim in translations:
rule = translations[prim]
rule(c, *xla_args, **params) # return val set as a side-effect on c
elif prim in reduction_translations:
rule = reduction_translations[prim]
rule(c, *xla_args, **params) # return val set as a side-effect on c
elif prim in initial_style_translations:
rule = initial_style_translations[prim]
rule(c, AxisEnv(), *xla_args, **params) # side-effect on c
rule(c, AxisEnv(), *xla_args, backend=backend, **params) # side-effect on c
else:
raise NotImplementedError("XLA translation rule for {} not found".format(prim))
c.ClearOpMetadata()
@ -216,14 +213,14 @@ def primitive_computation(prim, tuple_args, *avals, **params):
raise RuntimeError(msg)
def primitive_subcomputation(prim, *avals, **params):
return primitive_computation(prim, False, *avals, **params)
return primitive_computation(prim, None, False, *avals, **params)
def _execute_compiled_primitive(prim, compiled, tuple_args,
def _execute_compiled_primitive(prim, compiled, backend, tuple_args,
result_handler, *args):
device, = compiled.local_devices()
input_bufs = [device_put(x, device) for x in args if x is not token]
if tuple_args:
input_bufs = [make_tuple(input_bufs, device, None)]
input_bufs = [make_tuple(input_bufs, device, backend)]
out_buf = compiled.Execute(input_bufs)
if FLAGS.jax_debug_nans:
check_nans(prim, out_buf.destructure() if prim.multiple_results else out_buf)
@ -296,17 +293,13 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, freevars, *args):
ans = rule(c, *in_nodes, **eqn.params)
elif eqn.primitive in translations:
ans = translations[eqn.primitive](c, *in_nodes, **eqn.params)
elif eqn.primitive in reduction_translations:
new_params = check_backend_params(eqn.params, backend)
ans = reduction_translations[eqn.primitive](c, *in_nodes, backend=backend, **new_params)
elif eqn.primitive in initial_style_translations:
new_params = check_backend_params(eqn.params, backend)
rule = initial_style_translations[eqn.primitive]
ans = rule(c, axis_env, *in_nodes, backend=backend, **new_params)
elif eqn.primitive in parallel_translations:
new_params = check_backend_params(eqn.params, backend)
replica_groups = axis_groups(axis_env, new_params['axis_name'])
new_params = {k: new_params[k] for k in new_params if k != 'axis_name'}
replica_groups = axis_groups(axis_env, eqn.params['axis_name'])
new_params = {k: v for k, v in eqn.params.items() if k != 'axis_name'}
rule = parallel_translations[eqn.primitive]
ans = rule(c, *in_nodes, replica_groups=replica_groups, **new_params)
elif eqn.primitive in call_translations:
@ -431,7 +424,7 @@ def eqn_collectives(eqn):
def _xla_call_impl(fun, *args, **params):
device = params['device']
backend = params.get('backend', None)
backend = params['backend']
compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args))
try:
return compiled_fun(*args)
@ -563,7 +556,7 @@ xla_call_p.def_custom_bind(xla_call)
xla_call_p.def_impl(_xla_call_impl)
def _xla_call_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes,
in_nodes, device=None, backend=None):
in_nodes, backend, device=None):
del device # Ignored.
subc = xb.make_computation_builder("jaxpr_subcomputation") # TODO(mattjj): name
consts = [subc.ParameterWithShape(c.GetShape(n)) for n in const_nodes]
@ -578,7 +571,6 @@ ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
### translation tables
translations = {}
reduction_translations = {}
parallel_translations = {}
initial_style_translations = {}
call_translations = {}
@ -853,7 +845,7 @@ ad.deflinear(device_put_p, lambda cotangent, **kwargs: [cotangent])
def _remat_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes, in_nodes,
backend=None, device=None, concrete=None):
backend, device=None, concrete=None):
# This looks a lot like _xla_call_translation_rule, except for a widget we use
# to foil CSE.
del device, concrete # Unused.

View File

@ -1488,14 +1488,6 @@ def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
return prim
def standard_reduction_primitive(shape_rule, dtype_rule, name, translation_rule=None):
prim = Primitive(name)
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule))
xla.reduction_translations[prim] = translation_rule or partial(standard_translate, name)
return prim
def standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs):
assert all(isinstance(arg, UnshapedArray) for arg in args), args
least_specialized = _max(
@ -3100,11 +3092,11 @@ def _scatter_shape_rule(operand, scatter_indices, updates, **kwargs):
def _scatter_translation_rule(c, operand, scatter_indices, updates,
update_jaxpr, update_consts, dimension_numbers,
updates_shape, backend=None):
updates_shape):
dtype = c.GetShape(operand).numpy_dtype()
init_value = c.Constant(onp.array(0, dtype))
update_computation = _reduction_computation(
c, update_jaxpr, backend, update_consts, init_value)
c, update_jaxpr, update_consts, init_value)
indices_shape = c.GetShape(scatter_indices)
return c.Scatter(operand, scatter_indices, updates, update_computation,
_scatter_dimensions_proto(indices_shape, dimension_numbers))
@ -3203,7 +3195,7 @@ def _scatter_batching_rule(
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
return scatter_op(operand, scatter_indices, updates, dnums), 0
scatter_add_p = standard_reduction_primitive(
scatter_add_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
_scatter_translation_rule)
ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
@ -3212,14 +3204,14 @@ batching.primitive_batchers[scatter_add_p] = (
partial(_scatter_batching_rule, scatter_add))
# TODO(jlebar): Add derivatives.
scatter_min_p = standard_reduction_primitive(
scatter_min_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
_scatter_translation_rule)
batching.primitive_batchers[scatter_min_p] = (
partial(_scatter_batching_rule, scatter_min))
# TODO(jlebar): Add derivatives.
scatter_max_p = standard_reduction_primitive(
scatter_max_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
_scatter_translation_rule)
batching.primitive_batchers[scatter_max_p] = (
@ -3319,7 +3311,7 @@ def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,
return val_out, tangent_out
scatter_p = standard_reduction_primitive(
scatter_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter',
_scatter_translation_rule)
ad.primitive_jvps[scatter_p] = _scatter_jvp
@ -3330,9 +3322,8 @@ batching.primitive_batchers[scatter_p] = (
def _reduce_shape_rule(operand, init_value, computation, jaxpr, consts, dimensions):
return tuple(onp.delete(operand.shape, dimensions))
def _reduce_translation_rule(c, operand, init_value, computation, jaxpr, consts, dimensions,
backend=None):
xla_computation = _reduction_computation(c, jaxpr, backend, consts, init_value)
def _reduce_translation_rule(c, operand, init_value, computation, jaxpr, consts, dimensions):
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
return c.Reduce(operand, init_value, xla_computation, dimensions)
def _reduce_batch_rule(batched_args, batch_dims, computation, jaxpr, consts, dimensions):
@ -3346,13 +3337,13 @@ def _reduce_batch_rule(batched_args, batch_dims, computation, jaxpr, consts, dim
else:
raise NotImplementedError # loop and stack
def _reduction_computation(c, jaxpr, backend, consts, init_value):
def _reduction_computation(c, jaxpr, consts, init_value):
shape = c.GetShape(init_value)
axis_env = xla.AxisEnv() # no parallel primitives inside reductions
subc = xla_bridge.make_computation_builder("reduction_computation")
consts = [subc.ParameterWithShape(const) for const in consts]
args = [subc.ParameterWithShape(shape), subc.ParameterWithShape(shape)]
out, = xla.jaxpr_subcomp(subc, jaxpr, backend, axis_env, consts, (), *args)
out, = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, (), *args)
return subc.Build(out)
def _masking_defreducer(prim, identity):
@ -3374,7 +3365,7 @@ def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes,
masked_val = select(mask, padded_val, identity(padded_shape, padded_val.dtype))
return prim.bind(masked_val, axes=axes, input_shape=padded_shape)
reduce_p = standard_reduction_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
reduce_p = standard_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
_reduce_translation_rule)
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
@ -3419,8 +3410,7 @@ def _reduce_prod_translation_rule(c, operand, axes):
dtype = c.GetShape(operand).numpy_dtype()
scalar = ShapedArray((), dtype)
return c.Reduce(operand, c.Constant(onp.array(1, dtype)),
xla.primitive_subcomputation(mul_p, scalar, scalar),
axes)
xla.primitive_subcomputation(mul_p, scalar, scalar), axes)
def _reduce_prod_jvp_rule(tangent, operand, axes):
input_shape = onp.array(operand.shape)
@ -3528,8 +3518,8 @@ def _reduce_window_shape_rule(operand, init_value, jaxpr, consts,
window_strides, padding)
def _reduce_window_translation_rule(c, operand, init_value, jaxpr, consts,
window_dimensions, window_strides, padding, backend=None):
xla_computation = _reduction_computation(c, jaxpr, backend, consts, init_value)
window_dimensions, window_strides, padding):
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
return c.ReduceWindow(operand, init_value, xla_computation, window_dimensions,
window_strides, padding)
@ -3550,7 +3540,7 @@ def _generic_reduce_window_batch_rule(
window_dimensions, window_strides, padding)
reduce_window_p = standard_reduction_primitive(
reduce_window_p = standard_primitive(
_reduce_window_shape_rule, _input_dtype, 'reduce_window',
_reduce_window_translation_rule)
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
@ -3683,13 +3673,13 @@ def _select_and_scatter_shape_rule(
def _select_and_scatter_translation(
c, operand, source, init_value, select_jaxpr, select_consts, scatter_jaxpr,
scatter_consts, window_dimensions, window_strides, padding, backend=None):
select = _reduction_computation(c, select_jaxpr, backend, select_consts, init_value)
scatter = _reduction_computation(c, scatter_jaxpr, backend, scatter_consts, init_value)
scatter_consts, window_dimensions, window_strides, padding):
select = _reduction_computation(c, select_jaxpr, select_consts, init_value)
scatter = _reduction_computation(c, scatter_jaxpr, scatter_consts, init_value)
return c.SelectAndScatter(operand, select, window_dimensions, window_strides,
padding, source, init_value, scatter)
select_and_scatter_p = standard_reduction_primitive(
select_and_scatter_p = standard_primitive(
_select_and_scatter_shape_rule, _input_dtype, 'select_and_scatter',
_select_and_scatter_translation)

View File

@ -213,7 +213,7 @@ def _while_loop_abstract_eval(*args, **kwargs):
return kwargs["body_jaxpr"].out_avals
def _while_loop_translation_rule(c, axis_env, *args, **kwargs):
backend = kwargs.pop('backend', None)
backend = kwargs.pop('backend')
cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict(
kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"])
cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])

View File

@ -147,6 +147,11 @@ _backend_lock = threading.Lock()
@util.memoize
def get_backend(platform=None):
# TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
# 'backend' values are handled
if isinstance(platform, xla_client.Backend):
return platform
with _backend_lock:
backend = _backends.get(FLAGS.jax_xla_backend)
if backend is None: