diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 1979d3ceb..e949e8205 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -2007,7 +2007,7 @@ " return [xla_consts[id(cnst)] for cnst in consts]\n", "\n", "def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]:\n", - " return [xb.parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]\n", + " return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]\n", "\n", "def _xla_shape(aval: ShapedArray) -> xe.Shape:\n", " return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)" @@ -3630,7 +3630,7 @@ "\n", " def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:\n", " c = xc.XlaBuilder(name)\n", - " operand = xb.parameter(c, 0, operand_shape)\n", + " operand = xops.Parameter(c, 0, operand_shape)\n", " operands = tree_unflatten(in_tree, destructure_tuple(c, operand))\n", " outs = jaxpr_subcomp(c, jaxpr, operands)\n", " return c.build(xops.Tuple(c, outs))\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index e07841e76..b6412003b 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -1577,7 +1577,7 @@ def _xla_consts(c: xe.XlaBuilder, consts: List[Any]) -> List[xe.XlaOp]: return [xla_consts[id(cnst)] for cnst in consts] def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]: - return [xb.parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)] + return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)] def _xla_shape(aval: ShapedArray) -> xe.Shape: return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape) @@ -2844,7 +2844,7 @@ def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr): def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation: c = xc.XlaBuilder(name) - operand = xb.parameter(c, 0, operand_shape) + operand = xops.Parameter(c, 0, operand_shape) operands = tree_unflatten(in_tree, destructure_tuple(c, operand)) outs = jaxpr_subcomp(c, jaxpr, operands) return c.build(xops.Tuple(c, outs)) diff --git a/docs/autodidax.py b/docs/autodidax.py index 155d52d37..88d25acb5 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -1569,7 +1569,7 @@ def _xla_consts(c: xe.XlaBuilder, consts: List[Any]) -> List[xe.XlaOp]: return [xla_consts[id(cnst)] for cnst in consts] def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]: - return [xb.parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)] + return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)] def _xla_shape(aval: ShapedArray) -> xe.Shape: return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape) @@ -2836,7 +2836,7 @@ def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr): def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation: c = xc.XlaBuilder(name) - operand = xb.parameter(c, 0, operand_shape) + operand = xops.Parameter(c, 0, operand_shape) operands = tree_unflatten(in_tree, destructure_tuple(c, operand)) outs = jaxpr_subcomp(c, jaxpr, operands) return c.build(xops.Tuple(c, outs)) diff --git a/jax/_src/api.py b/jax/_src/api.py index d4f1fc471..8c59ebd78 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -860,7 +860,7 @@ def xla_computation(fun: Callable, out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args) build_out_tuple = partial(xc.ops.Tuple, c, out_nodes) if out_parts is not None: - out_tuple = xb.with_sharding(c, out_parts_flat, build_out_tuple) + out_tuple = xla.with_sharding(c, out_parts_flat, build_out_tuple) else: out_tuple = build_out_tuple() diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 3d4c5b858..c121b4f77 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -218,42 +218,16 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, tuple_args = len(abstract_args) > 100 axis_env = xla.AxisEnv(nreps, (), ()) name_stack = xla.extend_name_stack(xla.wrap_name(name, 'jit')) - module: Any + closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + module: Union[str, xc.XlaComputation] if config.jax_enable_mlir: - # TODO(b/203122001): implement buffer donation. - assert not any(donated_invars), donated_invars module = mlir.lower_jaxpr_to_module( - core.ClosedJaxpr(jaxpr, consts), backend.platform, axis_env, name_stack) + closed_jaxpr, backend.platform, axis_env, name_stack, donated_invars) else: - # XLA HLO lowering path - c = xc.XlaBuilder(f"jit_{fun.__name__}") - xla_consts = xla._xla_consts(c, consts) - xla_args, donated_invars = xla._xla_callable_args( - c, abstract_args, tuple_args, donated_invars=donated_invars) - platform = backend.platform - ctx = xla.TranslationContext(c, platform, axis_env, name_stack) - out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args) - # Replace tokens with a dummy array value, because the runtime cannot - # handle token arguments. - out_aval_lens = [len(xla.aval_to_xla_shapes(a)) for a in out_avals] - out_nodes = util.flatten( - [[xla._make_token_return_value(c)] if a is core.abstract_token - else v - for a, v in zip(out_avals, util.unflatten(out_nodes, out_aval_lens))]) - - # There is a non-zero cost to building an output tuple, particularly on TPU. - # Avoid it if the output arity is 1. - output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes) - if platform in ("gpu", "tpu"): - donated_invars = xla.set_up_aliases( - c, xla_args, c.GetShape(output), donated_invars, tuple_args) - if any(donated_invars): - # TODO(tomhennigan): At call time we should mark these buffers as deleted. - unused_donations = [str(c.GetShape(a)) - for a, d in zip(xla_args, donated_invars) if d] - warnings.warn("Some donated buffers were not usable: {}".format( - ", ".join(unused_donations))) - module = c.build(output) + module = xla.lower_jaxpr_to_xla_module( + f"jit_{fun.__name__}", closed_jaxpr, backend.platform, axis_env, + name_stack, tuple_args, donated_invars, replicated_args=None, + arg_partitions=None, out_partitions=None) return XlaComputation( name, module, False, donated_invars, nreps, device, backend, tuple_args, abstract_args, out_avals, kept_var_idx) diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index f32aa7d67..781bcd4f2 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -48,7 +48,6 @@ from jax.interpreters import batching from jax.interpreters import masking from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import mhlo -from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client from jax._src.traceback_util import api_boundary from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, @@ -342,7 +341,7 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr, init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals) cond_c = xla_client.XlaBuilder("cond_computation") - cond_carry = xb.parameter(cond_c, 0, c.get_shape(init_carry)) + cond_carry = xla.parameter(cond_c, 0, c.get_shape(init_carry)) cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))] x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts]) cond_ctx = ctx.replace(builder=cond_c, @@ -359,7 +358,7 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr, or_, list(range(cond_jaxpr.out_avals[0].ndim))) body_c = xla_client.XlaBuilder("body_computation") - body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry)) + body_carry = xla.parameter(body_c, 0, c.get_shape(init_carry)) body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))] x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts]) body_ctx = ctx.replace(builder=body_c, @@ -931,7 +930,7 @@ def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches, name_stack = extend_name_stack(ctx.name_stack, "cond") def make_computation(name, jaxpr, op_shape): c = xla_client.XlaBuilder(name + '_comp') - op = xb.parameter(c, 0, op_shape) + op = xla.parameter(c, 0, op_shape) ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))] subctx = ctx.replace( builder=c, name_stack=extend_name_stack(name_stack, name + '_fun')) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8960f79fb..125125670 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3224,7 +3224,7 @@ def _reduction_computation(ctx, jaxpr, consts, init_values, singleton=True): axis_env = xla.AxisEnv(1, (), ()) # no parallel primitives inside reductions subc = xc.XlaBuilder("reduction_computation") assert len(consts) == 0, "Reduction computations cannot have constants" - args = [xb.parameter(subc, i, shape) for i, shape in enumerate(shapes)] + args = [xla.parameter(subc, i, shape) for i, shape in enumerate(shapes)] ctx = xla.TranslationContext(subc, platform, axis_env, '') out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, consts, *args) if singleton: @@ -3684,7 +3684,7 @@ def _sort_translation_rule(ctx, avals_in, avals_out, *operands, dimension, c = ctx.builder types = [c.get_shape(x).xla_element_type() for x in operands] subc = xc.XlaBuilder("sort_lt_comparator") - params = [xb.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ())) + params = [xla.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ())) for i, typ in enumerate(types) for j in range(2)] result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys), backend=ctx.platform, @@ -3918,7 +3918,7 @@ def _infeed_translation_rule(ctx, avals_in, avals_out, token, *, shapes, build_infeed = partial(xops.InfeedWithToken, token, xla_client.Shape.tuple_shape(shape)) if partitions: - xs_and_token = xb.with_sharding(c, partitions, build_infeed) + xs_and_token = xla.with_sharding(c, partitions, build_infeed) else: # Note that infeed will default to replication if inside a sharded # computation and no sharding is specified. @@ -3986,8 +3986,8 @@ def _outfeed_translation_rule(ctx, avals_in, avals_out, token, *xs, partitions): c = ctx.builder t = xops.Tuple(c, xs) if partitions is not None: - return [xb.with_sharding(c, partitions, xops.OutfeedWithToken, - t, token, c.get_shape(t))] + return [xla.with_sharding(c, partitions, xops.OutfeedWithToken, + t, token, c.get_shape(t))] else: return [xops.OutfeedWithToken(t, token, c.get_shape(t))] diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index f59d284ee..46a090634 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1487,7 +1487,7 @@ def _scatter_add_translation_rule( def _make_reducer(dtype): subc = xc.XlaBuilder("scatter_add_reducer") shape = xc.Shape.array_shape(np.dtype(dtype), ()) - args = [xb.parameter(subc, 0, shape), xb.parameter(subc, 1, shape)] + args = [xla.parameter(subc, 0, shape), xla.parameter(subc, 1, shape)] out = xops.Add(args[0], args[1]) return subc.build(out) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 79440d944..e01d31db5 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -806,9 +806,9 @@ def _select_and_gather_add_translation( def reducer(): c = xc.XlaBuilder("select_and_gather_pair_reducer") - x = xb.parameter(c, 0, + x = xla.parameter(c, 0, xla_client.Shape.array_shape(np.dtype(double_word_dtype), ())) - y = xb.parameter(c, 1, + y = xla.parameter(c, 1, xla_client.Shape.array_shape(np.dtype(double_word_dtype), ())) assert select_prim is lax.ge_p or select_prim is lax.le_p which = xops.Ge if select_prim is lax.ge_p else xops.Le diff --git a/jax/_src/lib/xla_bridge.py b/jax/_src/lib/xla_bridge.py index 8834d5bce..5011932cf 100644 --- a/jax/_src/lib/xla_bridge.py +++ b/jax/_src/lib/xla_bridge.py @@ -23,7 +23,7 @@ XLA. There are also a handful of related casting utilities. from functools import partial, lru_cache import os import threading -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import warnings from absl import logging @@ -435,81 +435,3 @@ def host_ids(backend=None): "instead. jax.host_ids will eventually be removed; please update your " "code.") return list(range(process_count(backend))) - - -### utility functions - -def parameter(builder, num, shape, name=None, replicated=None): - if name is None: - name = '' - if replicated is None: - replicated = [] - elif isinstance(replicated, bool): - replicated = [replicated] * shape.leaf_count() - - return xops.Parameter(builder, num, - shape.with_major_to_minor_layout_if_absent(), name, - replicated) - -# HLO instructions optionally can be annotated to say how the output should be -# spatially partitioned (represented in XLA as OpSharding protos, see -# _sharding_to_proto). For array outputs, the annotation is either an int per -# dimension specifying the number of ways that dimension divided (i.e. the total -# number of shards is the product), or None to indicate the array should be -# replicated. Tuple outputs are represented as tuples thereof. XLA supports -# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type -# checkers don't support recursive types), so we only represent one level of -# nesting in this type definition. -SpatialSharding = Union[Tuple[int, ...], - None, - Tuple[Union[Tuple[int, ...], None], ...]] - -def _sharding_to_proto(sharding: SpatialSharding): - """Converts a SpatialSharding to an OpSharding. - - See - https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601 - for details on the OpSharding proto. - """ - proto = xla_client.OpSharding() - if isinstance(sharding, tuple) and not isinstance(sharding[0], int): - assert all(s is None or isinstance(s, tuple) for s in sharding) - return tuple_sharding_proto(list(map(_sharding_to_proto, sharding))) # type: ignore - - if sharding is None: - proto.type = xla_client.OpSharding.Type.REPLICATED - else: - proto.type = xla_client.OpSharding.Type.OTHER - proto.tile_assignment_dimensions = list(sharding) - proto.tile_assignment_devices = list(range(np.product(sharding))) - return proto - -def tuple_sharding_proto(elems): - proto = xla_client.OpSharding() - assert all(isinstance(e, type(proto)) for e in elems) - proto.type = xla_client.OpSharding.Type.TUPLE - proto.tuple_shardings = elems - return proto - -def set_sharding_proto(builder, op, sharding_proto): - """Uses CustomCall to annotate a value as sharded.""" - # "Sharding" is a built-in custom call target that acts like an identity - # function, and is used to attach an OpSharding to. - return with_sharding_proto(builder, sharding_proto, xops.CustomCall, - builder, b"Sharding", [op], builder.get_shape(op)) - -def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs): - """Builds op_fn(*args, **kwargs) with sharding annotation.""" - builder.set_sharding(sharding_proto) - try: - return op_fn(*args, **kwargs) - finally: - builder.clear_sharding() - -def set_sharding(builder, op, sharding: SpatialSharding): - """Uses CustomCall to annotate a value as sharded.""" - return set_sharding_proto(builder, op, _sharding_to_proto(sharding)) - -def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs): - """Builds op_fn(*args, **kwargs) with sharding annotation.""" - return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs) diff --git a/jax/experimental/ann.py b/jax/experimental/ann.py index 419c0eadb..474dead82 100644 --- a/jax/experimental/ann.py +++ b/jax/experimental/ann.py @@ -63,7 +63,6 @@ from typing import (Any, Tuple) import numpy as np from jax import lax, core -from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc from jax._src import ad_util, dtypes @@ -167,10 +166,10 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, def _comparator_builder(operand, op_type, is_max_k): c = xc.XlaBuilder( 'top_k_{}_comparator'.format('gt' if is_max_k else 'lt')) - p0 = xb.parameter(c, 0, xc.Shape.scalar_shape(op_type)) - p1 = xb.parameter(c, 1, xc.Shape.scalar_shape(op_type)) - xb.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32))) - xb.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32))) + p0 = xla.parameter(c, 0, xc.Shape.scalar_shape(op_type)) + p1 = xla.parameter(c, 1, xc.Shape.scalar_shape(op_type)) + xla.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32))) + xla.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32))) if is_max_k: cmp_result = xc.ops.Gt(p0, p1) else: diff --git a/jax/experimental/djax.py b/jax/experimental/djax.py index 6ca12b08f..fe88991d8 100644 --- a/jax/experimental/djax.py +++ b/jax/experimental/djax.py @@ -640,7 +640,7 @@ xla.canonicalize_dtype_handlers[BoundedInt] = _bdint_canoncalize_dtype def _make_params(c, dim_in_avals, in_avals): n = it.count() - make = lambda a: [xb.parameter(c, next(n), s) for s in xla.aval_to_xla_shapes(a)] + make = lambda a: [xla.parameter(c, next(n), s) for s in xla.aval_to_xla_shapes(a)] return map(make, dim_in_avals), map(make, in_avals) def _xla_consts(c, consts): diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 9803c27f2..f5d8c92db 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -946,7 +946,7 @@ def _outside_call_translation_rule(ctx, avals_in, avals_out, token_sharding_proto = xla_client.OpSharding() token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED - infeed_sharding_proto = xb.tuple_sharding_proto( + infeed_sharding_proto = xla.tuple_sharding_proto( [array_sharding_proto] * len(non_empty_flat_results_aval) + [token_sharding_proto]) @@ -959,8 +959,8 @@ def _outside_call_translation_rule(ctx, avals_in, avals_out, build_infeed = functools.partial(xops.InfeedWithToken, after_outfeed_itoken, xla_client.Shape.tuple_shape(shape)) - outs_and_token = xb.with_sharding_proto(comp, infeed_sharding_proto, - build_infeed) + outs_and_token = xla.with_sharding_proto(comp, infeed_sharding_proto, + build_infeed) outs = xops.GetTupleElement(outs_and_token, 0) next_itoken = xops.GetTupleElement(outs_and_token, 1) non_empty_results = [ diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index a67186bb1..5795a3de4 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -1464,7 +1464,7 @@ def _xmap_translation_rule_spmd(c, axis_env, global_sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names) sharded_global_in_nodes = [ - xb.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto()) + xla.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto()) if aval_axes else node for node, aval, aval_axes in zip(global_in_nodes, global_in_avals, mesh_in_axes) ] @@ -1478,7 +1478,7 @@ def _xmap_translation_rule_spmd(c, axis_env, *sharded_global_in_nodes) sharded_global_out_nodes = [ - xb.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto()) + xla.set_sharding_proto(c, node, global_sharding_spec(aval, aval_axes).sharding_proto()) if aval_axes else node for node, aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes) ] diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index f1cb94bde..a806954da 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -37,7 +37,6 @@ from jax.interpreters import xla from jax.interpreters import batching from jax.interpreters import partial_eval as pe from jax.interpreters.sharded_jit import PartitionSpec -from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves from jax._src.util import (extend_name_stack, HashableFunction, safe_zip, @@ -514,8 +513,8 @@ def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name, for i, (n, axis_resources) in enumerate(safe_zip(in_nodes, in_axis_resources)): # N.B. inlined calls shouldn't have shardings set directly on the inputs or # outputs (set_sharding_proto adds an identity operation). - arg = xb.parameter(subc, i, c.GetShape(n)) - args.append(xb.set_sharding_proto(subc, arg, + arg = xla.parameter(subc, i, c.GetShape(n)) + args.append(xla.set_sharding_proto(subc, arg, get_sharding_proto(c, n, axis_resources, mesh))) # TODO: Think about how to avoid duplicating constants with the outer jaxpr @@ -525,7 +524,7 @@ def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name, out_nodes = xla.jaxpr_subcomp( ctx, jaxpr.jaxpr, xla._xla_consts(subc, jaxpr.consts), *args) out_nodes = [ - xb.set_sharding_proto(subc, out, + xla.set_sharding_proto(subc, out, get_sharding_proto(subc, out, axis_resources, mesh)) for out, axis_resources in safe_zip(out_nodes, out_axis_resources) ] @@ -815,7 +814,7 @@ ad.deflinear2(sharding_constraint_p, def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node, *, axis_resources, resource_env): mesh = resource_env.physical_mesh - return [xb.set_sharding_proto( + return [xla.set_sharding_proto( ctx.builder, x_node, get_sharding_proto(ctx.builder, x_node, axis_resources, mesh))] xla.register_translation(sharding_constraint_p, _sharding_constraint_translation_rule) diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 6bbf69171..537c189fb 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -23,6 +23,7 @@ import typing from typing import (Any, Callable, Dict, List, Optional, Sequence, Type, Union, Tuple) from typing_extensions import Protocol +import warnings from jax import core from jax import linear_util as lu @@ -319,11 +320,22 @@ def flatten_lowering_ir_args( return util.flatten(map(wrap_singleton_ir_values, xs)) def lower_jaxpr_to_module(jaxpr: core.ClosedJaxpr, platform: str, - axis_env: xla.AxisEnv, name_stack: str) -> str: + axis_env: xla.AxisEnv, name_stack: str, + donated_invars: Sequence[bool]) -> str: """Lowers a top-level jaxpr to an MHLO module. Handles the quirks of the argument/return value passing conventions of the runtime.""" + if platform in ("gpu", "tpu"): + # TODO(b/203122001): implement buffer donation. + assert not any(donated_invars), donated_invars + if any(donated_invars): + # TODO(tomhennigan): At call time we should mark these buffers as deleted. + unused_donations = [str(a) for a, d in zip(jaxpr.in_avals, donated_invars) + if d] + warnings.warn("Some donated buffers were not usable: {}".format( + ", ".join(unused_donations))) + ctx = LoweringContext(platform, axis_env, name_stack) if platform == "iree": ctx = ctx.replace(tuple_results=False) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 1df788d9f..c10d677c5 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -995,31 +995,23 @@ def lower_parallel_callable( axis_env = xla.AxisEnv( replicas.num_global_replicas, (axis_name,), (global_axis_size,)) - - c = xc.XlaBuilder("pmap_{}".format(fun.__name__)) - xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts) + name_stack = extend_name_stack(wrap_name(name, 'pmap')) + closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) replicated_args = [axis is None for axis in in_axes] - xla_args, donated_invars = xla._xla_callable_args( - c, shards.global_sharded_avals, tuple_args(shards), - replicated=replicated_args, - partitions=parts.arg_parts, - donated_invars=donated_invars) + module: Union[str, xc.XlaComputation] with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore - ctx = xla.TranslationContext(c, backend.platform, axis_env, - extend_name_stack(wrap_name(name, 'pmap'))) - out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args) - build_out_tuple = partial(xops.Tuple, c, out_nodes) - if parts.out_parts is not None: - out_tuple = xb.with_sharding(c, parts.out_parts, build_out_tuple) - else: - out_tuple = build_out_tuple() + if config.jax_enable_mlir: + # TODO(phawkins): handle replicated_args. + # TODO(phawkins): handle sharding. + module = mlir.lower_jaxpr_to_module( + closed_jaxpr, backend.platform, axis_env, name_stack, donated_invars) + else: + module = xla.lower_jaxpr_to_xla_module( + f"pmap_{fun.__name__}", closed_jaxpr, backend.platform, axis_env, + name_stack, tuple_args(shards), donated_invars, replicated_args, + parts.arg_parts, parts.out_parts) - if backend.platform in ("gpu", "tpu"): - donated_invars = xla.set_up_aliases(c, xla_args, c.GetShape(out_tuple), - donated_invars, tuple_args(shards)) - built = c.Build(out_tuple) - - return PmapComputation(built, pci, replicas, parts, shards) + return PmapComputation(module, pci, replicas, parts, shards) class PmapComputation: @@ -1917,8 +1909,6 @@ def lower_mesh_computation( jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) # 3. Build up the HLO - c = xc.XlaBuilder(f"xmap_{fun.__name__}") - xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts) tuple_args = len(in_jaxpr_avals) > 100 # pass long arg lists as tuple for TPU in_partitions: Optional[List] if spmd_lowering: @@ -1929,40 +1919,29 @@ def lower_mesh_computation( for aval, aval_in_axes in safe_zip(global_in_untiled_avals, in_axes)] out_partitions = [global_sharding_spec(aval, aval_out_axes).sharding_proto() for aval, aval_out_axes in safe_zip(global_out_untiled_avals, out_axes)] + out_partitions_t = xla.tuple_sharding_proto(out_partitions) partitions_proto = True axis_env = xla.AxisEnv(nreps=1, names=(), sizes=()) # All named axes have been vmapped else: replicated_args = [not axis for axis in in_axes] in_partitions = None + out_partitions_t = None partitions_proto = False axis_env = xla.AxisEnv(nreps=mesh.size, names=tuple(global_axis_sizes.keys()), sizes=tuple(global_axis_sizes.values())) - xla_args, donated_invars = xla._xla_callable_args( - c, in_jaxpr_avals, tuple_args, - replicated=replicated_args, - partitions=in_partitions, - partitions_proto=partitions_proto, - donated_invars=donated_invars) + closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + name_stack = extend_name_stack(wrap_name(transformed_name, 'xmap')) with core.extend_axis_env_nd(mesh.shape.items()): - ctx = xla.TranslationContext( - c, backend.platform, axis_env, - extend_name_stack(wrap_name(transformed_name, 'xmap'))) - out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args) - if spmd_lowering: - out_partitions_t = xb.tuple_sharding_proto(out_partitions) - out_tuple = xb.with_sharding_proto(c, out_partitions_t, xops.Tuple, c, out_nodes) - else: - out_tuple = xops.Tuple(c, out_nodes) + # TODO(phawkins): add MLIR lowering. + module = xla.lower_jaxpr_to_xla_module( + f"xmap_{fun.__name__}", closed_jaxpr, backend.platform, axis_env, + name_stack, tuple_args, donated_invars, replicated_args, + in_partitions, out_partitions_t, + partitions_are_protos=partitions_proto) - if backend.platform in ("gpu", "tpu"): - xla.set_up_aliases(c, xla_args, c.GetShape(out_tuple), donated_invars, - tuple_args) - # TODO: Warn about unused donations? - - built = c.Build(out_tuple) return MeshComputation( - built, donated_invars, mesh, local_in_untiled_avals, + module, donated_invars, mesh, local_in_untiled_avals, local_out_untiled_avals, (out_jaxpr_avals if spmd_lowering else None), in_axes, out_axes, spmd_lowering, tuple_args) diff --git a/jax/interpreters/sharded_jit.py b/jax/interpreters/sharded_jit.py index ffc74450d..e690e9dfb 100644 --- a/jax/interpreters/sharded_jit.py +++ b/jax/interpreters/sharded_jit.py @@ -147,7 +147,7 @@ def _sharded_callable( ctx = xla.TranslationContext( c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args) - out_tuple = xb.with_sharding(c, out_parts, xops.Tuple, c, out_nodes) + out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes) built = c.Build(out_tuple) if nparts <= xb.local_device_count(): @@ -191,10 +191,10 @@ def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack, args = [] for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)): - # We use xb.set_sharding instead of xb.with_sharding because inlined calls + # We use xla.set_sharding instead of xla.with_sharding because inlined calls # shouldn't have shardings set directly on the inputs or outputs. - arg = xb.parameter(subc, i, c.GetShape(n)) - args.append(xb.set_sharding(subc, arg, sharding)) + arg = xla.parameter(subc, i, c.GetShape(n)) + args.append(xla.set_sharding(subc, arg, sharding)) ctx = xla.TranslationContext( subc, backend, axis_env, @@ -202,7 +202,7 @@ def _sharded_jit_translation_rule(c, axis_env, in_nodes, name_stack, out_nodes = xla.jaxpr_subcomp(ctx, call_jaxpr, (), *args) out_parts = out_parts_thunk() assert len(out_parts) == len(out_nodes) - out_nodes = [xb.set_sharding(subc, out, sharding) + out_nodes = [xla.set_sharding(subc, out, sharding) for out, sharding in safe_zip(out_nodes, out_parts)] subc = subc.build(xops.Tuple(subc, out_nodes)) @@ -218,7 +218,7 @@ def _execute_spatially_partitioned(compiled, in_handler, out_handler, *args): def _xla_sharded_args(c, avals, in_parts): xla_args = [] for i, (sharding, aval) in enumerate(safe_zip(in_parts, avals)): - param = xb.with_sharding(c, sharding, xb.parameter, c, i, + param = xla.with_sharding(c, sharding, xla.parameter, c, i, *xla.aval_to_xla_shapes(aval)) xla_args.append(param) return xla_args @@ -413,7 +413,7 @@ def _sharding_constraint_impl(x, partitions): def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node, partitions): - return [xb.set_sharding(ctx.builder, x_node, partitions)] + return [xla.set_sharding(ctx.builder, x_node, partitions)] sharding_constraint_p = core.Primitive("sharding_constraint") sharding_constraint_p.def_impl(_sharding_constraint_impl) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 649842ef9..1e74ef812 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -22,9 +22,10 @@ from functools import partial import itertools as it import operator import re -from typing import (Any, Callable, Deque, Dict, List, Optional, Sequence, Set, - Type, Tuple, NamedTuple) +from typing import (Any, Callable, Deque, Dict, List, NamedTuple, Optional, + Sequence, Set, Type, Tuple, Union) from typing_extensions import Protocol +import warnings import numpy as np @@ -43,7 +44,6 @@ import jax._src.pretty_printer as pp from jax._src import util from jax._src.util import (prod, extend_name_stack, wrap_name, safe_zip, safe_map, partition_list) -from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc from jax.interpreters import partial_eval as pe from jax.interpreters import ad @@ -113,6 +113,84 @@ def make_op_metadata(primitive: core.Primitive, source_file=_get_canonical_source_file(frame) if frame else None, source_line=frame.line_num if frame else None) +# Utilities + +def parameter(builder, num, shape, name=None, replicated=None): + if name is None: + name = '' + if replicated is None: + replicated = [] + elif isinstance(replicated, bool): + replicated = [replicated] * shape.leaf_count() + + return xops.Parameter(builder, num, + shape.with_major_to_minor_layout_if_absent(), name, + replicated) + +# HLO instructions optionally can be annotated to say how the output should be +# spatially partitioned (represented in XLA as OpSharding protos, see +# _sharding_to_proto). For array outputs, the annotation is either an int per +# dimension specifying the number of ways that dimension divided (i.e. the total +# number of shards is the product), or None to indicate the array should be +# replicated. Tuple outputs are represented as tuples thereof. XLA supports +# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type +# checkers don't support recursive types), so we only represent one level of +# nesting in this type definition. +SpatialSharding = Union[Tuple[int, ...], + None, + Tuple[Union[Tuple[int, ...], None], ...]] + +def _sharding_to_proto(sharding: SpatialSharding): + """Converts a SpatialSharding to an OpSharding. + + See + https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601 + for details on the OpSharding proto. + """ + proto = xc.OpSharding() + if isinstance(sharding, tuple) and not isinstance(sharding[0], int): + assert all(s is None or isinstance(s, tuple) for s in sharding) + return tuple_sharding_proto(list(map(_sharding_to_proto, sharding))) # type: ignore + + if sharding is None: + proto.type = xc.OpSharding.Type.REPLICATED + else: + proto.type = xc.OpSharding.Type.OTHER + proto.tile_assignment_dimensions = list(sharding) + proto.tile_assignment_devices = list(range(np.product(sharding))) + return proto + +def tuple_sharding_proto(elems): + proto = xc.OpSharding() + assert all(isinstance(e, type(proto)) for e in elems) + proto.type = xc.OpSharding.Type.TUPLE + proto.tuple_shardings = elems + return proto + +def set_sharding_proto(builder, op, sharding_proto): + """Uses CustomCall to annotate a value as sharded.""" + # "Sharding" is a built-in custom call target that acts like an identity + # function, and is used to attach an OpSharding to. + return with_sharding_proto(builder, sharding_proto, xops.CustomCall, + builder, b"Sharding", [op], builder.get_shape(op)) + +def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs): + """Builds op_fn(*args, **kwargs) with sharding annotation.""" + builder.set_sharding(sharding_proto) + try: + return op_fn(*args, **kwargs) + finally: + builder.clear_sharding() + +def set_sharding(builder, op, sharding: SpatialSharding): + """Uses CustomCall to annotate a value as sharded.""" + return set_sharding_proto(builder, op, _sharding_to_proto(sharding)) + +def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs): + """Builds op_fn(*args, **kwargs) with sharding annotation.""" + return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs) + + ### handlers # Numpy dtypes -> XLA primitive types @@ -400,7 +478,7 @@ def _xla_callable_args( if partitions is None: tuple_parts = None elif partitions_proto: - tuple_parts = xb.tuple_sharding_proto(partitions) + tuple_parts = tuple_sharding_proto(partitions) else: tuple_parts = tuple(partitions) tuple_shape = xc.Shape.tuple_shape( @@ -419,15 +497,15 @@ def _xla_param(builder, param_num, xla_shape, replicated, partitions, is_token = xla_shape.is_token() if filter_tokens and is_token: xla_shape = _token_param_shape() - make_param = partial(xb.parameter, builder, param_num, xla_shape, + make_param = partial(parameter, builder, param_num, xla_shape, replicated=replicated) - with_sharding = xb.with_sharding_proto if parts_proto else xb.with_sharding + with_sharding_fn = with_sharding_proto if parts_proto else with_sharding if partitions is None: out = make_param() elif partitions is _replicated_param: - out = with_sharding(builder, None, make_param) + out = with_sharding_fn(builder, None, make_param) else: - out = with_sharding(builder, partitions, make_param) + out = with_sharding_fn(builder, partitions, make_param) if filter_tokens and is_token: out = xops.CreateToken(builder) return out @@ -583,9 +661,9 @@ def flatten_shape(s: XlaShape) -> Sequence[Tuple[Sequence[int], XlaShape]]: Given the following computation: >>> c = xc.XlaBuilder("example") - >>> p0 = xb.parameter(c, 1, xc.shape_from_pyval(jnp.ones([1]))) - >>> p1 = xb.parameter(c, 2, xc.shape_from_pyval(jnp.ones([2]))) - >>> p2 = xb.parameter(c, 3, xc.shape_from_pyval(jnp.ones([3]))) + >>> p0 = parameter(c, 1, xc.shape_from_pyval(jnp.ones([1]))) + >>> p1 = parameter(c, 2, xc.shape_from_pyval(jnp.ones([2]))) + >>> p2 = parameter(c, 3, xc.shape_from_pyval(jnp.ones([3]))) >>> o = xops.Tuple(c, [p0, p1, p2]) We can query the arrays in the output tuple: @@ -659,6 +737,55 @@ def set_up_aliases(c, xla_args, out_shape: XlaShape, donated_args, tuple_args): +def lower_jaxpr_to_xla_module( + fn_name: str, jaxpr: core.ClosedJaxpr, platform: str, axis_env: AxisEnv, + name_stack: str, tuple_args: bool, donated_invars: Sequence[bool], + replicated_args: Optional[Sequence[bool]], + arg_partitions: Optional[Any], + out_partitions: Optional[Any], + partitions_are_protos: bool = False + ) -> xc.XlaComputation: + """Lowers a closed jaxpr to a top-level XLA module.""" + c = xc.XlaBuilder(fn_name) + xla_consts = _xla_consts(c, jaxpr.consts) + xla_args, donated_invars = _xla_callable_args( + c, jaxpr.in_avals, tuple_args, donated_invars=donated_invars, + replicated=replicated_args, partitions=arg_partitions, + partitions_proto=partitions_are_protos) + ctx = TranslationContext(c, platform, axis_env, name_stack) + out_nodes = jaxpr_subcomp(ctx, jaxpr.jaxpr, xla_consts, *xla_args) + # Replace tokens with a dummy array value, because the runtime cannot + # handle token arguments. + out_aval_lens = [len(aval_to_xla_shapes(a)) for a in jaxpr.out_avals] + out_nodes = util.flatten( + [[_make_token_return_value(c)] if a is core.abstract_token + else v for a, v in zip(jaxpr.out_avals, + util.unflatten(out_nodes, out_aval_lens))]) + + # There is a non-zero cost to building an output tuple, particularly on TPU. + # Avoid it if the output arity is 1. + if out_partitions is None: + output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes) + else: + build_out_tuple = partial(xops.Tuple, c, out_nodes) + if partitions_are_protos: + output = with_sharding_proto(c, out_partitions, build_out_tuple) + else: + output = with_sharding(c, out_partitions, build_out_tuple) + + if platform in ("gpu", "tpu"): + donated_invars = set_up_aliases( + c, xla_args, c.GetShape(output), donated_invars, tuple_args) + if any(donated_invars): + # TODO(tomhennigan): At call time we should mark these buffers as deleted. + unused_donations = [str(c.GetShape(a)) + for a, d in zip(xla_args, donated_invars) if d] + warnings.warn("Some donated buffers were not usable: {}".format( + ", ".join(unused_donations))) + return c.build(output) + + + xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call') xla_call = xla_call_p.bind @@ -698,7 +825,7 @@ def _xla_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name, c = ctx.builder check_backend_matches(backend, ctx.platform) subc = xc.XlaBuilder(f"jit_{name}") - args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)] + args = [parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)] sub_ctx = ctx.replace( builder=subc, name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'jit'))) @@ -934,7 +1061,7 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr): true_op = xops.Tuple(c, in_nodes) remat_subc = xc.XlaBuilder("remat_call_subcomputation") - input_op = xb.parameter(remat_subc, 0, c.get_shape(true_op), replicated=[]) + input_op = parameter(remat_subc, 0, c.get_shape(true_op), replicated=[]) args = xla_destructure(remat_subc, input_op) sub_ctx = ctx.replace( builder=remat_subc, @@ -945,7 +1072,7 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr): false_op = true_op dummy_subc = xc.XlaBuilder("remat_call_dummy_subcomputation") - xb.parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[]) + parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[]) out_nodes = [_zeros(dummy_subc, s) for s in out_node_shapes] dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes)) @@ -959,7 +1086,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr): # Dummy subc for getting subcomp shapes. dummy_inputs = xops.Tuple(c, in_nodes) dummy_subc = xc.XlaBuilder("remat_dummy_subcomputation") - dummy_input_op = xb.parameter(dummy_subc, 0, c.get_shape(dummy_inputs), replicated=[]) + dummy_input_op = parameter(dummy_subc, 0, c.get_shape(dummy_inputs), replicated=[]) dummy_args = xla_destructure(dummy_subc, dummy_input_op) dummy_ctx = ctx.replace( builder=dummy_subc, @@ -972,7 +1099,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr): inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs) cond_subc = xc.XlaBuilder("remat_cond_subcomputation") - input_op = xb.parameter(cond_subc, 0, c.get_shape(inputs), replicated=[]) + input_op = parameter(cond_subc, 0, c.get_shape(inputs), replicated=[]) i = xops.GetTupleElement(input_op, 0) rng = xops.RngUniform(xops.Constant(cond_subc, np.array(1, dtype=np.int32)), xops.Constant(cond_subc, np.array(2, dtype=np.int32)), @@ -980,7 +1107,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr): cond_subc = cond_subc.build(xops.Lt(i, rng)) body_subc = xc.XlaBuilder("remat_body_subcomputation") - input_op = xb.parameter(body_subc, 0, c.get_shape(inputs), replicated=[]) + input_op = parameter(body_subc, 0, c.get_shape(inputs), replicated=[]) i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes)+1] i_next = xops.Add(i, xops.Constant(body_subc, np.array(1, dtype=np.int32))) body_ctx = ctx.replace( @@ -1019,7 +1146,7 @@ def _named_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, check_backend_matches(backend, ctx.platform) c = ctx.builder subc = xc.XlaBuilder(name) - args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)] + args = [parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)] sub_ctx = ctx.replace(builder=subc, name_stack=extend_name_stack(ctx.name_stack, name)) out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) diff --git a/tests/api_test.py b/tests/api_test.py index 9459e15df..a143ee93b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -244,8 +244,6 @@ class CPPJitTest(jtu.BufferDonationTestCase): # Jit and Donate arguments def test_jit_donate_argnums_warning_raised(self): - if jax.config.jax_enable_mlir: - raise unittest.SkipTest("Buffer donation not yet implemented via MLIR") x = jnp.array([1.0, 2.0], jnp.float32) y = jnp.array([1, 2], jnp.int32) f = self.jit(lambda x, y: x.sum() + y.sum(), donate_argnums=(0, 1)) @@ -256,7 +254,7 @@ class CPPJitTest(jtu.BufferDonationTestCase): self.assertLen(w, 1) self.assertTrue(issubclass(w[-1].category, UserWarning)) self.assertIn( - "Some donated buffers were not usable: f32[2]{0}, s32[2]{0}", + "Some donated buffers were not usable:", str(w[-1].message)) @jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU. diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index 31b54545c..2fe0d0c80 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -26,6 +26,7 @@ from jax.interpreters import xla from jax._src.lib.mlir import ir from jax._src.lib import xla_bridge, xla_client xops = xla_client.ops +xc = xla_client xb = xla_bridge from jax.config import config @@ -113,14 +114,14 @@ def sparse_array_result_handler(device, aval): def sparse_array_shape_handler(a): return ( - xla.xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape), - xla.xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape), + xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape), + xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape), ) def sparse_array_device_put_handler(a, device): return ( - xla.xb.get_device_backend(device).buffer_from_pyval(a.data, device), - xla.xb.get_device_backend(device).buffer_from_pyval(a.indices, device) + xb.get_device_backend(device).buffer_from_pyval(a.data, device), + xb.get_device_backend(device).buffer_from_pyval(a.indices, device) ) def sparse_array_constant_handler(c, val, canonicalize_dtypes): diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 2218a1011..2937180d1 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -19,6 +19,7 @@ from absl.testing import absltest from jax._src import test_util as jtu from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc +from jax.interpreters import xla from jax._src.config import config config.parse_flags_with_absl() @@ -47,13 +48,13 @@ class XlaBridgeTest(jtu.JaxTestCase): def test_parameter_replication_default(self): c = xc.XlaBuilder("test") - _ = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ())) + _ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ())) built_c = c.Build() assert "replication" not in built_c.as_hlo_text() def test_parameter_replication(self): c = xc.XlaBuilder("test") - _ = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "", + _ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "", False) built_c = c.Build() assert "parameter_replication={false}" in built_c.as_hlo_text()