Remove xla_bridge.make_computation_builder().

This is a vestigal wrapper around xla_client.XlaBuilder whose purpose is long gone.

Also rename uses of XlaComputationBuilder to XlaBuilder. XlaComputationBuilder was an older name that is gone in most places.
This commit is contained in:
Peter Hawkins 2021-10-18 13:19:45 -04:00
parent 391cafb0e5
commit 714e19a794
15 changed files with 46 additions and 48 deletions

View File

@ -1992,7 +1992,7 @@
" typecheck_jaxpr(jaxpr)\n",
" consts = [x.val for x in hashable_consts]\n",
" in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]\n",
" c = xb.make_computation_builder('xla_call')\n",
" c = xc.XlaBuilder('xla_call')\n",
" xla_consts = _xla_consts(c, consts)\n",
" xla_params = _xla_params(c, in_avals)\n",
" outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)\n",
@ -2094,7 +2094,7 @@
"def reduce_sum_translation(c, in_avals, in_vals, *, axis):\n",
" (x_aval,), (x,) = in_avals, in_vals\n",
" zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))\n",
" subc = xb.make_computation_builder('add')\n",
" subc = xc.XlaBuilder('add')\n",
" shape = _xla_shape(ShapedArray((), x_aval.dtype))\n",
" xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))\n",
" return [xops.Reduce(c, [x], [zero], subc.build(), [axis])]\n",
@ -2279,7 +2279,7 @@
"def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):\n",
" del num_consts # Only used at top-level.\n",
" # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.\n",
" subc = xb.make_computation_builder('inner xla_call')\n",
" subc = xc.XlaBuilder('inner xla_call')\n",
" xla_params = _xla_params(subc, in_avals)\n",
" outs = jaxpr_subcomp(subc, jaxpr, xla_params)\n",
" subc = subc.build(xops.Tuple(subc, outs))\n",
@ -3629,7 +3629,7 @@
" operand_shape = c.get_shape(operand)\n",
"\n",
" def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:\n",
" c = xb.make_computation_builder(name)\n",
" c = xc.XlaBuilder(name)\n",
" operand = xb.parameter(c, 0, operand_shape)\n",
" operands = tree_unflatten(in_tree, destructure_tuple(c, operand))\n",
" outs = jaxpr_subcomp(c, jaxpr, operands)\n",

View File

@ -1562,7 +1562,7 @@ def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: Tuple[IDHashable])
typecheck_jaxpr(jaxpr)
consts = [x.val for x in hashable_consts]
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
c = xb.make_computation_builder('xla_call')
c = xc.XlaBuilder('xla_call')
xla_consts = _xla_consts(c, consts)
xla_params = _xla_params(c, in_avals)
outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)
@ -1644,7 +1644,7 @@ xla_translations[less_p] = partial(direct_translation, xops.Lt)
def reduce_sum_translation(c, in_avals, in_vals, *, axis):
(x_aval,), (x,) = in_avals, in_vals
zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))
subc = xb.make_computation_builder('add')
subc = xc.XlaBuilder('add')
shape = _xla_shape(ShapedArray((), x_aval.dtype))
xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))
return [xops.Reduce(c, [x], [zero], subc.build(), [axis])]
@ -1776,7 +1776,7 @@ abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule
def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):
del num_consts # Only used at top-level.
# Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.
subc = xb.make_computation_builder('inner xla_call')
subc = xc.XlaBuilder('inner xla_call')
xla_params = _xla_params(subc, in_avals)
outs = jaxpr_subcomp(subc, jaxpr, xla_params)
subc = subc.build(xops.Tuple(subc, outs))
@ -2843,7 +2843,7 @@ def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
operand_shape = c.get_shape(operand)
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
c = xb.make_computation_builder(name)
c = xc.XlaBuilder(name)
operand = xb.parameter(c, 0, operand_shape)
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
outs = jaxpr_subcomp(c, jaxpr, operands)

View File

