mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove xla_bridge.make_computation_builder().
This is a vestigal wrapper around xla_client.XlaBuilder whose purpose is long gone. Also rename uses of XlaComputationBuilder to XlaBuilder. XlaComputationBuilder was an older name that is gone in most places.
This commit is contained in:
parent
391cafb0e5
commit
714e19a794
@ -1992,7 +1992,7 @@
|
||||
" typecheck_jaxpr(jaxpr)\n",
|
||||
" consts = [x.val for x in hashable_consts]\n",
|
||||
" in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]\n",
|
||||
" c = xb.make_computation_builder('xla_call')\n",
|
||||
" c = xc.XlaBuilder('xla_call')\n",
|
||||
" xla_consts = _xla_consts(c, consts)\n",
|
||||
" xla_params = _xla_params(c, in_avals)\n",
|
||||
" outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)\n",
|
||||
@ -2094,7 +2094,7 @@
|
||||
"def reduce_sum_translation(c, in_avals, in_vals, *, axis):\n",
|
||||
" (x_aval,), (x,) = in_avals, in_vals\n",
|
||||
" zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))\n",
|
||||
" subc = xb.make_computation_builder('add')\n",
|
||||
" subc = xc.XlaBuilder('add')\n",
|
||||
" shape = _xla_shape(ShapedArray((), x_aval.dtype))\n",
|
||||
" xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))\n",
|
||||
" return [xops.Reduce(c, [x], [zero], subc.build(), [axis])]\n",
|
||||
@ -2279,7 +2279,7 @@
|
||||
"def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):\n",
|
||||
" del num_consts # Only used at top-level.\n",
|
||||
" # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.\n",
|
||||
" subc = xb.make_computation_builder('inner xla_call')\n",
|
||||
" subc = xc.XlaBuilder('inner xla_call')\n",
|
||||
" xla_params = _xla_params(subc, in_avals)\n",
|
||||
" outs = jaxpr_subcomp(subc, jaxpr, xla_params)\n",
|
||||
" subc = subc.build(xops.Tuple(subc, outs))\n",
|
||||
@ -3629,7 +3629,7 @@
|
||||
" operand_shape = c.get_shape(operand)\n",
|
||||
"\n",
|
||||
" def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:\n",
|
||||
" c = xb.make_computation_builder(name)\n",
|
||||
" c = xc.XlaBuilder(name)\n",
|
||||
" operand = xb.parameter(c, 0, operand_shape)\n",
|
||||
" operands = tree_unflatten(in_tree, destructure_tuple(c, operand))\n",
|
||||
" outs = jaxpr_subcomp(c, jaxpr, operands)\n",
|
||||
|
@ -1562,7 +1562,7 @@ def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: Tuple[IDHashable])
|
||||
typecheck_jaxpr(jaxpr)
|
||||
consts = [x.val for x in hashable_consts]
|
||||
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
|
||||
c = xb.make_computation_builder('xla_call')
|
||||
c = xc.XlaBuilder('xla_call')
|
||||
xla_consts = _xla_consts(c, consts)
|
||||
xla_params = _xla_params(c, in_avals)
|
||||
outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)
|
||||
@ -1644,7 +1644,7 @@ xla_translations[less_p] = partial(direct_translation, xops.Lt)
|
||||
def reduce_sum_translation(c, in_avals, in_vals, *, axis):
|
||||
(x_aval,), (x,) = in_avals, in_vals
|
||||
zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))
|
||||
subc = xb.make_computation_builder('add')
|
||||
subc = xc.XlaBuilder('add')
|
||||
shape = _xla_shape(ShapedArray((), x_aval.dtype))
|
||||
xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))
|
||||
return [xops.Reduce(c, [x], [zero], subc.build(), [axis])]
|
||||
@ -1776,7 +1776,7 @@ abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule
|
||||
def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):
|
||||
del num_consts # Only used at top-level.
|
||||
# Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.
|
||||
subc = xb.make_computation_builder('inner xla_call')
|
||||
subc = xc.XlaBuilder('inner xla_call')
|
||||
xla_params = _xla_params(subc, in_avals)
|
||||
outs = jaxpr_subcomp(subc, jaxpr, xla_params)
|
||||
subc = subc.build(xops.Tuple(subc, outs))
|
||||
@ -2843,7 +2843,7 @@ def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
|
||||
operand_shape = c.get_shape(operand)
|
||||
|
||||
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
|
||||
c = xb.make_computation_builder(name)
|
||||
c = xc.XlaBuilder(name)
|
||||
operand = xb.parameter(c, 0, operand_shape)
|
||||
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
|
||||
outs = jaxpr_subcomp(c, jaxpr, operands)
|
||||
|
@ -1554,7 +1554,7 @@ def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: Tuple[IDHashable])
|
||||
typecheck_jaxpr(jaxpr)
|
||||
consts = [x.val for x in hashable_consts]
|
||||
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
|
||||
c = xb.make_computation_builder('xla_call')
|
||||
c = xc.XlaBuilder('xla_call')
|
||||
xla_consts = _xla_consts(c, consts)
|
||||
xla_params = _xla_params(c, in_avals)
|
||||
outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)
|
||||
@ -1640,7 +1640,7 @@ xla_translations[less_p] = partial(direct_translation, xops.Lt)
|
||||
def reduce_sum_translation(c, in_avals, in_vals, *, axis):
|
||||
(x_aval,), (x,) = in_avals, in_vals
|
||||
zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))
|
||||
subc = xb.make_computation_builder('add')
|
||||
subc = xc.XlaBuilder('add')
|
||||
shape = _xla_shape(ShapedArray((), x_aval.dtype))
|
||||
xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))
|
||||
return [xops.Reduce(c, [x], [zero], subc.build(), [axis])]
|
||||
@ -1770,7 +1770,7 @@ abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule
|
||||
def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):
|
||||
del num_consts # Only used at top-level.
|
||||
# Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.
|
||||
subc = xb.make_computation_builder('inner xla_call')
|
||||
subc = xc.XlaBuilder('inner xla_call')
|
||||
xla_params = _xla_params(subc, in_avals)
|
||||
outs = jaxpr_subcomp(subc, jaxpr, xla_params)
|
||||
subc = subc.build(xops.Tuple(subc, outs))
|
||||
@ -2835,7 +2835,7 @@ def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
|
||||
operand_shape = c.get_shape(operand)
|
||||
|
||||
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
|
||||
c = xb.make_computation_builder(name)
|
||||
c = xc.XlaBuilder(name)
|
||||
operand = xb.parameter(c, 0, operand_shape)
|
||||
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
|
||||
outs = jaxpr_subcomp(c, jaxpr, operands)
|
||||
|
@ -817,7 +817,7 @@ def xla_computation(fun: Callable,
|
||||
else:
|
||||
out_parts_flat = tuple(flatten_axes(
|
||||
"xla_computation out_parts", out_tree(), out_parts))
|
||||
c = xb.make_computation_builder(f"xla_computation_{fun_name}")
|
||||
c = xc.XlaBuilder(f"xla_computation_{fun_name}")
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
|
||||
xla_args, donated_invars = xla._xla_callable_args(
|
||||
|
@ -336,7 +336,7 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
|
||||
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
|
||||
|
||||
cond_c = xb.make_computation_builder("cond_computation")
|
||||
cond_c = xla_client.XlaBuilder("cond_computation")
|
||||
cond_carry = xb.parameter(cond_c, 0, c.get_shape(init_carry))
|
||||
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
|
||||
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
|
||||
@ -351,7 +351,7 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, np.array(False))], or_,
|
||||
list(range(cond_jaxpr.out_avals[0].ndim)))
|
||||
|
||||
body_c = xb.make_computation_builder("body_computation")
|
||||
body_c = xla_client.XlaBuilder("body_computation")
|
||||
body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry))
|
||||
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
|
||||
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
|
||||
@ -801,7 +801,7 @@ def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
|
||||
|
||||
name_stack = extend_name_stack(ctx.name_stack, "cond")
|
||||
def make_computation(name, jaxpr, op_shape):
|
||||
c = xb.make_computation_builder(name + '_comp')
|
||||
c = xla_client.XlaBuilder(name + '_comp')
|
||||
op = xb.parameter(c, 0, op_shape)
|
||||
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
|
||||
subctx = ctx.replace(
|
||||
|
@ -4727,7 +4727,7 @@ def _gather_translation_rule(c, operand, indices, *, dimension_numbers,
|
||||
|
||||
# Compute the conjunction of the mask elements across the dimensions in which
|
||||
# we are slicing.
|
||||
and_builder = xb.make_computation_builder("and_reduction")
|
||||
and_builder = xc.XlaBuilder("and_reduction")
|
||||
scalar_pred = xla_client.Shape.array_shape(np.dtype(np.bool_), ())
|
||||
xops.And(xb.parameter(and_builder, 0, scalar_pred),
|
||||
xb.parameter(and_builder, 1, scalar_pred))
|
||||
@ -5029,7 +5029,7 @@ def _scatter_add_translation_rule(
|
||||
dimension_numbers)
|
||||
|
||||
def _make_reducer(dtype):
|
||||
subc = xla_bridge.make_computation_builder("scatter_add_reducer")
|
||||
subc = xc.XlaBuilder("scatter_add_reducer")
|
||||
shape = xc.Shape.array_shape(np.dtype(dtype), ())
|
||||
args = [xb.parameter(subc, 0, shape), xb.parameter(subc, 1, shape)]
|
||||
out = xops.Add(args[0], args[1])
|
||||
@ -5549,7 +5549,7 @@ def _reduction_computation(c, jaxpr, consts, init_values, singleton=True):
|
||||
init_values = [init_values]
|
||||
shapes = safe_map(c.get_shape, init_values + init_values)
|
||||
axis_env = xla.AxisEnv(1, (), ()) # no parallel primitives inside reductions
|
||||
subc = xla_bridge.make_computation_builder("reduction_computation")
|
||||
subc = xc.XlaBuilder("reduction_computation")
|
||||
assert len(consts) == 0, "Reduction computations cannot have constants"
|
||||
args = [xb.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
|
||||
ctx = xla.TranslationContext(subc, None, axis_env, '')
|
||||
@ -6300,7 +6300,7 @@ def _select_and_gather_add_translation(
|
||||
etype)
|
||||
|
||||
def reducer():
|
||||
c = xla_bridge.make_computation_builder("select_and_gather_pair_reducer")
|
||||
c = xc.XlaBuilder("select_and_gather_pair_reducer")
|
||||
x = xb.parameter(c, 0,
|
||||
xla_client.Shape.array_shape(np.dtype(double_word_dtype), ()))
|
||||
y = xb.parameter(c, 1,
|
||||
@ -6330,7 +6330,7 @@ def _select_and_gather_add_translation_using_variadic_reducewindow(
|
||||
canonicalize_types=False)
|
||||
|
||||
def reducer():
|
||||
c = xla_bridge.make_computation_builder("select_and_gather_pair_reducer")
|
||||
c = xc.XlaBuilder("select_and_gather_pair_reducer")
|
||||
shape = xla_client.Shape.array_shape(np.dtype(dtype), ())
|
||||
kx, vx, ky, vy = (xb.parameter(c, i, shape) for i in range(4))
|
||||
which = (xops.Ge if select_prim is ge_p else xops.Le)(kx, ky)
|
||||
@ -6487,7 +6487,7 @@ def _sort_lt_comparator(*operands, num_keys=1):
|
||||
|
||||
def _sort_translation_rule(c, *operands, dimension, is_stable, num_keys):
|
||||
types = [c.get_shape(x).xla_element_type() for x in operands]
|
||||
subc = xla_bridge.make_computation_builder("sort_lt_comparator")
|
||||
subc = xc.XlaBuilder("sort_lt_comparator")
|
||||
params = [xb.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
|
||||
for i, typ in enumerate(types) for j in range(2)]
|
||||
result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys),
|
||||
|
@ -558,9 +558,6 @@ def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
|
||||
|
||||
def make_computation_builder(name):
|
||||
return xla_client.XlaBuilder(name)
|
||||
|
||||
|
||||
def register_constant_handler(type_, handler_fun):
|
||||
_constant_handlers[type_] = handler_fun
|
||||
|
@ -743,7 +743,7 @@ dynamic_xla_call_p.multiple_results = True
|
||||
def _dynamic_xla_call_impl(*args, jaxpr, num_consts):
|
||||
in_dim_vals, consts, args = split_list(args, [len(jaxpr.in_dim_binders), num_consts])
|
||||
dim_in_avals = [v.aval for v in jaxpr.in_dim_binders]
|
||||
c = xb.make_computation_builder("dxla_call")
|
||||
c = xc.XlaBuilder("dxla_call")
|
||||
dim_params, params = _make_params(c, dim_in_avals, map(xla.abstractify, args))
|
||||
const_params = _xla_consts(c, consts)
|
||||
dim_outs, outs = djaxpr_subcomp(c, jaxpr, dim_params, const_params + params)
|
||||
|
@ -496,7 +496,7 @@ xops = xla_client._xla.ops
|
||||
|
||||
XlaOp = xla_client.XlaOp
|
||||
XlaShape = xla_client.Shape
|
||||
XlaComputationBuilder = xla_client.XlaBuilder
|
||||
XlaBuilder = xla_client.XlaBuilder
|
||||
XlaDevice = xla_client.Device
|
||||
XlaLocalClient = xla_client.Client
|
||||
DType = Any
|
||||
@ -923,7 +923,7 @@ def _outside_call_impl(*args, **params):
|
||||
outside_call_p.def_impl(_outside_call_impl)
|
||||
|
||||
|
||||
def _outside_call_translation_rule(comp: XlaComputationBuilder,
|
||||
def _outside_call_translation_rule(comp: XlaBuilder,
|
||||
*args_op: XlaOp,
|
||||
platform="tpu",
|
||||
has_token,
|
||||
|
@ -237,7 +237,7 @@ def _call_tf_abstract_eval(*_,
|
||||
call_tf_p.def_abstract_eval(_call_tf_abstract_eval)
|
||||
|
||||
|
||||
def _call_tf_translation_rule(builder: xla.XlaComputationBuilder, *args_op,
|
||||
def _call_tf_translation_rule(builder: xla.XlaBuilder, *args_op,
|
||||
function_flat_tf,
|
||||
args_flat_sig_tf,
|
||||
**_):
|
||||
@ -253,7 +253,7 @@ def _code_generator_and_avals(
|
||||
function_flat_tf,
|
||||
args_flat_sig_tf,
|
||||
code_gen_optional=False
|
||||
) -> Tuple[Optional[Callable[[xla.XlaComputationBuilder, Sequence[xla.XlaOp]], xla.XlaOp]],
|
||||
) -> Tuple[Optional[Callable[[xla.XlaBuilder, Sequence[xla.XlaOp]], xla.XlaOp]],
|
||||
Sequence[core.ShapedArray]]:
|
||||
# Returns and caches a code generator (taking a builder and the
|
||||
# XlaOps for the arguments) and a sequence of result abstract shapes.
|
||||
@ -384,7 +384,7 @@ def _code_generator_and_avals(
|
||||
|
||||
result_avals = tuple(map(canonical_res_aval, result_shapes)) # type: ignore
|
||||
|
||||
def code_gen(builder: xla.XlaComputationBuilder, args_op: Sequence[xla.XlaOp]) -> xla.XlaOp:
|
||||
def code_gen(builder: xla.XlaBuilder, args_op: Sequence[xla.XlaOp]) -> xla.XlaOp:
|
||||
captured_ops = [xops.ConstantLiteral(builder, np.asarray(inp))
|
||||
for inp in captured_inputs]
|
||||
|
||||
|
@ -875,7 +875,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
|
||||
tuple_args = len(global_sharded_avals) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
c = xb.make_computation_builder("pmap_{}".format(fun.__name__))
|
||||
c = xc.XlaBuilder("pmap_{}".format(fun.__name__))
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
replicated_args = [axis is None for axis in in_axes]
|
||||
xla_args, donated_invars = xla._xla_callable_args(c, global_sharded_avals, tuple_args,
|
||||
@ -1594,7 +1594,7 @@ def lower_mesh_computation(
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
# 3. Build up the HLO
|
||||
c = xb.make_computation_builder(f"xmap_{fun.__name__}")
|
||||
c = xc.XlaBuilder(f"xmap_{fun.__name__}")
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
tuple_args = len(in_jaxpr_avals) > 100 # pass long arg lists as tuple for TPU
|
||||
in_partitions: Optional[List]
|
||||
|
@ -138,7 +138,7 @@ def _sharded_callable(
|
||||
"Compiling %s for %d devices with args %s.",
|
||||
fun.__name__, nparts, global_abstract_args)
|
||||
|
||||
c = xb.make_computation_builder("spjit_{}".format(fun.__name__))
|
||||
c = xc.XlaBuilder("spjit_{}".format(fun.__name__))
|
||||
xla_consts = _map(partial(xb.constant, c), consts)
|
||||
xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
|
||||
axis_env = xla.AxisEnv(nrep, (), ())
|
||||
|
@ -62,7 +62,7 @@ Buffer = xe.Buffer
|
||||
|
||||
XlaOp = xc.XlaOp
|
||||
XlaShape = xc.Shape
|
||||
XlaComputationBuilder = xc.XlaBuilder
|
||||
XlaBuilder = xc.XlaBuilder
|
||||
XlaExecutable = xc.Executable
|
||||
|
||||
# This flag is set on exit; no logging should be attempted
|
||||
@ -328,7 +328,7 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
|
||||
|
||||
def primitive_subcomputation(prim: core.Primitive, *avals: core.AbstractValue,
|
||||
**params):
|
||||
c = xb.make_computation_builder(f"primitive_computation_{prim.name}")
|
||||
c = xc.XlaBuilder(f"primitive_computation_{prim.name}")
|
||||
f = lower_fun(prim.bind, multiple_results=prim.multiple_results)
|
||||
xla_args, _ = _xla_callable_args(c, avals, tuple_args=False)
|
||||
ans = f(c, *xla_args, **params)
|
||||
@ -683,7 +683,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
|
||||
tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
c = xb.make_computation_builder(f"jit_{fun.__name__}")
|
||||
c = xc.XlaBuilder(f"jit_{fun.__name__}")
|
||||
xla_consts = _xla_consts(c, consts)
|
||||
xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args,
|
||||
donated_invars=donated_invars)
|
||||
@ -1026,7 +1026,7 @@ def _xla_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name,
|
||||
del device, donated_invars, inline # Ignored.
|
||||
c = ctx.builder
|
||||
check_backend_matches(backend, ctx.platform)
|
||||
subc = xb.make_computation_builder(f"jit_{name}")
|
||||
subc = xc.XlaBuilder(f"jit_{name}")
|
||||
args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
|
||||
sub_ctx = ctx.replace(
|
||||
builder=subc,
|
||||
@ -1579,7 +1579,7 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
|
||||
pred = xops.Lt(rng, xb.constant(c, np.array(2, dtype=np.float32)))
|
||||
|
||||
true_op = xops.Tuple(c, in_nodes)
|
||||
remat_subc = xb.make_computation_builder("remat_call_subcomputation")
|
||||
remat_subc = xc.XlaBuilder("remat_call_subcomputation")
|
||||
input_op = xb.parameter(remat_subc, 0, c.get_shape(true_op), replicated=[])
|
||||
args = xla_destructure(remat_subc, input_op)
|
||||
sub_ctx = ctx.replace(
|
||||
@ -1590,7 +1590,7 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
|
||||
remat_subc = remat_subc.build(xops.Tuple(remat_subc, out_nodes))
|
||||
|
||||
false_op = true_op
|
||||
dummy_subc = xb.make_computation_builder("remat_call_dummy_subcomputation")
|
||||
dummy_subc = xc.XlaBuilder("remat_call_dummy_subcomputation")
|
||||
xb.parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[])
|
||||
out_nodes = [_zeros(dummy_subc, s) for s in out_node_shapes]
|
||||
dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes))
|
||||
@ -1604,7 +1604,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
|
||||
c = ctx.builder
|
||||
# Dummy subc for getting subcomp shapes.
|
||||
dummy_inputs = xops.Tuple(c, in_nodes)
|
||||
dummy_subc = xb.make_computation_builder("remat_dummy_subcomputation")
|
||||
dummy_subc = xc.XlaBuilder("remat_dummy_subcomputation")
|
||||
dummy_input_op = xb.parameter(dummy_subc, 0, c.get_shape(dummy_inputs), replicated=[])
|
||||
dummy_args = xla_destructure(dummy_subc, dummy_input_op)
|
||||
dummy_ctx = ctx.replace(
|
||||
@ -1617,7 +1617,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
|
||||
zeros_like_outs = [_zeros(c, s) for s in out_node_shapes]
|
||||
inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs)
|
||||
|
||||
cond_subc = xb.make_computation_builder("remat_cond_subcomputation")
|
||||
cond_subc = xc.XlaBuilder("remat_cond_subcomputation")
|
||||
input_op = xb.parameter(cond_subc, 0, c.get_shape(inputs), replicated=[])
|
||||
i = xops.GetTupleElement(input_op, 0)
|
||||
rng = xops.RngUniform(xb.constant(cond_subc, np.array(1, dtype=np.int32)),
|
||||
@ -1625,7 +1625,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
|
||||
xc.Shape.array_shape(xc.PrimitiveType.S32, []))
|
||||
cond_subc = cond_subc.build(xops.Lt(i, rng))
|
||||
|
||||
body_subc = xb.make_computation_builder("remat_body_subcomputation")
|
||||
body_subc = xc.XlaBuilder("remat_body_subcomputation")
|
||||
input_op = xb.parameter(body_subc, 0, c.get_shape(inputs), replicated=[])
|
||||
i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes)+1]
|
||||
i_next = xops.Add(i, xb.constant(body_subc, np.array(1, dtype=np.int32)))
|
||||
@ -1664,7 +1664,7 @@ def _named_call_translation_rule(ctx, avals_in, avals_out, *in_nodes,
|
||||
name="core_call", backend=None, call_jaxpr):
|
||||
check_backend_matches(backend, ctx.platform)
|
||||
c = ctx.builder
|
||||
subc = xb.make_computation_builder(name)
|
||||
subc = xc.XlaBuilder(name)
|
||||
args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
|
||||
sub_ctx = ctx.replace(builder=subc,
|
||||
name_stack=extend_name_stack(ctx.name_stack, name))
|
||||
|
@ -39,6 +39,7 @@ from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1753,7 +1754,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
Check that we get the proper error from the runtime."""
|
||||
if not hcb._use_outfeed(jtu.device_under_test()):
|
||||
raise SkipTest("test works only for outfeed")
|
||||
comp = xla_bridge.make_computation_builder(self._testMethodName)
|
||||
comp = xla_client.XlaBuilder(self._testMethodName)
|
||||
token = hcb.xops.CreateToken(comp)
|
||||
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
@ -1766,7 +1767,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
"""Try to register different shapes for the same consumer ID."""
|
||||
if not hcb._use_outfeed(jtu.device_under_test()):
|
||||
raise SkipTest("test works only for outfeed")
|
||||
comp = xla_bridge.make_computation_builder(self._testMethodName)
|
||||
comp = xla_client.XlaBuilder(self._testMethodName)
|
||||
token = hcb.xops.CreateToken(comp)
|
||||
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
|
||||
hcb._callback_handler_data.receiver.add_outfeed(
|
||||
|
@ -46,13 +46,13 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
expected_device_assignment)
|
||||
|
||||
def test_parameter_replication_default(self):
|
||||
c = xb.make_computation_builder("test")
|
||||
c = xc.XlaBuilder("test")
|
||||
_ = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
|
||||
built_c = c.Build()
|
||||
assert "replication" not in built_c.as_hlo_text()
|
||||
|
||||
def test_parameter_replication(self):
|
||||
c = xb.make_computation_builder("test")
|
||||
c = xc.XlaBuilder("test")
|
||||
_ = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "",
|
||||
False)
|
||||
built_c = c.Build()
|
||||
|
Loading…
x
Reference in New Issue
Block a user