mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move xla_bridge.constant to jax.interpreter.xla.pyval_to_ir_constant.
This is a more descriptive name and a better location (next to other facilities for building XLA IR). Quite a few users of the former xla_bridge.constant() didn't need anything other than uncanonicalized array constants. Change these users to use xla_client.ops.Constant instead; no need for the fancy utility in these cases. PiperOrigin-RevId: 404270649
This commit is contained in:
parent
e96f363242
commit
1a73743610
@ -2033,7 +2033,7 @@
|
||||
" env: Dict[Var, xe.XlaOp] = {}\n",
|
||||
"\n",
|
||||
" def read(x: Atom) -> xe.XlaOp:\n",
|
||||
" return env[x] if type(x) is Var else xb.constant(c, x.val, False)\n",
|
||||
" return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))\n",
|
||||
"\n",
|
||||
" def write(v: Var, val: xe.XlaOp) -> None:\n",
|
||||
" env[v] = val\n",
|
||||
|
@ -1593,7 +1593,7 @@ def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
|
||||
env: Dict[Var, xe.XlaOp] = {}
|
||||
|
||||
def read(x: Atom) -> xe.XlaOp:
|
||||
return env[x] if type(x) is Var else xb.constant(c, x.val, False)
|
||||
return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))
|
||||
|
||||
def write(v: Var, val: xe.XlaOp) -> None:
|
||||
env[v] = val
|
||||
|
@ -1587,7 +1587,7 @@ def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
|
||||
env: Dict[Var, xe.XlaOp] = {}
|
||||
|
||||
def read(x: Atom) -> xe.XlaOp:
|
||||
return env[x] if type(x) is Var else xb.constant(c, x.val, False)
|
||||
return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))
|
||||
|
||||
def write(v: Var, val: xe.XlaOp) -> None:
|
||||
env[v] = val
|
||||
|
@ -818,7 +818,7 @@ def xla_computation(fun: Callable,
|
||||
out_parts_flat = tuple(flatten_axes(
|
||||
"xla_computation out_parts", out_tree(), out_parts))
|
||||
c = xc.XlaBuilder(f"xla_computation_{fun_name}")
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts)
|
||||
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
|
||||
xla_args, donated_invars = xla._xla_callable_args(
|
||||
c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)
|
||||
|
@ -344,12 +344,13 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
name_stack=extend_name_stack(ctx.name_stack, 'cond'))
|
||||
pred, = xla.jaxpr_subcomp(
|
||||
cond_ctx, cond_jaxpr.jaxpr,
|
||||
_map(partial(xb.constant, cond_c), cond_jaxpr.consts), *(x + z))
|
||||
_map(partial(xla.pyval_to_ir_constant, cond_c), cond_jaxpr.consts),
|
||||
*(x + z))
|
||||
if batched:
|
||||
scalar = ShapedArray((), np.bool_)
|
||||
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
|
||||
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, np.array(False))], or_,
|
||||
list(range(cond_jaxpr.out_avals[0].ndim)))
|
||||
pred = xops.Reduce(cond_c, [pred], [xops.Constant(cond_c, np.array(False))],
|
||||
or_, list(range(cond_jaxpr.out_avals[0].ndim)))
|
||||
|
||||
body_c = xla_client.XlaBuilder("body_computation")
|
||||
body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry))
|
||||
@ -359,15 +360,17 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
name_stack=extend_name_stack(ctx.name_stack, 'body'))
|
||||
new_z = xla.jaxpr_subcomp(
|
||||
body_ctx, body_jaxpr.jaxpr,
|
||||
_map(partial(xb.constant, body_c), body_jaxpr.consts),
|
||||
_map(partial(xla.pyval_to_ir_constant, body_c), body_jaxpr.consts),
|
||||
*(y + z))
|
||||
if batched:
|
||||
body_pred_ctx = body_ctx.replace(
|
||||
name_stack=extend_name_stack(ctx.name_stack, 'body_pred'))
|
||||
body_pred, = xla.jaxpr_subcomp(
|
||||
body_pred_ctx, cond_jaxpr.jaxpr,
|
||||
_map(partial(xb.constant, body_c), cond_jaxpr.consts), *(x + z))
|
||||
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z, body_jaxpr.out_avals)
|
||||
_map(partial(xla.pyval_to_ir_constant, body_c), cond_jaxpr.consts),
|
||||
*(x + z))
|
||||
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z,
|
||||
body_jaxpr.out_avals)
|
||||
assert _map(body_c.get_shape, new_z) == _map(body_c.get_shape, z) # no broadcast
|
||||
new_carry = xops.Tuple(body_c, [*x, *y, *new_z])
|
||||
|
||||
@ -806,8 +809,9 @@ def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
|
||||
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
|
||||
subctx = ctx.replace(
|
||||
builder=c, name_stack=extend_name_stack(name_stack, name + '_fun'))
|
||||
outs = xla.jaxpr_subcomp(subctx, jaxpr.jaxpr,
|
||||
_map(partial(xb.constant, c), jaxpr.consts), *ops)
|
||||
outs = xla.jaxpr_subcomp(
|
||||
subctx, jaxpr.jaxpr,
|
||||
_map(partial(xla.pyval_to_ir_constant, c), jaxpr.consts), *ops)
|
||||
return c.build(xops.Tuple(c, outs))
|
||||
|
||||
c = ctx.builder
|
||||
|
@ -2460,10 +2460,12 @@ def _sign_translation_rule(ctx, avals_in, avals_out, x):
|
||||
x_aval, = avals_in
|
||||
dtype = x_aval.dtype
|
||||
if dtypes.issubdtype(dtype, np.unsignedinteger):
|
||||
zero = xb.constant(c, np.array(0, dtype=dtype))
|
||||
return [xops.Select(xops.Eq(x, zero), xops.Broadcast(zero, x_aval.shape),
|
||||
xops.Broadcast(xb.constant(c, np.array(1, dtype=dtype)),
|
||||
x_aval.shape))]
|
||||
zero = xops.Constant(c, np.array(0, dtype=dtype))
|
||||
return [xops.Select(
|
||||
xops.Eq(x, zero),
|
||||
xops.Broadcast(zero, x_aval.shape),
|
||||
xops.Broadcast(xops.Constant(c, np.array(1, dtype=dtype)),
|
||||
x_aval.shape))]
|
||||
return [xops.Sign(x)]
|
||||
|
||||
sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule)
|
||||
@ -2764,7 +2766,8 @@ else:
|
||||
return [xops.Mul(
|
||||
xops.Sign(x),
|
||||
xops.Pow(xops.Abs(x),
|
||||
xb.constant(ctx.builder, np.array(1/3, dtype=x_aval.dtype))))]
|
||||
xla.pyval_to_ir_constant(ctx.builder,
|
||||
np.array(1/3, dtype=x_aval.dtype))))]
|
||||
|
||||
cbrt_p = standard_unop(_float, 'cbrt',
|
||||
translation_rule=_cbrt_translation_rule)
|
||||
@ -2794,7 +2797,7 @@ def _integer_pow_translation_rule(ctx, avals_in, avals_out, x, *, y):
|
||||
# This should be kept in sync with the jax2tf translation rule.
|
||||
x_aval, = avals_in
|
||||
if y == 0:
|
||||
one = xb.constant(ctx.builder, np.array(1, dtype=x_aval.dtype))
|
||||
one = xla.pyval_to_ir_constant(ctx.builder, np.array(1, dtype=x_aval.dtype))
|
||||
return [xops.Broadcast(one, x_aval.shape)]
|
||||
is_reciprocal = y < 0
|
||||
if is_reciprocal:
|
||||
@ -4752,8 +4755,8 @@ def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *,
|
||||
|
||||
upper_bound = operand_dims[intarray(dnums.start_index_map)]
|
||||
upper_bound -= intarray(slice_sizes)[intarray(dnums.start_index_map)]
|
||||
mask = xops.And(xops.Ge(indices, xb.constant(c, intarray(0))),
|
||||
xops.Le(indices, xb.constant(c, upper_bound),
|
||||
mask = xops.And(xops.Ge(indices, xla.pyval_to_ir_constant(c, intarray(0))),
|
||||
xops.Le(indices, xla.pyval_to_ir_constant(c, upper_bound),
|
||||
broadcast_dimensions=[num_batch_dims]))
|
||||
|
||||
# Compute the conjunction of the mask elements across the dimensions in which
|
||||
@ -4762,8 +4765,8 @@ def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *,
|
||||
scalar_pred = xla_client.Shape.array_shape(np.dtype(np.bool_), ())
|
||||
xops.And(xb.parameter(and_builder, 0, scalar_pred),
|
||||
xb.parameter(and_builder, 1, scalar_pred))
|
||||
mask = xops.Reduce(c, [mask], [xb.constant(c, True)], and_builder.build(),
|
||||
[num_batch_dims])
|
||||
mask = xops.Reduce(c, [mask], [xla.pyval_to_ir_constant(c, True)],
|
||||
and_builder.build(), [num_batch_dims])
|
||||
|
||||
# Computes the output shape and the positions of the batch dimensions in the
|
||||
# output
|
||||
@ -4778,7 +4781,7 @@ def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *,
|
||||
xops.Gather(operand, indices, dimensions, slice_sizes,
|
||||
indices_are_sorted=indices_are_sorted),
|
||||
xops.Broadcast(
|
||||
xb.constant(c, np.array(fill_value, operand_aval.dtype)),
|
||||
xla.pyval_to_ir_constant(c, np.array(fill_value, operand_aval.dtype)),
|
||||
aval_out.shape))]
|
||||
|
||||
def _gather_jvp_rule(g, operand, indices, *, dimension_numbers,
|
||||
@ -5023,8 +5026,9 @@ def _clamp_scatter_indices(c, indices, operand_shape, updates_shape, dnums):
|
||||
upper_bound -= intarray(slice_sizes)[intarray(dnums.scatter_dims_to_operand_dims)]
|
||||
upper_bound = np.minimum(upper_bound, np.iinfo(indices_dtype).max)
|
||||
return xops.Min(
|
||||
xops.Max(xb.constant(c, np.array(0, dtype=indices_dtype)), indices),
|
||||
xb.constant(c, upper_bound.astype(indices_dtype)),
|
||||
xops.Max(xla.pyval_to_ir_constant(c, np.array(0, dtype=indices_dtype)),
|
||||
indices),
|
||||
xla.pyval_to_ir_constant(c, upper_bound.astype(indices_dtype)),
|
||||
broadcast_dimensions=[len(indices_shape.dimensions()) - 1])
|
||||
|
||||
def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices,
|
||||
@ -5037,7 +5041,7 @@ def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices,
|
||||
indices = _clamp_scatter_indices(c, indices, operand_aval.shape,
|
||||
updates_aval.shape, dimension_numbers)
|
||||
|
||||
init_value = xb.constant(c, np.array(0, operand_aval.dtype))
|
||||
init_value = xla.pyval_to_ir_constant(c, np.array(0, operand_aval.dtype))
|
||||
update_computation = _reduction_computation(
|
||||
c, update_jaxpr, update_consts, init_value)
|
||||
return [xops.Scatter(
|
||||
@ -5690,7 +5694,7 @@ def _reduce_sum_translation_rule(ctx, avals_in, avals_out, operand, *, axes):
|
||||
scalar = ShapedArray((), operand_aval.dtype)
|
||||
return [xops.Reduce(
|
||||
ctx.builder, [operand],
|
||||
[xb.constant(ctx.builder, np.array(0, operand_aval.dtype))],
|
||||
[xla.pyval_to_ir_constant(ctx.builder, np.array(0, operand_aval.dtype))],
|
||||
xla.primitive_subcomputation(add_p, scalar, scalar), axes)]
|
||||
|
||||
def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
|
||||
@ -5723,7 +5727,7 @@ def _reduce_prod_translation_rule(ctx, avals_in, avals_out, operand, *, axes):
|
||||
scalar = ShapedArray((), operand_aval.dtype)
|
||||
return [xops.Reduce(
|
||||
ctx.builder, [operand],
|
||||
[xb.constant(ctx.builder, np.array(1, operand_aval.dtype))],
|
||||
[xla.pyval_to_ir_constant(ctx.builder, np.array(1, operand_aval.dtype))],
|
||||
xla.primitive_subcomputation(mul_p, scalar, scalar), axes)]
|
||||
|
||||
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
|
||||
@ -5749,7 +5753,8 @@ def _reduce_chooser_translation_rule(prim, identity, ctx, avals_in, avals_out,
|
||||
operand_aval, = avals_in
|
||||
scalar = ShapedArray((), operand_aval.dtype)
|
||||
return [xops.Reduce(ctx.builder, [operand],
|
||||
[xb.constant(ctx.builder, identity(operand_aval.dtype))],
|
||||
[xla.pyval_to_ir_constant(ctx.builder,
|
||||
identity(operand_aval.dtype))],
|
||||
xla.primitive_subcomputation(prim, scalar, scalar), axes)]
|
||||
|
||||
def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
|
||||
@ -5872,7 +5877,7 @@ def _reduce_logical_translation_rule(prim, identity, ctx, avals_in, avals_out,
|
||||
operand, *, axes):
|
||||
scalar = ShapedArray((), np.bool_)
|
||||
return [xops.Reduce(ctx.builder, [operand],
|
||||
[xb.constant(ctx.builder, identity(np.bool_))],
|
||||
[xla.pyval_to_ir_constant(ctx.builder, identity(np.bool_))],
|
||||
xla.primitive_subcomputation(prim, scalar, scalar), axes)]
|
||||
|
||||
_reduce_or_translation_rule = partial(_reduce_logical_translation_rule,
|
||||
@ -5958,7 +5963,8 @@ def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *,
|
||||
operand_aval, = avals_in
|
||||
scalar = ShapedArray((), operand_aval.dtype)
|
||||
return [xops.ReduceWindowWithGeneralPadding(
|
||||
operand, xb.constant(ctx.builder, np.array(0, operand_aval.dtype)),
|
||||
operand,
|
||||
xla.pyval_to_ir_constant(ctx.builder, np.array(0, operand_aval.dtype)),
|
||||
xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions,
|
||||
window_strides, base_dilation, window_dilation, padding)]
|
||||
|
||||
@ -6012,7 +6018,8 @@ def _reduce_window_chooser_translation_rule(
|
||||
operand_aval, = avals_in
|
||||
scalar = ShapedArray((), operand_aval.dtype)
|
||||
return [xops.ReduceWindowWithGeneralPadding(
|
||||
operand, xb.constant(ctx.builder, identity(operand_aval.dtype)),
|
||||
operand,
|
||||
xla.pyval_to_ir_constant(ctx.builder, identity(operand_aval.dtype)),
|
||||
xla.primitive_subcomputation(prim, scalar, scalar), window_dimensions,
|
||||
window_strides, base_dilation, window_dilation, padding)]
|
||||
|
||||
@ -6162,7 +6169,7 @@ def _select_and_scatter_add_translation(
|
||||
select = xla.primitive_subcomputation(select_prim, scalar, scalar)
|
||||
scatter = xla.primitive_subcomputation(or_p if dtype == np.bool_ else add_p,
|
||||
scalar, scalar)
|
||||
zero = xb.constant(c, np.array(0, dtype))
|
||||
zero = xla.pyval_to_ir_constant(c, np.array(0, dtype))
|
||||
# TODO(b/161704903): remove this workaround when XLA:CPU bug is fixed.
|
||||
expand_padding = (expand_padding and
|
||||
not all(lo == 0 and hi == 0 for (lo, hi) in padding))
|
||||
@ -6171,7 +6178,7 @@ def _select_and_scatter_add_translation(
|
||||
identity = (_get_max_identity if select_prim is ge_p
|
||||
else _get_min_identity)
|
||||
pads = [(lo, hi, 0) for (lo, hi) in padding]
|
||||
operand = xops.Pad(operand, xb.constant(c, identity(dtype)),
|
||||
operand = xops.Pad(operand, xla.pyval_to_ir_constant(c, identity(dtype)),
|
||||
xc.make_padding_config(pads))
|
||||
padding = [(0, 0) for _ in padding]
|
||||
output = xops.SelectAndScatterWithGeneralPadding(
|
||||
@ -6287,8 +6294,7 @@ def _select_and_gather_add_translation(
|
||||
assert nbits <= max_bits
|
||||
double_word_reduction = nbits * 2 <= max_bits
|
||||
|
||||
const = lambda c, dtype, x: xb.constant(c, np.array(x, dtype=dtype),
|
||||
canonicalize_types=False)
|
||||
const = lambda c, dtype, x: xops.Constant(c, np.array(x, dtype=dtype))
|
||||
|
||||
if double_word_reduction:
|
||||
# TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
|
||||
@ -6381,8 +6387,7 @@ def _select_and_gather_add_translation_using_variadic_reducewindow(
|
||||
tangents_aval, operand_aval = avals_in
|
||||
dtype = operand_aval.dtype
|
||||
|
||||
const = lambda c, dtype, x: xb.constant(c, np.array(x, dtype=dtype),
|
||||
canonicalize_types=False)
|
||||
const = lambda c, dtype, x: xops.Constant(c, np.array(x, dtype=dtype))
|
||||
|
||||
def reducer():
|
||||
c = xc.XlaBuilder("select_and_gather_pair_reducer")
|
||||
@ -6870,26 +6875,28 @@ def _rng_bit_generator_translation_rule(
|
||||
|
||||
def _convert_4xU32_to_2xU64_without_bitcast(c, key):
|
||||
u64_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint64'))
|
||||
new_key = xb.constant(c, np.zeros(2, dtype=np.dtype('uint64')),
|
||||
canonicalize_types=False)
|
||||
_32 = xb.constant(c, np.uint64(32), canonicalize_types=False)
|
||||
new_key = xops.Constant(c, np.zeros(2, dtype=np.dtype('uint64')))
|
||||
_32 = xops.Constant(c, np.array(32, np.uint64))
|
||||
for i in [0, 2]:
|
||||
hi = xops.ConvertElementType(xops.Slice(key, [i] , [i+1], [1]), u64_etype)
|
||||
lo = xops.ConvertElementType(xops.Slice(key, [i+1], [i+2], [1]), u64_etype)
|
||||
elt = xops.Xor(xops.ShiftLeft(hi, _32), lo)
|
||||
new_key = xops.DynamicUpdateSlice(new_key, elt, [xb.constant(c, i // 2)])
|
||||
new_key = xops.DynamicUpdateSlice(new_key, elt,
|
||||
[xla.pyval_to_ir_constant(c, i // 2)])
|
||||
return new_key
|
||||
|
||||
def _convert_2xU64_to_4xU32_without_bitcast(c, key):
|
||||
u32_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint32'))
|
||||
new_key = xb.constant(c, np.zeros(4, dtype=np.dtype('uint32')))
|
||||
_32 = xb.constant(c, np.uint64(32), canonicalize_types=False)
|
||||
new_key = xops.Constant(c, np.zeros(4, dtype=np.dtype('uint32')))
|
||||
_32 = xops.Constant(c, np.array(32, np.uint64))
|
||||
for i in [0, 1]:
|
||||
elt = xops.Slice(key, [i], [i+1], [1])
|
||||
hi = xops.ConvertElementType(xops.ShiftRightLogical(elt, _32), u32_etype)
|
||||
lo = xops.ConvertElementType(elt, u32_etype)
|
||||
new_key = xops.DynamicUpdateSlice(new_key, hi, [xb.constant(c, 2 * i)])
|
||||
new_key = xops.DynamicUpdateSlice(new_key, lo, [xb.constant(c, 2 * i + 1)])
|
||||
new_key = xops.DynamicUpdateSlice(new_key, hi,
|
||||
[xla.pyval_to_ir_constant(c, 2 * i)])
|
||||
new_key = xops.DynamicUpdateSlice(new_key, lo,
|
||||
[xla.pyval_to_ir_constant(c, 2 * i + 1)])
|
||||
return new_key
|
||||
|
||||
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
|
||||
|
@ -41,7 +41,6 @@ from jax._src.lib import cusparse
|
||||
from jax._src.lib import rocsolver
|
||||
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
|
||||
xops = xla_client.ops
|
||||
@ -346,9 +345,9 @@ def _nan_like(c, operand):
|
||||
shape = c.get_shape(operand)
|
||||
dtype = shape.element_type()
|
||||
if jnp.issubdtype(dtype, np.complexfloating):
|
||||
nan = xb.constant(c, np.array(np.nan * (1. + 1j), dtype=dtype))
|
||||
nan = xops.Constant(c, np.array(np.nan * (1. + 1j), dtype=dtype))
|
||||
else:
|
||||
nan = xb.constant(c, np.array(np.nan, dtype=dtype))
|
||||
nan = xops.Constant(c, np.array(np.nan, dtype=dtype))
|
||||
return xops.Broadcast(nan, shape.dimensions())
|
||||
|
||||
def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand):
|
||||
@ -696,7 +695,7 @@ def _triangular_solve_cpu_translation_rule(
|
||||
conjugate_a = False
|
||||
if len(shape.dimensions()) == 2 and np.dtype(dtype) in _cpu_lapack_types:
|
||||
return lapack.jax_trsm(
|
||||
c, xb.constant(c, np.array(1, dtype=dtype)),
|
||||
c, xops.Constant(c, np.array(1, dtype=dtype)),
|
||||
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)
|
||||
else:
|
||||
# Fall back to the HLO implementation for unsupported types or batching.
|
||||
|
@ -32,7 +32,6 @@ from jax.interpreters import pxla
|
||||
from jax.interpreters import batching
|
||||
from jax._src import dtypes
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, moveaxis
|
||||
|
||||
@ -1321,9 +1320,10 @@ def _build_axis_index_lowering(c, axis_name, axis_env):
|
||||
axis_name, = axis_name
|
||||
axis_pos = list(axis_env.names).index(axis_name)
|
||||
nreplicas = axis_env.nreps // prod(axis_env.sizes)
|
||||
div = xb.constant(c, np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]),
|
||||
dtype=np.uint32))
|
||||
mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
|
||||
div = xops.Constant(c,
|
||||
np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]),
|
||||
dtype=np.uint32))
|
||||
mod = xops.Constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
|
||||
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
return xops.ConvertElementType(
|
||||
unsigned_index, xla.dtype_to_primitive_type(np.dtype(np.int32)))
|
||||
|
@ -22,7 +22,7 @@ XLA. There are also a handful of related casting utilities.
|
||||
|
||||
from functools import partial, lru_cache
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
@ -34,7 +34,6 @@ from jax._src.config import flags, bool_env
|
||||
from . import tpu_driver_client
|
||||
from . import xla_client
|
||||
from jax._src import util, traceback_util
|
||||
from jax._src import dtypes
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
@ -427,20 +426,6 @@ def host_ids(backend=None):
|
||||
|
||||
### utility functions
|
||||
|
||||
# TODO(mattjj,frostig): try to remove this function
|
||||
def normalize_to_xla_dtypes(val):
|
||||
"""Normalize dtypes in a value."""
|
||||
if hasattr(val, '__array__') or np.isscalar(val):
|
||||
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
|
||||
elif isinstance(val, (tuple, list)):
|
||||
return tuple(normalize_to_xla_dtypes(x) for x in val)
|
||||
raise TypeError('Can\'t convert to XLA: {}'.format(val))
|
||||
|
||||
def _numpy_array_constant(builder, value, canonicalize_types=True):
|
||||
if canonicalize_types:
|
||||
value = normalize_to_xla_dtypes(value)
|
||||
return [xops.ConstantLiteral(builder, value)]
|
||||
|
||||
def parameter(builder, num, shape, name=None, replicated=None):
|
||||
if name is None:
|
||||
name = ''
|
||||
@ -453,36 +438,6 @@ def parameter(builder, num, shape, name=None, replicated=None):
|
||||
shape.with_major_to_minor_layout_if_absent(), name,
|
||||
replicated)
|
||||
|
||||
|
||||
def constant_general(builder, py_val, canonicalize_types=True):
|
||||
"""Translate a general constant `py_val` to a constant, canonicalizing its dtype.
|
||||
|
||||
Args:
|
||||
py_val: a Python value to be translated to a constant.
|
||||
|
||||
Returns:
|
||||
A representation of the constant as a list of xla ops.
|
||||
"""
|
||||
for t in type(py_val).mro():
|
||||
handler = _constant_handlers.get(t)
|
||||
if handler: return handler(builder, py_val, canonicalize_types)
|
||||
if hasattr(py_val, '__jax_array__'):
|
||||
return constant(builder, py_val.__jax_array__(), canonicalize_types)
|
||||
raise TypeError("No constant handler for type: {}".format(type(py_val)))
|
||||
|
||||
def constant(builder, py_val, canonicalize_types=True):
|
||||
"""Translate constant `py_val` to a constant, canonicalizing its dtype.
|
||||
|
||||
Args:
|
||||
py_val: a Python value to be translated to a constant.
|
||||
|
||||
Returns:
|
||||
A representation of the constant, either a ComputationDataHandle or None
|
||||
"""
|
||||
const = constant_general(builder, py_val, canonicalize_types=canonicalize_types)
|
||||
assert len(const) == 1, f"Internal error: cannot create constant from object of type {type(py_val)}"
|
||||
return const[0]
|
||||
|
||||
# HLO instructions optionally can be annotated to say how the output should be
|
||||
# spatially partitioned (represented in XLA as OpSharding protos, see
|
||||
# _sharding_to_proto). For array outputs, the annotation is either an int per
|
||||
@ -545,65 +500,3 @@ def set_sharding(builder, op, sharding: SpatialSharding):
|
||||
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
|
||||
|
||||
|
||||
def register_constant_handler(type_, handler_fun):
|
||||
_constant_handlers[type_] = handler_fun
|
||||
_constant_handlers: Dict[type, Callable] = {}
|
||||
|
||||
|
||||
def _ndarray_constant_handler(c, val, canonicalize_types=True):
|
||||
"""Constant handler for ndarray literals, handling zero-size strides.
|
||||
|
||||
This function essentially calls _numpy_array_constant(val) except it has
|
||||
special handling of arrays with any strides of size zero: for those, it
|
||||
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
|
||||
to avoid staging in large literals that might arise from np.zeros or np.ones
|
||||
or the output of lax.broadcast (which uses np.broadcast_to which in turn
|
||||
uses size-zero strides).
|
||||
|
||||
Args:
|
||||
c: an XlaBuilder
|
||||
val: an ndarray.
|
||||
|
||||
Returns:
|
||||
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
|
||||
staged into the XLA Computation.
|
||||
"""
|
||||
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
|
||||
if dtypes.result_type(val) == dtypes.float0:
|
||||
return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_))
|
||||
elif np.any(np.equal(0, val.strides)) and val.size > 0:
|
||||
zero_stride_axes, = np.where(np.equal(0, val.strides))
|
||||
other_axes, = np.where(np.not_equal(0, val.strides))
|
||||
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
|
||||
for ax in range(val.ndim))]
|
||||
xla_val = xops.Broadcast(
|
||||
_numpy_array_constant(c, collapsed_val, canonicalize_types)[0],
|
||||
np.take(val.shape, zero_stride_axes))
|
||||
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
|
||||
return [xops.Transpose(xla_val, permutation)]
|
||||
else:
|
||||
return _numpy_array_constant(c, val, canonicalize_types)
|
||||
register_constant_handler(np.ndarray, _ndarray_constant_handler)
|
||||
|
||||
|
||||
def _scalar_constant_handler(c, val, canonicalize_types=True):
|
||||
return _numpy_array_constant(c, val, canonicalize_types)
|
||||
|
||||
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.float16, np.float32, np.float64,
|
||||
np.bool_, np.longlong,
|
||||
xla_client.bfloat16]:
|
||||
register_constant_handler(scalar_type, _scalar_constant_handler)
|
||||
|
||||
# https://github.com/winpython/winpython/issues/613#issuecomment-380121523
|
||||
if hasattr(np, "float128"):
|
||||
register_constant_handler(np.float128, _scalar_constant_handler)
|
||||
|
||||
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
|
||||
return _numpy_array_constant(c, dtype.type(val))
|
||||
|
||||
for ptype, dtype in dtypes.python_scalar_dtypes.items():
|
||||
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
|
||||
|
@ -28,7 +28,6 @@ from jax.dtypes import float0
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import xla
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import cuda_prng
|
||||
import jax._src.pretty_printer as pp
|
||||
@ -358,7 +357,7 @@ def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2):
|
||||
rank = len(shape)
|
||||
if 0 in shape:
|
||||
zeros = xla_client.ops.Broadcast(
|
||||
xla_bridge.constant(c, np.array(0, np.uint32)), shape)
|
||||
xla_client.ops.Constant(c, np.array(0, np.uint32)), shape)
|
||||
return xla_client.ops.Tuple(c, [zeros, zeros])
|
||||
def _broadcast(x):
|
||||
ndims = c.get_shape(x).rank()
|
||||
|
@ -643,7 +643,8 @@ def _make_params(c, dim_in_avals, in_avals):
|
||||
def _xla_consts(c, consts):
|
||||
unique_consts = {id(const): const for const in consts}
|
||||
xla_consts = {
|
||||
id_: [xb.constant(c, const)] for id_, const in unique_consts.items()}
|
||||
id_: [xla.pyval_to_ir_constant(c, const)]
|
||||
for id_, const in unique_consts.items()}
|
||||
return [xla_consts[id(const)] for const in consts]
|
||||
|
||||
def djaxpr_subcomp(c, jaxpr, dim_args, args):
|
||||
@ -654,7 +655,7 @@ def djaxpr_subcomp(c, jaxpr, dim_args, args):
|
||||
|
||||
def read(v):
|
||||
if type(v) is core.Literal:
|
||||
return [xb.constant(c, xla.canonicalize_dtype(v.val))]
|
||||
return [xla.pyval_to_ir_constant(c, xla.canonicalize_dtype(v.val))]
|
||||
else:
|
||||
return env[v]
|
||||
|
||||
|
@ -1360,20 +1360,21 @@ def _xmap_translation_rule_replica(c, axis_env,
|
||||
return xops.Tuple(c, outs)
|
||||
|
||||
def _xla_tile_base_indices(c, axis_env, tile_shape, axes, axis_sizes):
|
||||
zero = xb.constant(c, np.zeros((), dtype=np.int32))
|
||||
zero = xops.Constant(c, np.zeros((), dtype=np.int32))
|
||||
linear_idxs = [zero] * len(tile_shape)
|
||||
strides = [1] * len(tile_shape)
|
||||
for name, axis in reversed(axes.items()):
|
||||
axis_index = _build_axis_index_lowering(
|
||||
c, axis_name=name, axis_env=axis_env)
|
||||
stride_c = xb.constant(c, np.array(strides[axis], np.int32))
|
||||
stride_c = xops.Constant(c, np.array(strides[axis], np.int32))
|
||||
if linear_idxs[axis] is zero and strides[axis] == 1:
|
||||
linear_idxs[axis] = axis_index
|
||||
else:
|
||||
linear_idxs[axis] = xops.Add(linear_idxs[axis], xops.Mul(axis_index, stride_c))
|
||||
strides[axis] *= axis_sizes[name]
|
||||
return [zero if linear_idx is zero else
|
||||
xops.Mul(linear_idx, xb.constant(c, np.array(tile_dim_size, np.int32)))
|
||||
xops.Mul(linear_idx,
|
||||
xops.Constant(c, np.array(tile_dim_size, np.int32)))
|
||||
for linear_idx, tile_dim_size in zip(linear_idxs, tile_shape)]
|
||||
|
||||
def _xla_tile(c, axis_env, x, in_axes, axis_sizes):
|
||||
@ -1405,7 +1406,7 @@ def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend):
|
||||
shape[axis] *= axis_sizes[name]
|
||||
base_idxs = _xla_tile_base_indices(c, axis_env, tile_shape, out_axes, axis_sizes)
|
||||
|
||||
padded = xops.Broadcast(xb.constant(c, np.array(0, x_dtype)), shape)
|
||||
padded = xops.Broadcast(xops.Constant(c, np.array(0, x_dtype)), shape)
|
||||
padded = xops.DynamicUpdateSlice(padded, x, base_idxs)
|
||||
replica_groups_protos = xc.make_replica_groups(
|
||||
xla.axis_groups(axis_env, tuple(out_axes.keys())))
|
||||
@ -1413,7 +1414,7 @@ def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend):
|
||||
|
||||
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
||||
if convert_bool:
|
||||
nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32)))
|
||||
nonzero = xops.Ne(out, xops.Constant(c, np.array(0, dtype=np.float32)))
|
||||
out = xops.ConvertElementType(
|
||||
nonzero, xla.dtype_to_primitive_type(np.dtype(np.bool_)))
|
||||
return out
|
||||
|
@ -686,12 +686,13 @@ def _shard_sharded_device_array_slow_path(x, devices, indices):
|
||||
|
||||
|
||||
def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
|
||||
return xb.constant_general(c, np.asarray(val), canonicalize_types=canonicalize_types)
|
||||
return xla.pyval_to_ir_constants(c, np.asarray(val),
|
||||
canonicalize_types=canonicalize_types)
|
||||
|
||||
|
||||
def _register_handlers_for_sharded_device_array(sda):
|
||||
shard_arg_handlers[sda] = _shard_sharded_device_array_slow_path
|
||||
xb.register_constant_handler(sda, _sharded_device_array_constant_handler)
|
||||
xla.register_constant_handler(sda, _sharded_device_array_constant_handler)
|
||||
|
||||
core.pytype_aval_mappings[sda] = ConcreteArray
|
||||
xla.device_put_handlers[sda] = xla._device_put_array
|
||||
@ -876,7 +877,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
tuple_args = len(global_sharded_avals) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
c = xc.XlaBuilder("pmap_{}".format(fun.__name__))
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts)
|
||||
replicated_args = [axis is None for axis in in_axes]
|
||||
xla_args, donated_invars = xla._xla_callable_args(c, global_sharded_avals, tuple_args,
|
||||
replicated=replicated_args,
|
||||
@ -1317,7 +1318,7 @@ def _xla_shard(c, aval, axis_env, x, in_axis):
|
||||
return x
|
||||
elif isinstance(aval, ShapedArray):
|
||||
dims = list(c.get_shape(x).dimensions())
|
||||
zero = xb.constant(c, np.zeros((), dtype=np.uint32))
|
||||
zero = xops.Constant(c, np.zeros((), dtype=np.uint32))
|
||||
idxs = [zero] * (len(dims) - 1)
|
||||
idxs.insert(in_axis, _unravel_index(c, axis_env))
|
||||
dims_unsqueezed = dims.copy()
|
||||
@ -1344,9 +1345,10 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend):
|
||||
|
||||
xla_shape = c.get_shape(x)
|
||||
dims = list(xla_shape.dimensions())
|
||||
padded = xops.Broadcast(xb.constant(c, np.array(0, xla_shape.numpy_dtype())),
|
||||
[axis_env.sizes[-1]] + dims)
|
||||
zero = xb.constant(c, np.zeros((), dtype=np.uint32))
|
||||
padded = xops.Broadcast(
|
||||
xops.Constant(c, np.array(0, xla_shape.numpy_dtype())),
|
||||
[axis_env.sizes[-1]] + dims)
|
||||
zero = xops.Constant(c, np.zeros((), dtype=np.uint32))
|
||||
idxs = [_unravel_index(c, axis_env)] + [zero] * len(dims)
|
||||
padded = xops.DynamicUpdateSlice(padded, xops.Reshape(x, [1] + dims), idxs)
|
||||
replica_groups_protos = xc.make_replica_groups(
|
||||
@ -1360,7 +1362,7 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend):
|
||||
|
||||
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
||||
if convert_bool:
|
||||
nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32)))
|
||||
nonzero = xops.Ne(out, xops.Constant(c, np.array(0, dtype=np.float32)))
|
||||
out = xops.ConvertElementType(
|
||||
nonzero, xla.dtype_to_primitive_type(np.dtype(np.bool_)))
|
||||
return out
|
||||
@ -1368,8 +1370,9 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend):
|
||||
raise TypeError((aval, c.get_shape(x)))
|
||||
|
||||
def _unravel_index(c, axis_env):
|
||||
div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes), np.uint32))
|
||||
mod = xb.constant(c, np.array(axis_env.sizes[-1], np.uint32))
|
||||
div = xops.Constant(c, np.array(axis_env.nreps // prod(axis_env.sizes),
|
||||
np.uint32))
|
||||
mod = xops.Constant(c, np.array(axis_env.sizes[-1], np.uint32))
|
||||
return xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
|
||||
# ------------------- xmap -------------------
|
||||
@ -1597,7 +1600,7 @@ def lower_mesh_computation(
|
||||
|
||||
# 3. Build up the HLO
|
||||
c = xc.XlaBuilder(f"xmap_{fun.__name__}")
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts)
|
||||
tuple_args = len(in_jaxpr_avals) > 100 # pass long arg lists as tuple for TPU
|
||||
in_partitions: Optional[List]
|
||||
if spmd_lowering:
|
||||
|
@ -139,7 +139,7 @@ def _sharded_callable(
|
||||
fun.__name__, nparts, global_abstract_args)
|
||||
|
||||
c = xc.XlaBuilder("spjit_{}".format(fun.__name__))
|
||||
xla_consts = _map(partial(xb.constant, c), consts)
|
||||
xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts)
|
||||
xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
|
||||
axis_env = xla.AxisEnv(nrep, (), ())
|
||||
ctx = xla.TranslationContext(
|
||||
|
@ -70,28 +70,13 @@ XlaExecutable = xc.Executable
|
||||
_on_exit = False
|
||||
|
||||
|
||||
def compile_or_get_cached(backend, computation, compile_options):
|
||||
# Avoid import cycle between jax and jax.experimental
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
# Persistent compilation cache only implemented on TPU.
|
||||
# TODO(skye): add warning when initializing cache on unsupported default platform
|
||||
if cc.is_initialized() and backend.platform == 'tpu':
|
||||
cached_executable = cc.get_executable(computation, compile_options, backend)
|
||||
if cached_executable is not None:
|
||||
logging.info('Persistent compilation cache hit')
|
||||
return cached_executable
|
||||
else:
|
||||
compiled = backend_compile(backend, computation, compile_options)
|
||||
cc.put_executable(computation, compile_options, compiled, backend)
|
||||
return compiled
|
||||
return backend_compile(backend, computation, compile_options)
|
||||
|
||||
def identity(x): return x
|
||||
|
||||
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
||||
|
||||
# unit representation
|
||||
def _make_unit_constant(c): return xb.constant_general(c, np.zeros((), dtype=np.dtype('bool')))
|
||||
def _make_unit_constant(c): return [
|
||||
xops.Constant(c, np.zeros((), dtype=np.dtype('bool')))]
|
||||
def _make_unit_shape(_): return (xc.Shape.array_shape(np.dtype('bool'), ()),)
|
||||
def _device_put_unit(_, device):
|
||||
backend = xb.get_device_backend(device)
|
||||
@ -128,6 +113,8 @@ def make_op_metadata(primitive: core.Primitive,
|
||||
|
||||
### handlers
|
||||
|
||||
# Numpy dtypes -> XLA primitive types
|
||||
|
||||
_dtype_to_primitive_type: Dict[np.dtype, xc.PrimitiveType] = {
|
||||
np.dtype('bool'): xc.PrimitiveType.PRED,
|
||||
np.dtype('int8'): xc.PrimitiveType.S8,
|
||||
@ -156,7 +143,8 @@ def dtype_to_primitive_type(dtype: np.dtype) -> xc.PrimitiveType:
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No XLA lowering for NumPy dtype: {dtype}") from err
|
||||
|
||||
xb.register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c))
|
||||
|
||||
# JAX abstract values -> XLA shapes
|
||||
|
||||
def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[XlaShape]:
|
||||
try:
|
||||
@ -171,6 +159,123 @@ xla_shape_handlers: Dict[Type[core.AbstractValue],
|
||||
ConcreteArray: _make_array_shape,
|
||||
}
|
||||
|
||||
|
||||
|
||||
# IR constants
|
||||
|
||||
_constant_handlers: Dict[type, Callable] = {}
|
||||
|
||||
def pyval_to_ir_constants(builder, py_val, canonicalize_types=True):
|
||||
"""Translate a general constant `py_val` to a constant, canonicalizing its dtype.
|
||||
|
||||
Args:
|
||||
py_val: a Python value to be translated to a constant.
|
||||
|
||||
Returns:
|
||||
A representation of the constant as a list of xla ops.
|
||||
"""
|
||||
for t in type(py_val).mro():
|
||||
handler = _constant_handlers.get(t)
|
||||
if handler: return handler(builder, py_val, canonicalize_types)
|
||||
if hasattr(py_val, '__jax_array__'):
|
||||
return pyval_to_ir_constants(builder, py_val.__jax_array__(),
|
||||
canonicalize_types)
|
||||
raise TypeError("No constant handler for type: {}".format(type(py_val)))
|
||||
|
||||
def pyval_to_ir_constant(builder, py_val, canonicalize_types=True):
|
||||
"""Translate constant `py_val` to a constant, canonicalizing its dtype.
|
||||
|
||||
Args:
|
||||
py_val: a Python value to be translated to a constant.
|
||||
|
||||
Returns:
|
||||
A representation of the constant, either a ComputationDataHandle or None
|
||||
"""
|
||||
const = pyval_to_ir_constants(builder, py_val, canonicalize_types=canonicalize_types)
|
||||
assert len(const) == 1, f"Internal error: cannot create constant from object of type {type(py_val)}"
|
||||
return const[0]
|
||||
|
||||
|
||||
def register_constant_handler(type_, handler_fun):
|
||||
_constant_handlers[type_] = handler_fun
|
||||
|
||||
register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c))
|
||||
|
||||
|
||||
# TODO(mattjj,frostig): try to remove this function
|
||||
def _normalize_to_xla_dtypes(val):
|
||||
"""Normalize dtypes in a value."""
|
||||
if hasattr(val, '__array__') or np.isscalar(val):
|
||||
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
|
||||
elif isinstance(val, (tuple, list)):
|
||||
return tuple(_normalize_to_xla_dtypes(x) for x in val)
|
||||
raise TypeError('Can\'t convert to XLA: {}'.format(val))
|
||||
|
||||
def _numpy_array_constant(builder, value, canonicalize_types=True):
|
||||
if canonicalize_types:
|
||||
value = _normalize_to_xla_dtypes(value)
|
||||
return [xops.Constant(builder, value)]
|
||||
|
||||
|
||||
def _ndarray_constant_handler(c, val, canonicalize_types=True):
|
||||
"""Constant handler for ndarray literals, handling zero-size strides.
|
||||
|
||||
This function essentially calls _numpy_array_constant(val) except it has
|
||||
special handling of arrays with any strides of size zero: for those, it
|
||||
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
|
||||
to avoid staging in large literals that might arise from np.zeros or np.ones
|
||||
or the output of lax.broadcast (which uses np.broadcast_to which in turn
|
||||
uses size-zero strides).
|
||||
|
||||
Args:
|
||||
c: an XlaBuilder
|
||||
val: an ndarray.
|
||||
|
||||
Returns:
|
||||
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
|
||||
staged into the XLA Computation.
|
||||
"""
|
||||
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
|
||||
if dtypes.result_type(val) == dtypes.float0:
|
||||
return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_))
|
||||
elif np.any(np.equal(0, val.strides)) and val.size > 0:
|
||||
zero_stride_axes, = np.where(np.equal(0, val.strides))
|
||||
other_axes, = np.where(np.not_equal(0, val.strides))
|
||||
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
|
||||
for ax in range(val.ndim))]
|
||||
xla_val = xops.Broadcast(
|
||||
_numpy_array_constant(c, collapsed_val, canonicalize_types)[0],
|
||||
np.take(val.shape, zero_stride_axes))
|
||||
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
|
||||
return [xops.Transpose(xla_val, permutation)]
|
||||
else:
|
||||
return _numpy_array_constant(c, val, canonicalize_types)
|
||||
register_constant_handler(np.ndarray, _ndarray_constant_handler)
|
||||
|
||||
|
||||
def _scalar_constant_handler(c, val, canonicalize_types=True):
|
||||
return _numpy_array_constant(c, val, canonicalize_types)
|
||||
|
||||
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.float16, np.float32, np.float64,
|
||||
np.bool_, np.longlong,
|
||||
dtypes.bfloat16]:
|
||||
register_constant_handler(scalar_type, _scalar_constant_handler)
|
||||
|
||||
# https://github.com/winpython/winpython/issues/613#issuecomment-380121523
|
||||
if hasattr(np, "float128"):
|
||||
register_constant_handler(np.float128, _scalar_constant_handler)
|
||||
|
||||
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
|
||||
return _numpy_array_constant(c, dtype.type(val))
|
||||
|
||||
for ptype, dtype in dtypes.python_scalar_dtypes.items():
|
||||
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
|
||||
|
||||
|
||||
# Result handlers
|
||||
|
||||
def aval_to_result_handler(device: Optional[Device],
|
||||
aval: core.AbstractValue) -> Callable:
|
||||
try:
|
||||
@ -434,7 +539,7 @@ def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
|
||||
# assert ctx.platform is not None
|
||||
def read(v):
|
||||
if type(v) is Literal:
|
||||
return xb.constant_general(ctx.builder, canonicalize_dtype(v.val))
|
||||
return pyval_to_ir_constants(ctx.builder, canonicalize_dtype(v.val))
|
||||
else:
|
||||
return env[v]
|
||||
|
||||
@ -449,7 +554,8 @@ def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
|
||||
env[v] = node
|
||||
|
||||
env: Dict[core.Var, Sequence[XlaOp]] = {}
|
||||
_partitionmap(write, [core.unitvar], _make_unit_constant(ctx.builder))
|
||||
_partitionmap(write, [core.unitvar],
|
||||
pyval_to_ir_constants(ctx.builder, core.unit))
|
||||
_partitionmap(write, jaxpr.constvars, consts)
|
||||
_partitionmap(write, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
@ -645,7 +751,7 @@ def _flatten_shape(s: XlaShape, index: Tuple[int, ...],
|
||||
def _xla_consts(c, consts):
|
||||
unique_consts = {id(const): const for const in consts}
|
||||
xla_consts = {
|
||||
id_: xb.constant_general(c, const) for id_, const in unique_consts.items()}
|
||||
id_: pyval_to_ir_constants(c, const) for id_, const in unique_consts.items()}
|
||||
return [c for const in consts for c in xla_consts[id(const)]]
|
||||
|
||||
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
|
||||
@ -740,6 +846,23 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
out_avals, kept_var_idx)
|
||||
|
||||
|
||||
def compile_or_get_cached(backend, computation, compile_options):
|
||||
# Avoid import cycle between jax and jax.experimental
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
# Persistent compilation cache only implemented on TPU.
|
||||
# TODO(skye): add warning when initializing cache on unsupported default platform
|
||||
if cc.is_initialized() and backend.platform == 'tpu':
|
||||
cached_executable = cc.get_executable(computation, compile_options, backend)
|
||||
if cached_executable is not None:
|
||||
logging.info('Persistent compilation cache hit')
|
||||
return cached_executable
|
||||
else:
|
||||
compiled = backend_compile(backend, computation, compile_options)
|
||||
cc.put_executable(computation, compile_options, compiled, backend)
|
||||
return compiled
|
||||
return backend_compile(backend, computation, compile_options)
|
||||
|
||||
|
||||
class XlaComputation:
|
||||
name: str
|
||||
_is_trivial: bool
|
||||
@ -1183,7 +1306,7 @@ register_translation(xla_call_p, _xla_call_translation_rule)
|
||||
def zeros_like_translation_rule(c, x):
|
||||
shape = c.get_shape(x)
|
||||
assert not shape.is_tuple()
|
||||
zero = xb.constant(c, np.array(0, shape.element_type()))
|
||||
zero = xops.Constant(c, np.array(0, shape.element_type()))
|
||||
return xops.Broadcast(zero, shape.dimensions())
|
||||
translations[ad_util.zeros_like_p] = zeros_like_translation_rule
|
||||
|
||||
@ -1539,9 +1662,9 @@ for device_array in [_CppDeviceArray, _DeviceArray]:
|
||||
canonicalize_dtype_handlers[device_array] = identity
|
||||
|
||||
def _device_array_constant_handler(c, val, canonicalize_types=True):
|
||||
return xb.constant_general(c, val.device_buffer.to_py())
|
||||
xb.register_constant_handler(_DeviceArray, _device_array_constant_handler)
|
||||
xb.register_constant_handler(_CppDeviceArray, _device_array_constant_handler)
|
||||
return pyval_to_ir_constants(c, val.device_buffer.to_py())
|
||||
register_constant_handler(_DeviceArray, _device_array_constant_handler)
|
||||
register_constant_handler(_CppDeviceArray, _device_array_constant_handler)
|
||||
|
||||
def _device_put_device_array(x: Union[DeviceArrayProtocol, _DeviceArray], device: Optional[Device]):
|
||||
x = _copy_device_array_to_device(x, device)
|
||||
@ -1593,7 +1716,7 @@ masking.defvectorized(device_put_p)
|
||||
def _zeros(c, xla_shape):
|
||||
if xla_shape.is_array():
|
||||
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
|
||||
zero = xb.constant(c, np.array(0, dtype=dtype))
|
||||
zero = xops.Constant(c, np.array(0, dtype=dtype))
|
||||
return xops.Broadcast(zero, shape)
|
||||
else:
|
||||
# It is a token
|
||||
@ -1608,10 +1731,10 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
|
||||
Conditional."""
|
||||
# Fake condition which always selects True branch.
|
||||
c = ctx.builder
|
||||
rng = xops.RngUniform(xb.constant(c, np.array(0, dtype=np.float32)),
|
||||
xb.constant(c, np.array(1, dtype=np.float32)),
|
||||
rng = xops.RngUniform(xops.Constant(c, np.array(0, dtype=np.float32)),
|
||||
xops.Constant(c, np.array(1, dtype=np.float32)),
|
||||
xc.Shape.array_shape(xc.PrimitiveType.F32, []))
|
||||
pred = xops.Lt(rng, xb.constant(c, np.array(2, dtype=np.float32)))
|
||||
pred = xops.Lt(rng, xops.Constant(c, np.array(2, dtype=np.float32)))
|
||||
|
||||
true_op = xops.Tuple(c, in_nodes)
|
||||
remat_subc = xc.XlaBuilder("remat_call_subcomputation")
|
||||
@ -1648,22 +1771,22 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
|
||||
dummy_subcomp_outs = jaxpr_subcomp(dummy_ctx, call_jaxpr, (), *dummy_args)
|
||||
out_node_shapes = [dummy_subc.get_shape(o) for o in dummy_subcomp_outs]
|
||||
|
||||
i_init = xb.constant(c, np.array(0, dtype=np.int32))
|
||||
i_init = xops.Constant(c, np.array(0, dtype=np.int32))
|
||||
zeros_like_outs = [_zeros(c, s) for s in out_node_shapes]
|
||||
inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs)
|
||||
|
||||
cond_subc = xc.XlaBuilder("remat_cond_subcomputation")
|
||||
input_op = xb.parameter(cond_subc, 0, c.get_shape(inputs), replicated=[])
|
||||
i = xops.GetTupleElement(input_op, 0)
|
||||
rng = xops.RngUniform(xb.constant(cond_subc, np.array(1, dtype=np.int32)),
|
||||
xb.constant(cond_subc, np.array(2, dtype=np.int32)),
|
||||
rng = xops.RngUniform(xops.Constant(cond_subc, np.array(1, dtype=np.int32)),
|
||||
xops.Constant(cond_subc, np.array(2, dtype=np.int32)),
|
||||
xc.Shape.array_shape(xc.PrimitiveType.S32, []))
|
||||
cond_subc = cond_subc.build(xops.Lt(i, rng))
|
||||
|
||||
body_subc = xc.XlaBuilder("remat_body_subcomputation")
|
||||
input_op = xb.parameter(body_subc, 0, c.get_shape(inputs), replicated=[])
|
||||
i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes)+1]
|
||||
i_next = xops.Add(i, xb.constant(body_subc, np.array(1, dtype=np.int32)))
|
||||
i_next = xops.Add(i, xops.Constant(body_subc, np.array(1, dtype=np.int32)))
|
||||
body_ctx = ctx.replace(
|
||||
builder=body_subc,
|
||||
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'remat')))
|
||||
|
@ -121,8 +121,8 @@ def sparse_array_device_put_handler(a, device):
|
||||
|
||||
def sparse_array_constant_handler(c, val, canonicalize_dtypes):
|
||||
return (
|
||||
xb.constant(val.data, canonicalize_dtypes),
|
||||
xb.constant(val.indices, canonicalize_dtypes)
|
||||
xla.pyval_to_ir_constant(val.data, canonicalize_dtypes),
|
||||
xla.pyval_to_ir_constant(val.indices, canonicalize_dtypes)
|
||||
)
|
||||
|
||||
core.pytype_aval_mappings[SparseArray] = lambda x: x.aval
|
||||
@ -132,7 +132,7 @@ xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
|
||||
xla.device_put_handlers[SparseArray] = sparse_array_device_put_handler
|
||||
xla.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler
|
||||
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
|
||||
xb.register_constant_handler(SparseArray, sparse_array_constant_handler)
|
||||
xla.register_constant_handler(SparseArray, sparse_array_constant_handler)
|
||||
|
||||
|
||||
sp_indices_p = core.Primitive('sp_indices')
|
||||
|
@ -38,8 +38,10 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_bridge
|
||||
|
||||
xops = xla_client.ops
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1761,7 +1763,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
"Consumer ID cannot be a reserved value: 0"):
|
||||
hcb._callback_handler_data.receiver.add_outfeed(
|
||||
comp, token, 0,
|
||||
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
|
||||
[xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))])
|
||||
|
||||
def test_tap_error_different_shapes(self):
|
||||
"""Try to register different shapes for the same consumer ID."""
|
||||
@ -1772,17 +1774,17 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
|
||||
hcb._callback_handler_data.receiver.add_outfeed(
|
||||
comp, token, 123,
|
||||
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
|
||||
[xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))])
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, ".*does not match previous shape element_type.*"):
|
||||
hcb._callback_handler_data.receiver.add_outfeed(
|
||||
comp, token, 123,
|
||||
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))])
|
||||
[xops.Constant(comp, np.zeros((2, 3), dtype=np.int32))])
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, ".*does not match previous shape element_type.*"):
|
||||
hcb._callback_handler_data.receiver.add_outfeed(
|
||||
comp, token, 123,
|
||||
[xla_bridge.constant(comp, np.zeros((2,), dtype=np.float32))])
|
||||
[xops.Constant(comp, np.zeros((2,), dtype=np.float32))])
|
||||
|
||||
def test_tap_id_tap_removed_kwargs(self):
|
||||
def func(x, transforms, y):
|
||||
|
Loading…
x
Reference in New Issue
Block a user