@ -1554,7 +1554,7 @@ def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: Tuple[IDHashable])
typecheck_jaxpr(jaxpr)
consts = [x.val for x in hashable_consts]
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
c = xb.make_computation_builder('xla_call')
c = xc.XlaBuilder('xla_call')
xla_consts = _xla_consts(c, consts)
xla_params = _xla_params(c, in_avals)
outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)
@ -1640,7 +1640,7 @@ xla_translations[less_p] = partial(direct_translation, xops.Lt)
def reduce_sum_translation(c, in_avals, in_vals, *, axis):
(x_aval,), (x,) = in_avals, in_vals
zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))
subc = xb.make_computation_builder('add')
subc = xc.XlaBuilder('add')
shape = _xla_shape(ShapedArray((), x_aval.dtype))
xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))
return [xops.Reduce(c, [x], [zero], subc.build(), [axis])]
@ -1770,7 +1770,7 @@ abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule
def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):
del num_consts # Only used at top-level.
# Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.
subc = xb.make_computation_builder('inner xla_call')
subc = xc.XlaBuilder('inner xla_call')
xla_params = _xla_params(subc, in_avals)
outs = jaxpr_subcomp(subc, jaxpr, xla_params)
subc = subc.build(xops.Tuple(subc, outs))
@ -2835,7 +2835,7 @@ def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
operand_shape = c.get_shape(operand)
def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
c = xb.make_computation_builder(name)
c = xc.XlaBuilder(name)
operand = xb.parameter(c, 0, operand_shape)
operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
outs = jaxpr_subcomp(c, jaxpr, operands)

View File

@ -817,7 +817,7 @@ def xla_computation(fun: Callable,
else:
out_parts_flat = tuple(flatten_axes(
"xla_computation out_parts", out_tree(), out_parts))
c = xb.make_computation_builder(f"xla_computation_{fun_name}")
c = xc.XlaBuilder(f"xla_computation_{fun_name}")
xla_consts = map(partial(xb.constant, c), consts)
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
xla_args, donated_invars = xla._xla_callable_args(

View File

@ -336,7 +336,7 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
cond_c = xb.make_computation_builder("cond_computation")
cond_c = xla_client.XlaBuilder("cond_computation")
cond_carry = xb.parameter(cond_c, 0, c.get_shape(init_carry))
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
@ -351,7 +351,7 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, np.array(False))], or_,
list(range(cond_jaxpr.out_avals[0].ndim)))
body_c = xb.make_computation_builder("body_computation")
body_c = xla_client.XlaBuilder("body_computation")
body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry))
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
@ -801,7 +801,7 @@ def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
name_stack = extend_name_stack(ctx.name_stack, "cond")
def make_computation(name, jaxpr, op_shape):
c = xb.make_computation_builder(name + '_comp')
c = xla_client.XlaBuilder(name + '_comp')
op = xb.parameter(c, 0, op_shape)
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
subctx = ctx.replace(

View File

@ -4727,7 +4727,7 @@ def _gather_translation_rule(c, operand, indices, *, dimension_numbers,
# Compute the conjunction of the mask elements across the dimensions in which
# we are slicing.
and_builder = xb.make_computation_builder("and_reduction")
and_builder = xc.XlaBuilder("and_reduction")
scalar_pred = xla_client.Shape.array_shape(np.dtype(np.bool_), ())
xops.And(xb.parameter(and_builder, 0, scalar_pred),
xb.parameter(and_builder, 1, scalar_pred))
@ -5029,7 +5029,7 @@ def _scatter_add_translation_rule(
dimension_numbers)
def _make_reducer(dtype):
subc = xla_bridge.make_computation_builder("scatter_add_reducer")
subc = xc.XlaBuilder("scatter_add_reducer")
shape = xc.Shape.array_shape(np.dtype(dtype), ())
args = [xb.parameter(subc, 0, shape), xb.parameter(subc, 1, shape)]
out = xops.Add(args[0], args[1])
@ -5549,7 +5549,7 @@ def _reduction_computation(c, jaxpr, consts, init_values, singleton=True):
init_values = [init_values]
shapes = safe_map(c.get_shape, init_values + init_values)
axis_env = xla.AxisEnv(1, (), ()) # no parallel primitives inside reductions
subc = xla_bridge.make_computation_builder("reduction_computation")
subc = xc.XlaBuilder("reduction_computation")
assert len(consts) == 0, "Reduction computations cannot have constants"
args = [xb.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
ctx = xla.TranslationContext(subc, None, axis_env, '')
@ -6300,7 +6300,7 @@ def _select_and_gather_add_translation(
etype)
def reducer():
c = xla_bridge.make_computation_builder("select_and_gather_pair_reducer")
c = xc.XlaBuilder("select_and_gather_pair_reducer")
x = xb.parameter(c, 0,
xla_client.Shape.array_shape(np.dtype(double_word_dtype), ()))
y = xb.parameter(c, 1,
@ -6330,7 +6330,7 @@ def _select_and_gather_add_translation_using_variadic_reducewindow(
canonicalize_types=False)
def reducer():
c = xla_bridge.make_computation_builder("select_and_gather_pair_reducer")
c = xc.XlaBuilder("select_and_gather_pair_reducer")
shape = xla_client.Shape.array_shape(np.dtype(dtype), ())
kx, vx, ky, vy = (xb.parameter(c, i, shape) for i in range(4))
which = (xops.Ge if select_prim is ge_p else xops.Le)(kx, ky)
@ -6487,7 +6487,7 @@ def _sort_lt_comparator(*operands, num_keys=1):
def _sort_translation_rule(c, *operands, dimension, is_stable, num_keys):
types = [c.get_shape(x).xla_element_type() for x in operands]
subc = xla_bridge.make_computation_builder("sort_lt_comparator")
subc = xc.XlaBuilder("sort_lt_comparator")
params = [xb.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
for i, typ in enumerate(types) for j in range(2)]
result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys),

View File

@ -558,9 +558,6 @@ def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
def make_computation_builder(name):
return xla_client.XlaBuilder(name)
def register_constant_handler(type_, handler_fun):
_constant_handlers[type_] = handler_fun

View File

@ -743,7 +743,7 @@ dynamic_xla_call_p.multiple_results = True
def _dynamic_xla_call_impl(*args, jaxpr, num_consts):
in_dim_vals, consts, args = split_list(args, [len(jaxpr.in_dim_binders), num_consts])
dim_in_avals = [v.aval for v in jaxpr.in_dim_binders]
c = xb.make_computation_builder("dxla_call")
c = xc.XlaBuilder("dxla_call")
dim_params, params = _make_params(c, dim_in_avals, map(xla.abstractify, args))
const_params = _xla_consts(c, consts)
dim_outs, outs = djaxpr_subcomp(c, jaxpr, dim_params, const_params + params)

View File

@ -496,7 +496,7 @@ xops = xla_client._xla.ops
XlaOp = xla_client.XlaOp
XlaShape = xla_client.Shape
XlaComputationBuilder = xla_client.XlaBuilder
XlaBuilder = xla_client.XlaBuilder
XlaDevice = xla_client.Device
XlaLocalClient = xla_client.Client
DType = Any
@ -923,7 +923,7 @@ def _outside_call_impl(*args, **params):
outside_call_p.def_impl(_outside_call_impl)
def _outside_call_translation_rule(comp: XlaComputationBuilder,
def _outside_call_translation_rule(comp: XlaBuilder,
*args_op: XlaOp,
platform="tpu",
has_token,

View File

@ -237,7 +237,7 @@ def _call_tf_abstract_eval(*_,
call_tf_p.def_abstract_eval(_call_tf_abstract_eval)
def _call_tf_translation_rule(builder: xla.XlaComputationBuilder, *args_op,
def _call_tf_translation_rule(builder: xla.XlaBuilder, *args_op,
function_flat_tf,
args_flat_sig_tf,
**_):
@ -253,7 +253,7 @@ def _code_generator_and_avals(
function_flat_tf,
args_flat_sig_tf,
code_gen_optional=False
) -> Tuple[Optional[Callable[[xla.XlaComputationBuilder, Sequence[xla.XlaOp]], xla.XlaOp]],
) -> Tuple[Optional[Callable[[xla.XlaBuilder, Sequence[xla.XlaOp]], xla.XlaOp]],
Sequence[core.ShapedArray]]:
# Returns and caches a code generator (taking a builder and the
# XlaOps for the arguments) and a sequence of result abstract shapes.
@ -384,7 +384,7 @@ def _code_generator_and_avals(
result_avals = tuple(map(canonical_res_aval, result_shapes)) # type: ignore
def code_gen(builder: xla.XlaComputationBuilder, args_op: Sequence[xla.XlaOp]) -> xla.XlaOp:
def code_gen(builder: xla.XlaBuilder, args_op: Sequence[xla.XlaOp]) -> xla.XlaOp:
captured_ops = [xops.ConstantLiteral(builder, np.asarray(inp))
for inp in captured_inputs]

View File

@ -875,7 +875,7 @@ def parallel_callable(fun: lu.WrappedFun,
tuple_args = len(global_sharded_avals) > 100 # pass long arg lists as tuple for TPU
c = xb.make_computation_builder("pmap_{}".format(fun.__name__))
c = xc.XlaBuilder("pmap_{}".format(fun.__name__))
xla_consts = map(partial(xb.constant, c), consts)
replicated_args = [axis is None for axis in in_axes]
xla_args, donated_invars = xla._xla_callable_args(c, global_sharded_avals, tuple_args,
@ -1594,7 +1594,7 @@ def lower_mesh_computation(
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
# 3. Build up the HLO
c = xb.make_computation_builder(f"xmap_{fun.__name__}")
c = xc.XlaBuilder(f"xmap_{fun.__name__}")
xla_consts = map(partial(xb.constant, c), consts)
tuple_args = len(in_jaxpr_avals) > 100 # pass long arg lists as tuple for TPU
in_partitions: Optional[List]

View File

@ -138,7 +138,7 @@ def _sharded_callable(
"Compiling %s for %d devices with args %s.",
fun.__name__, nparts, global_abstract_args)
c = xb.make_computation_builder("spjit_{}".format(fun.__name__))
c = xc.XlaBuilder("spjit_{}".format(fun.__name__))
xla_consts = _map(partial(xb.constant, c), consts)
xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
axis_env = xla.AxisEnv(nrep, (), ())

View File

@ -62,7 +62,7 @@ Buffer = xe.Buffer
XlaOp = xc.XlaOp
XlaShape = xc.Shape
XlaComputationBuilder = xc.XlaBuilder
XlaBuilder = xc.XlaBuilder
XlaExecutable = xc.Executable
# This flag is set on exit; no logging should be attempted
@ -328,7 +328,7 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
def primitive_subcomputation(prim: core.Primitive, *avals: core.AbstractValue,
**params):
c = xb.make_computation_builder(f"primitive_computation_{prim.name}")
c = xc.XlaBuilder(f"primitive_computation_{prim.name}")
f = lower_fun(prim.bind, multiple_results=prim.multiple_results)
xla_args, _ = _xla_callable_args(c, avals, tuple_args=False)
ans = f(c, *xla_args, **params)
@ -683,7 +683,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU
c = xb.make_computation_builder(f"jit_{fun.__name__}")
c = xc.XlaBuilder(f"jit_{fun.__name__}")
xla_consts = _xla_consts(c, consts)
xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args,
donated_invars=donated_invars)
@ -1026,7 +1026,7 @@ def _xla_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name,
del device, donated_invars, inline # Ignored.
c = ctx.builder
check_backend_matches(backend, ctx.platform)
subc = xb.make_computation_builder(f"jit_{name}")
subc = xc.XlaBuilder(f"jit_{name}")
args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
sub_ctx = ctx.replace(
builder=subc,
@ -1579,7 +1579,7 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
pred = xops.Lt(rng, xb.constant(c, np.array(2, dtype=np.float32)))
true_op = xops.Tuple(c, in_nodes)
remat_subc = xb.make_computation_builder("remat_call_subcomputation")
remat_subc = xc.XlaBuilder("remat_call_subcomputation")
input_op = xb.parameter(remat_subc, 0, c.get_shape(true_op), replicated=[])
args = xla_destructure(remat_subc, input_op)
sub_ctx = ctx.replace(
@ -1590,7 +1590,7 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
remat_subc = remat_subc.build(xops.Tuple(remat_subc, out_nodes))
false_op = true_op
dummy_subc = xb.make_computation_builder("remat_call_dummy_subcomputation")
dummy_subc = xc.XlaBuilder("remat_call_dummy_subcomputation")
xb.parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[])
out_nodes = [_zeros(dummy_subc, s) for s in out_node_shapes]
dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes))
@ -1604,7 +1604,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
c = ctx.builder
# Dummy subc for getting subcomp shapes.
dummy_inputs = xops.Tuple(c, in_nodes)
dummy_subc = xb.make_computation_builder("remat_dummy_subcomputation")
dummy_subc = xc.XlaBuilder("remat_dummy_subcomputation")
dummy_input_op = xb.parameter(dummy_subc, 0, c.get_shape(dummy_inputs), replicated=[])
dummy_args = xla_destructure(dummy_subc, dummy_input_op)
dummy_ctx = ctx.replace(
@ -1617,7 +1617,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
zeros_like_outs = [_zeros(c, s) for s in out_node_shapes]
inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs)
cond_subc = xb.make_computation_builder("remat_cond_subcomputation")
cond_subc = xc.XlaBuilder("remat_cond_subcomputation")
input_op = xb.parameter(cond_subc, 0, c.get_shape(inputs), replicated=[])
i = xops.GetTupleElement(input_op, 0)
rng = xops.RngUniform(xb.constant(cond_subc, np.array(1, dtype=np.int32)),
@ -1625,7 +1625,7 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
xc.Shape.array_shape(xc.PrimitiveType.S32, []))
cond_subc = cond_subc.build(xops.Lt(i, rng))
body_subc = xb.make_computation_builder("remat_body_subcomputation")
body_subc = xc.XlaBuilder("remat_body_subcomputation")
input_op = xb.parameter(body_subc, 0, c.get_shape(inputs), replicated=[])
i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes)+1]
i_next = xops.Add(i, xb.constant(body_subc, np.array(1, dtype=np.int32)))
@ -1664,7 +1664,7 @@ def _named_call_translation_rule(ctx, avals_in, avals_out, *in_nodes,
name="core_call", backend=None, call_jaxpr):
check_backend_matches(backend, ctx.platform)
c = ctx.builder
subc = xb.make_computation_builder(name)
subc = xc.XlaBuilder(name)
args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
sub_ctx = ctx.replace(builder=subc,
name_stack=extend_name_stack(ctx.name_stack, name))

View File

@ -39,6 +39,7 @@ from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import tree_util
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
import numpy as np
@ -1753,7 +1754,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
Check that we get the proper error from the runtime."""
if not hcb._use_outfeed(jtu.device_under_test()):
raise SkipTest("test works only for outfeed")
comp = xla_bridge.make_computation_builder(self._testMethodName)
comp = xla_client.XlaBuilder(self._testMethodName)
token = hcb.xops.CreateToken(comp)
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
with self.assertRaisesRegex(RuntimeError,
@ -1766,7 +1767,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
"""Try to register different shapes for the same consumer ID."""
if not hcb._use_outfeed(jtu.device_under_test()):
raise SkipTest("test works only for outfeed")
comp = xla_bridge.make_computation_builder(self._testMethodName)
comp = xla_client.XlaBuilder(self._testMethodName)
token = hcb.xops.CreateToken(comp)
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
hcb._callback_handler_data.receiver.add_outfeed(

View File

@ -46,13 +46,13 @@ class XlaBridgeTest(jtu.JaxTestCase):
expected_device_assignment)
def test_parameter_replication_default(self):
c = xb.make_computation_builder("test")
c = xc.XlaBuilder("test")
_ = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
built_c = c.Build()
assert "replication" not in built_c.as_hlo_text()
def test_parameter_replication(self):
c = xb.make_computation_builder("test")
c = xc.XlaBuilder("test")
_ = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "",
False)
built_c = c.Build()