Move xla_bridge.constant to jax.interpreter.xla.pyval_to_ir_constant.

This is a more descriptive name and a better location (next to other facilities for building XLA IR).

Quite a few users of the former xla_bridge.constant() didn't need anything other than uncanonicalized array constants. Change these users to use xla_client.ops.Constant instead; no need for the fancy utility in these cases.

PiperOrigin-RevId: 404270649
This commit is contained in:
Peter Hawkins 2021-10-19 08:40:15 -07:00 committed by jax authors
parent e96f363242
commit 1a73743610
17 changed files with 256 additions and 224 deletions

View File

@ -2033,7 +2033,7 @@
" env: Dict[Var, xe.XlaOp] = {}\n",
"\n",
" def read(x: Atom) -> xe.XlaOp:\n",
" return env[x] if type(x) is Var else xb.constant(c, x.val, False)\n",
" return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))\n",
"\n",
" def write(v: Var, val: xe.XlaOp) -> None:\n",
" env[v] = val\n",

View File

@ -1593,7 +1593,7 @@ def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
env: Dict[Var, xe.XlaOp] = {}
def read(x: Atom) -> xe.XlaOp:
return env[x] if type(x) is Var else xb.constant(c, x.val, False)
return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))
def write(v: Var, val: xe.XlaOp) -> None:
env[v] = val

View File

@ -1587,7 +1587,7 @@ def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
env: Dict[Var, xe.XlaOp] = {}
def read(x: Atom) -> xe.XlaOp:
return env[x] if type(x) is Var else xb.constant(c, x.val, False)
return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))
def write(v: Var, val: xe.XlaOp) -> None:
env[v] = val

View File

@ -818,7 +818,7 @@ def xla_computation(fun: Callable,
out_parts_flat = tuple(flatten_axes(
"xla_computation out_parts", out_tree(), out_parts))
c = xc.XlaBuilder(f"xla_computation_{fun_name}")
xla_consts = map(partial(xb.constant, c), consts)
xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts)
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
xla_args, donated_invars = xla._xla_callable_args(
c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)

View File

@ -344,12 +344,13 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
name_stack=extend_name_stack(ctx.name_stack, 'cond'))
pred, = xla.jaxpr_subcomp(
cond_ctx, cond_jaxpr.jaxpr,
_map(partial(xb.constant, cond_c), cond_jaxpr.consts), *(x + z))
_map(partial(xla.pyval_to_ir_constant, cond_c), cond_jaxpr.consts),
*(x + z))
if batched:
scalar = ShapedArray((), np.bool_)
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, np.array(False))], or_,
list(range(cond_jaxpr.out_avals[0].ndim)))
pred = xops.Reduce(cond_c, [pred], [xops.Constant(cond_c, np.array(False))],
or_, list(range(cond_jaxpr.out_avals[0].ndim)))
body_c = xla_client.XlaBuilder("body_computation")
body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry))
@ -359,15 +360,17 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
name_stack=extend_name_stack(ctx.name_stack, 'body'))
new_z = xla.jaxpr_subcomp(
body_ctx, body_jaxpr.jaxpr,
_map(partial(xb.constant, body_c), body_jaxpr.consts),
_map(partial(xla.pyval_to_ir_constant, body_c), body_jaxpr.consts),
*(y + z))
if batched:
body_pred_ctx = body_ctx.replace(
name_stack=extend_name_stack(ctx.name_stack, 'body_pred'))
body_pred, = xla.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr,
_map(partial(xb.constant, body_c), cond_jaxpr.consts), *(x + z))
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z, body_jaxpr.out_avals)
_map(partial(xla.pyval_to_ir_constant, body_c), cond_jaxpr.consts),
*(x + z))
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z,
body_jaxpr.out_avals)
assert _map(body_c.get_shape, new_z) == _map(body_c.get_shape, z) # no broadcast
new_carry = xops.Tuple(body_c, [*x, *y, *new_z])
@ -806,8 +809,9 @@ def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
subctx = ctx.replace(
builder=c, name_stack=extend_name_stack(name_stack, name + '_fun'))
outs = xla.jaxpr_subcomp(subctx, jaxpr.jaxpr,
_map(partial(xb.constant, c), jaxpr.consts), *ops)
outs = xla.jaxpr_subcomp(
subctx, jaxpr.jaxpr,
_map(partial(xla.pyval_to_ir_constant, c), jaxpr.consts), *ops)
return c.build(xops.Tuple(c, outs))
c = ctx.builder

View File

@ -2460,10 +2460,12 @@ def _sign_translation_rule(ctx, avals_in, avals_out, x):
x_aval, = avals_in
dtype = x_aval.dtype
if dtypes.issubdtype(dtype, np.unsignedinteger):
zero = xb.constant(c, np.array(0, dtype=dtype))
return [xops.Select(xops.Eq(x, zero), xops.Broadcast(zero, x_aval.shape),
xops.Broadcast(xb.constant(c, np.array(1, dtype=dtype)),
x_aval.shape))]
zero = xops.Constant(c, np.array(0, dtype=dtype))
return [xops.Select(
xops.Eq(x, zero),
xops.Broadcast(zero, x_aval.shape),
xops.Broadcast(xops.Constant(c, np.array(1, dtype=dtype)),
x_aval.shape))]
return [xops.Sign(x)]
sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule)
@ -2764,7 +2766,8 @@ else:
return [xops.Mul(
xops.Sign(x),
xops.Pow(xops.Abs(x),
xb.constant(ctx.builder, np.array(1/3, dtype=x_aval.dtype))))]
xla.pyval_to_ir_constant(ctx.builder,
np.array(1/3, dtype=x_aval.dtype))))]
cbrt_p = standard_unop(_float, 'cbrt',
translation_rule=_cbrt_translation_rule)
@ -2794,7 +2797,7 @@ def _integer_pow_translation_rule(ctx, avals_in, avals_out, x, *, y):
# This should be kept in sync with the jax2tf translation rule.
x_aval, = avals_in
if y == 0:
one = xb.constant(ctx.builder, np.array(1, dtype=x_aval.dtype))
one = xla.pyval_to_ir_constant(ctx.builder, np.array(1, dtype=x_aval.dtype))
return [xops.Broadcast(one, x_aval.shape)]
is_reciprocal = y < 0
if is_reciprocal:
@ -4752,8 +4755,8 @@ def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *,
upper_bound = operand_dims[intarray(dnums.start_index_map)]
upper_bound -= intarray(slice_sizes)[intarray(dnums.start_index_map)]
mask = xops.And(xops.Ge(indices, xb.constant(c, intarray(0))),
xops.Le(indices, xb.constant(c, upper_bound),
mask = xops.And(xops.Ge(indices, xla.pyval_to_ir_constant(c, intarray(0))),
xops.Le(indices, xla.pyval_to_ir_constant(c, upper_bound),
broadcast_dimensions=[num_batch_dims]))
# Compute the conjunction of the mask elements across the dimensions in which
@ -4762,8 +4765,8 @@ def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *,
scalar_pred = xla_client.Shape.array_shape(np.dtype(np.bool_), ())
xops.And(xb.parameter(and_builder, 0, scalar_pred),
xb.parameter(and_builder, 1, scalar_pred))
mask = xops.Reduce(c, [mask], [xb.constant(c, True)], and_builder.build(),
[num_batch_dims])
mask = xops.Reduce(c, [mask], [xla.pyval_to_ir_constant(c, True)],
and_builder.build(), [num_batch_dims])
# Computes the output shape and the positions of the batch dimensions in the
# output
@ -4778,7 +4781,7 @@ def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *,
xops.Gather(operand, indices, dimensions, slice_sizes,
indices_are_sorted=indices_are_sorted),
xops.Broadcast(
xb.constant(c, np.array(fill_value, operand_aval.dtype)),
xla.pyval_to_ir_constant(c, np.array(fill_value, operand_aval.dtype)),
aval_out.shape))]
def _gather_jvp_rule(g, operand, indices, *, dimension_numbers,
@ -5023,8 +5026,9 @@ def _clamp_scatter_indices(c, indices, operand_shape, updates_shape, dnums):
upper_bound -= intarray(slice_sizes)[intarray(dnums.scatter_dims_to_operand_dims)]
upper_bound = np.minimum(upper_bound, np.iinfo(indices_dtype).max)
return xops.Min(
xops.Max(xb.constant(c, np.array(0, dtype=indices_dtype)), indices),
xb.constant(c, upper_bound.astype(indices_dtype)),
xops.Max(xla.pyval_to_ir_constant(c, np.array(0, dtype=indices_dtype)),
indices),
xla.pyval_to_ir_constant(c, upper_bound.astype(indices_dtype)),
broadcast_dimensions=[len(indices_shape.dimensions()) - 1])
def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices,
@ -5037,7 +5041,7 @@ def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices,
indices = _clamp_scatter_indices(c, indices, operand_aval.shape,
updates_aval.shape, dimension_numbers)
init_value = xb.constant(c, np.array(0, operand_aval.dtype))
init_value = xla.pyval_to_ir_constant(c, np.array(0, operand_aval.dtype))
update_computation = _reduction_computation(
c, update_jaxpr, update_consts, init_value)
return [xops.Scatter(
@ -5690,7 +5694,7 @@ def _reduce_sum_translation_rule(ctx, avals_in, avals_out, operand, *, axes):
scalar = ShapedArray((), operand_aval.dtype)
return [xops.Reduce(
ctx.builder, [operand],
[xb.constant(ctx.builder, np.array(0, operand_aval.dtype))],
[xla.pyval_to_ir_constant(ctx.builder, np.array(0, operand_aval.dtype))],
xla.primitive_subcomputation(add_p, scalar, scalar), axes)]
def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
@ -5723,7 +5727,7 @@ def _reduce_prod_translation_rule(ctx, avals_in, avals_out, operand, *, axes):
scalar = ShapedArray((), operand_aval.dtype)
return [xops.Reduce(
ctx.builder, [operand],
[xb.constant(ctx.builder, np.array(1, operand_aval.dtype))],
[xla.pyval_to_ir_constant(ctx.builder, np.array(1, operand_aval.dtype))],
xla.primitive_subcomputation(mul_p, scalar, scalar), axes)]
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
@ -5749,7 +5753,8 @@ def _reduce_chooser_translation_rule(prim, identity, ctx, avals_in, avals_out,
operand_aval, = avals_in
scalar = ShapedArray((), operand_aval.dtype)
return [xops.Reduce(ctx.builder, [operand],
[xb.constant(ctx.builder, identity(operand_aval.dtype))],
[xla.pyval_to_ir_constant(ctx.builder,
identity(operand_aval.dtype))],
xla.primitive_subcomputation(prim, scalar, scalar), axes)]
def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
@ -5872,7 +5877,7 @@ def _reduce_logical_translation_rule(prim, identity, ctx, avals_in, avals_out,
operand, *, axes):
scalar = ShapedArray((), np.bool_)
return [xops.Reduce(ctx.builder, [operand],
[xb.constant(ctx.builder, identity(np.bool_))],
[xla.pyval_to_ir_constant(ctx.builder, identity(np.bool_))],
xla.primitive_subcomputation(prim, scalar, scalar), axes)]
_reduce_or_translation_rule = partial(_reduce_logical_translation_rule,
@ -5958,7 +5963,8 @@ def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *,
operand_aval, = avals_in
scalar = ShapedArray((), operand_aval.dtype)
return [xops.ReduceWindowWithGeneralPadding(
operand, xb.constant(ctx.builder, np.array(0, operand_aval.dtype)),
operand,
xla.pyval_to_ir_constant(ctx.builder, np.array(0, operand_aval.dtype)),
xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions,
window_strides, base_dilation, window_dilation, padding)]
@ -6012,7 +6018,8 @@ def _reduce_window_chooser_translation_rule(
operand_aval, = avals_in
scalar = ShapedArray((), operand_aval.dtype)
return [xops.ReduceWindowWithGeneralPadding(
operand, xb.constant(ctx.builder, identity(operand_aval.dtype)),
operand,
xla.pyval_to_ir_constant(ctx.builder, identity(operand_aval.dtype)),
xla.primitive_subcomputation(prim, scalar, scalar), window_dimensions,
window_strides, base_dilation, window_dilation, padding)]
@ -6162,7 +6169,7 @@ def _select_and_scatter_add_translation(
select = xla.primitive_subcomputation(select_prim, scalar, scalar)
scatter = xla.primitive_subcomputation(or_p if dtype == np.bool_ else add_p,
scalar, scalar)
zero = xb.constant(c, np.array(0, dtype))
zero = xla.pyval_to_ir_constant(c, np.array(0, dtype))
# TODO(b/161704903): remove this workaround when XLA:CPU bug is fixed.
expand_padding = (expand_padding and
not all(lo == 0 and hi == 0 for (lo, hi) in padding))
@ -6171,7 +6178,7 @@ def _select_and_scatter_add_translation(
identity = (_get_max_identity if select_prim is ge_p
else _get_min_identity)
pads = [(lo, hi, 0) for (lo, hi) in padding]
operand = xops.Pad(operand, xb.constant(c, identity(dtype)),
operand = xops.Pad(operand, xla.pyval_to_ir_constant(c, identity(dtype)),
xc.make_padding_config(pads))
padding = [(0, 0) for _ in padding]
output = xops.SelectAndScatterWithGeneralPadding(
@ -6287,8 +6294,7 @@ def _select_and_gather_add_translation(
assert nbits <= max_bits
double_word_reduction = nbits * 2 <= max_bits
const = lambda c, dtype, x: xb.constant(c, np.array(x, dtype=dtype),
canonicalize_types=False)
const = lambda c, dtype, x: xops.Constant(c, np.array(x, dtype=dtype))
if double_word_reduction:
# TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
@ -6381,8 +6387,7 @@ def _select_and_gather_add_translation_using_variadic_reducewindow(
tangents_aval, operand_aval = avals_in
dtype = operand_aval.dtype
const = lambda c, dtype, x: xb.constant(c, np.array(x, dtype=dtype),
canonicalize_types=False)
const = lambda c, dtype, x: xops.Constant(c, np.array(x, dtype=dtype))
def reducer():
c = xc.XlaBuilder("select_and_gather_pair_reducer")
@ -6870,26 +6875,28 @@ def _rng_bit_generator_translation_rule(
def _convert_4xU32_to_2xU64_without_bitcast(c, key):
u64_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint64'))
new_key = xb.constant(c, np.zeros(2, dtype=np.dtype('uint64')),
canonicalize_types=False)
_32 = xb.constant(c, np.uint64(32), canonicalize_types=False)
new_key = xops.Constant(c, np.zeros(2, dtype=np.dtype('uint64')))
_32 = xops.Constant(c, np.array(32, np.uint64))
for i in [0, 2]:
hi = xops.ConvertElementType(xops.Slice(key, [i] , [i+1], [1]), u64_etype)
lo = xops.ConvertElementType(xops.Slice(key, [i+1], [i+2], [1]), u64_etype)
elt = xops.Xor(xops.ShiftLeft(hi, _32), lo)
new_key = xops.DynamicUpdateSlice(new_key, elt, [xb.constant(c, i // 2)])
new_key = xops.DynamicUpdateSlice(new_key, elt,
[xla.pyval_to_ir_constant(c, i // 2)])
return new_key
def _convert_2xU64_to_4xU32_without_bitcast(c, key):
u32_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint32'))
new_key = xb.constant(c, np.zeros(4, dtype=np.dtype('uint32')))
_32 = xb.constant(c, np.uint64(32), canonicalize_types=False)
new_key = xops.Constant(c, np.zeros(4, dtype=np.dtype('uint32')))
_32 = xops.Constant(c, np.array(32, np.uint64))
for i in [0, 1]:
elt = xops.Slice(key, [i], [i+1], [1])
hi = xops.ConvertElementType(xops.ShiftRightLogical(elt, _32), u32_etype)
lo = xops.ConvertElementType(elt, u32_etype)
new_key = xops.DynamicUpdateSlice(new_key, hi, [xb.constant(c, 2 * i)])
new_key = xops.DynamicUpdateSlice(new_key, lo, [xb.constant(c, 2 * i + 1)])
new_key = xops.DynamicUpdateSlice(new_key, hi,
[xla.pyval_to_ir_constant(c, 2 * i)])
new_key = xops.DynamicUpdateSlice(new_key, lo,
[xla.pyval_to_ir_constant(c, 2 * i + 1)])
return new_key
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):

View File

@ -41,7 +41,6 @@ from jax._src.lib import cusparse
from jax._src.lib import rocsolver
from jax._src.lib import xla_client
from jax._src.lib import xla_bridge as xb
from jax._src.lib import version as jaxlib_version
xops = xla_client.ops
@ -346,9 +345,9 @@ def _nan_like(c, operand):
shape = c.get_shape(operand)
dtype = shape.element_type()
if jnp.issubdtype(dtype, np.complexfloating):
nan = xb.constant(c, np.array(np.nan * (1. + 1j), dtype=dtype))
nan = xops.Constant(c, np.array(np.nan * (1. + 1j), dtype=dtype))
else:
nan = xb.constant(c, np.array(np.nan, dtype=dtype))
nan = xops.Constant(c, np.array(np.nan, dtype=dtype))
return xops.Broadcast(nan, shape.dimensions())
def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand):
@ -696,7 +695,7 @@ def _triangular_solve_cpu_translation_rule(
conjugate_a = False
if len(shape.dimensions()) == 2 and np.dtype(dtype) in _cpu_lapack_types:
return lapack.jax_trsm(
c, xb.constant(c, np.array(1, dtype=dtype)),
c, xops.Constant(c, np.array(1, dtype=dtype)),
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)
else:
# Fall back to the HLO implementation for unsupported types or batching.

View File

@ -32,7 +32,6 @@ from jax.interpreters import pxla
from jax.interpreters import batching
from jax._src import dtypes
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.numpy import lax_numpy
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, moveaxis
@ -1321,9 +1320,10 @@ def _build_axis_index_lowering(c, axis_name, axis_env):
axis_name, = axis_name
axis_pos = list(axis_env.names).index(axis_name)
nreplicas = axis_env.nreps // prod(axis_env.sizes)
div = xb.constant(c, np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]),
dtype=np.uint32))
mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
div = xops.Constant(c,
np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]),
dtype=np.uint32))
mod = xops.Constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(
unsigned_index, xla.dtype_to_primitive_type(np.dtype(np.int32)))

View File

@ -22,7 +22,7 @@ XLA. There are also a handful of related casting utilities.
from functools import partial, lru_cache
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import warnings
from absl import logging
@ -34,7 +34,6 @@ from jax._src.config import flags, bool_env
from . import tpu_driver_client
from . import xla_client
from jax._src import util, traceback_util
from jax._src import dtypes
import numpy as np
import threading
@ -427,20 +426,6 @@ def host_ids(backend=None):
### utility functions
# TODO(mattjj,frostig): try to remove this function
def normalize_to_xla_dtypes(val):
"""Normalize dtypes in a value."""
if hasattr(val, '__array__') or np.isscalar(val):
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
elif isinstance(val, (tuple, list)):
return tuple(normalize_to_xla_dtypes(x) for x in val)
raise TypeError('Can\'t convert to XLA: {}'.format(val))
def _numpy_array_constant(builder, value, canonicalize_types=True):
if canonicalize_types:
value = normalize_to_xla_dtypes(value)
return [xops.ConstantLiteral(builder, value)]
def parameter(builder, num, shape, name=None, replicated=None):
if name is None:
name = ''
@ -453,36 +438,6 @@ def parameter(builder, num, shape, name=None, replicated=None):
shape.with_major_to_minor_layout_if_absent(), name,
replicated)
def constant_general(builder, py_val, canonicalize_types=True):
"""Translate a general constant `py_val` to a constant, canonicalizing its dtype.
Args:
py_val: a Python value to be translated to a constant.
Returns:
A representation of the constant as a list of xla ops.
"""
for t in type(py_val).mro():
handler = _constant_handlers.get(t)
if handler: return handler(builder, py_val, canonicalize_types)
if hasattr(py_val, '__jax_array__'):
return constant(builder, py_val.__jax_array__(), canonicalize_types)
raise TypeError("No constant handler for type: {}".format(type(py_val)))
def constant(builder, py_val, canonicalize_types=True):
"""Translate constant `py_val` to a constant, canonicalizing its dtype.
Args:
py_val: a Python value to be translated to a constant.
Returns:
A representation of the constant, either a ComputationDataHandle or None
"""
const = constant_general(builder, py_val, canonicalize_types=canonicalize_types)
assert len(const) == 1, f"Internal error: cannot create constant from object of type {type(py_val)}"
return const[0]
# HLO instructions optionally can be annotated to say how the output should be
# spatially partitioned (represented in XLA as OpSharding protos, see
# _sharding_to_proto). For array outputs, the annotation is either an int per
@ -545,65 +500,3 @@ def set_sharding(builder, op, sharding: SpatialSharding):
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
def register_constant_handler(type_, handler_fun):
_constant_handlers[type_] = handler_fun
_constant_handlers: Dict[type, Callable] = {}
def _ndarray_constant_handler(c, val, canonicalize_types=True):
"""Constant handler for ndarray literals, handling zero-size strides.
This function essentially calls _numpy_array_constant(val) except it has
special handling of arrays with any strides of size zero: for those, it
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
to avoid staging in large literals that might arise from np.zeros or np.ones
or the output of lax.broadcast (which uses np.broadcast_to which in turn
uses size-zero strides).
Args:
c: an XlaBuilder
val: an ndarray.
Returns:
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
staged into the XLA Computation.
"""
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
if dtypes.result_type(val) == dtypes.float0:
return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_))
elif np.any(np.equal(0, val.strides)) and val.size > 0:
zero_stride_axes, = np.where(np.equal(0, val.strides))
other_axes, = np.where(np.not_equal(0, val.strides))
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
for ax in range(val.ndim))]
xla_val = xops.Broadcast(
_numpy_array_constant(c, collapsed_val, canonicalize_types)[0],
np.take(val.shape, zero_stride_axes))
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
return [xops.Transpose(xla_val, permutation)]
else:
return _numpy_array_constant(c, val, canonicalize_types)
register_constant_handler(np.ndarray, _ndarray_constant_handler)
def _scalar_constant_handler(c, val, canonicalize_types=True):
return _numpy_array_constant(c, val, canonicalize_types)
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64,
np.bool_, np.longlong,
xla_client.bfloat16]:
register_constant_handler(scalar_type, _scalar_constant_handler)
# https://github.com/winpython/winpython/issues/613#issuecomment-380121523
if hasattr(np, "float128"):
register_constant_handler(np.float128, _scalar_constant_handler)
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
return _numpy_array_constant(c, dtype.type(val))
for ptype, dtype in dtypes.python_scalar_dtypes.items():
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))

View File

@ -28,7 +28,6 @@ from jax.dtypes import float0
from jax.interpreters import batching
from jax.interpreters import xla
from jax._src.api import jit, vmap
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import cuda_prng
import jax._src.pretty_printer as pp
@ -358,7 +357,7 @@ def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2):
rank = len(shape)
if 0 in shape:
zeros = xla_client.ops.Broadcast(
xla_bridge.constant(c, np.array(0, np.uint32)), shape)
xla_client.ops.Constant(c, np.array(0, np.uint32)), shape)
return xla_client.ops.Tuple(c, [zeros, zeros])
def _broadcast(x):
ndims = c.get_shape(x).rank()

View File

@ -643,7 +643,8 @@ def _make_params(c, dim_in_avals, in_avals):
def _xla_consts(c, consts):
unique_consts = {id(const): const for const in consts}
xla_consts = {
id_: [xb.constant(c, const)] for id_, const in unique_consts.items()}
id_: [xla.pyval_to_ir_constant(c, const)]
for id_, const in unique_consts.items()}
return [xla_consts[id(const)] for const in consts]
def djaxpr_subcomp(c, jaxpr, dim_args, args):
@ -654,7 +655,7 @@ def djaxpr_subcomp(c, jaxpr, dim_args, args):
def read(v):
if type(v) is core.Literal:
return [xb.constant(c, xla.canonicalize_dtype(v.val))]
return [xla.pyval_to_ir_constant(c, xla.canonicalize_dtype(v.val))]
else:
return env[v]

View File

@ -1360,20 +1360,21 @@ def _xmap_translation_rule_replica(c, axis_env,
return xops.Tuple(c, outs)
def _xla_tile_base_indices(c, axis_env, tile_shape, axes, axis_sizes):
zero = xb.constant(c, np.zeros((), dtype=np.int32))
zero = xops.Constant(c, np.zeros((), dtype=np.int32))
linear_idxs = [zero] * len(tile_shape)
strides = [1] * len(tile_shape)
for name, axis in reversed(axes.items()):
axis_index = _build_axis_index_lowering(
c, axis_name=name, axis_env=axis_env)
stride_c = xb.constant(c, np.array(strides[axis], np.int32))
stride_c = xops.Constant(c, np.array(strides[axis], np.int32))
if linear_idxs[axis] is zero and strides[axis] == 1:
linear_idxs[axis] = axis_index
else:
linear_idxs[axis] = xops.Add(linear_idxs[axis], xops.Mul(axis_index, stride_c))
strides[axis] *= axis_sizes[name]
return [zero if linear_idx is zero else
xops.Mul(linear_idx, xb.constant(c, np.array(tile_dim_size, np.int32)))
xops.Mul(linear_idx,
xops.Constant(c, np.array(tile_dim_size, np.int32)))
for linear_idx, tile_dim_size in zip(linear_idxs, tile_shape)]
def _xla_tile(c, axis_env, x, in_axes, axis_sizes):
@ -1405,7 +1406,7 @@ def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend):
shape[axis] *= axis_sizes[name]
base_idxs = _xla_tile_base_indices(c, axis_env, tile_shape, out_axes, axis_sizes)
padded = xops.Broadcast(xb.constant(c, np.array(0, x_dtype)), shape)
padded = xops.Broadcast(xops.Constant(c, np.array(0, x_dtype)), shape)
padded = xops.DynamicUpdateSlice(padded, x, base_idxs)
replica_groups_protos = xc.make_replica_groups(
xla.axis_groups(axis_env, tuple(out_axes.keys())))
@ -1413,7 +1414,7 @@ def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend):
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32)))
nonzero = xops.Ne(out, xops.Constant(c, np.array(0, dtype=np.float32)))
out = xops.ConvertElementType(
nonzero, xla.dtype_to_primitive_type(np.dtype(np.bool_)))
return out

View File

@ -686,12 +686,13 @@ def _shard_sharded_device_array_slow_path(x, devices, indices):
def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
return xb.constant_general(c, np.asarray(val), canonicalize_types=canonicalize_types)
return xla.pyval_to_ir_constants(c, np.asarray(val),
canonicalize_types=canonicalize_types)
def _register_handlers_for_sharded_device_array(sda):
shard_arg_handlers[sda] = _shard_sharded_device_array_slow_path
xb.register_constant_handler(sda, _sharded_device_array_constant_handler)
xla.register_constant_handler(sda, _sharded_device_array_constant_handler)
core.pytype_aval_mappings[sda] = ConcreteArray
xla.device_put_handlers[sda] = xla._device_put_array
@ -876,7 +877,7 @@ def parallel_callable(fun: lu.WrappedFun,
tuple_args = len(global_sharded_avals) > 100 # pass long arg lists as tuple for TPU
c = xc.XlaBuilder("pmap_{}".format(fun.__name__))
xla_consts = map(partial(xb.constant, c), consts)
xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts)
replicated_args = [axis is None for axis in in_axes]
xla_args, donated_invars = xla._xla_callable_args(c, global_sharded_avals, tuple_args,
replicated=replicated_args,
@ -1317,7 +1318,7 @@ def _xla_shard(c, aval, axis_env, x, in_axis):
return x
elif isinstance(aval, ShapedArray):
dims = list(c.get_shape(x).dimensions())
zero = xb.constant(c, np.zeros((), dtype=np.uint32))
zero = xops.Constant(c, np.zeros((), dtype=np.uint32))
idxs = [zero] * (len(dims) - 1)
idxs.insert(in_axis, _unravel_index(c, axis_env))
dims_unsqueezed = dims.copy()
@ -1344,9 +1345,10 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend):
xla_shape = c.get_shape(x)
dims = list(xla_shape.dimensions())
padded = xops.Broadcast(xb.constant(c, np.array(0, xla_shape.numpy_dtype())),
[axis_env.sizes[-1]] + dims)
zero = xb.constant(c, np.zeros((), dtype=np.uint32))
padded = xops.Broadcast(
xops.Constant(c, np.array(0, xla_shape.numpy_dtype())),
[axis_env.sizes[-1]] + dims)
zero = xops.Constant(c, np.zeros((), dtype=np.uint32))
idxs = [_unravel_index(c, axis_env)] + [zero] * len(dims)
padded = xops.DynamicUpdateSlice(padded, xops.Reshape(x, [1] + dims), idxs)
replica_groups_protos = xc.make_replica_groups(
@ -1360,7 +1362,7 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend):
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32)))
nonzero = xops.Ne(out, xops.Constant(c, np.array(0, dtype=np.float32)))
out = xops.ConvertElementType(
nonzero, xla.dtype_to_primitive_type(np.dtype(np.bool_)))
return out
@ -1368,8 +1370,9 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend):
raise TypeError((aval, c.get_shape(x)))
def _unravel_index(c, axis_env):
div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes), np.uint32))
mod = xb.constant(c, np.array(axis_env.sizes[-1], np.uint32))
div = xops.Constant(c, np.array(axis_env.nreps // prod(axis_env.sizes),
np.uint32))
mod = xops.Constant(c, np.array(axis_env.sizes[-1], np.uint32))
return xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
# ------------------- xmap -------------------
@ -1597,7 +1600,7 @@ def lower_mesh_computation(
# 3. Build up the HLO
c = xc.XlaBuilder(f"xmap_{fun.__name__}")
xla_consts = map(partial(xb.constant, c), consts)
xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts)
tuple_args = len(in_jaxpr_avals) > 100 # pass long arg lists as tuple for TPU
in_partitions: Optional[List]
if spmd_lowering:

View File

@ -139,7 +139,7 @@ def _sharded_callable(
fun.__name__, nparts, global_abstract_args)
c = xc.XlaBuilder("spjit_{}".format(fun.__name__))
xla_consts = _map(partial(xb.constant, c), consts)
xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts)
xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
axis_env = xla.AxisEnv(nrep, (), ())
ctx = xla.TranslationContext(

View File

@ -70,28 +70,13 @@ XlaExecutable = xc.Executable
_on_exit = False
def compile_or_get_cached(backend, computation, compile_options):
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc
# Persistent compilation cache only implemented on TPU.
# TODO(skye): add warning when initializing cache on unsupported default platform
if cc.is_initialized() and backend.platform == 'tpu':
cached_executable = cc.get_executable(computation, compile_options, backend)
if cached_executable is not None:
logging.info('Persistent compilation cache hit')
return cached_executable
else:
compiled = backend_compile(backend, computation, compile_options)
cc.put_executable(computation, compile_options, compiled, backend)
return compiled
return backend_compile(backend, computation, compile_options)
def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()
# unit representation
def _make_unit_constant(c): return xb.constant_general(c, np.zeros((), dtype=np.dtype('bool')))
def _make_unit_constant(c): return [
xops.Constant(c, np.zeros((), dtype=np.dtype('bool')))]
def _make_unit_shape(_): return (xc.Shape.array_shape(np.dtype('bool'), ()),)
def _device_put_unit(_, device):
backend = xb.get_device_backend(device)
@ -128,6 +113,8 @@ def make_op_metadata(primitive: core.Primitive,
### handlers
# Numpy dtypes -> XLA primitive types
_dtype_to_primitive_type: Dict[np.dtype, xc.PrimitiveType] = {
np.dtype('bool'): xc.PrimitiveType.PRED,
np.dtype('int8'): xc.PrimitiveType.S8,
@ -156,7 +143,8 @@ def dtype_to_primitive_type(dtype: np.dtype) -> xc.PrimitiveType:
except KeyError as err:
raise TypeError(f"No XLA lowering for NumPy dtype: {dtype}") from err
xb.register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c))
# JAX abstract values -> XLA shapes
def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[XlaShape]:
try:
@ -171,6 +159,123 @@ xla_shape_handlers: Dict[Type[core.AbstractValue],
ConcreteArray: _make_array_shape,
}
# IR constants
_constant_handlers: Dict[type, Callable] = {}
def pyval_to_ir_constants(builder, py_val, canonicalize_types=True):
"""Translate a general constant `py_val` to a constant, canonicalizing its dtype.
Args:
py_val: a Python value to be translated to a constant.
Returns:
A representation of the constant as a list of xla ops.
"""
for t in type(py_val).mro():
handler = _constant_handlers.get(t)
if handler: return handler(builder, py_val, canonicalize_types)
if hasattr(py_val, '__jax_array__'):
return pyval_to_ir_constants(builder, py_val.__jax_array__(),
canonicalize_types)
raise TypeError("No constant handler for type: {}".format(type(py_val)))
def pyval_to_ir_constant(builder, py_val, canonicalize_types=True):
"""Translate constant `py_val` to a constant, canonicalizing its dtype.
Args:
py_val: a Python value to be translated to a constant.
Returns:
A representation of the constant, either a ComputationDataHandle or None
"""
const = pyval_to_ir_constants(builder, py_val, canonicalize_types=canonicalize_types)
assert len(const) == 1, f"Internal error: cannot create constant from object of type {type(py_val)}"
return const[0]
def register_constant_handler(type_, handler_fun):
_constant_handlers[type_] = handler_fun
register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c))
# TODO(mattjj,frostig): try to remove this function
def _normalize_to_xla_dtypes(val):
"""Normalize dtypes in a value."""
if hasattr(val, '__array__') or np.isscalar(val):
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
elif isinstance(val, (tuple, list)):
return tuple(_normalize_to_xla_dtypes(x) for x in val)
raise TypeError('Can\'t convert to XLA: {}'.format(val))
def _numpy_array_constant(builder, value, canonicalize_types=True):
if canonicalize_types:
value = _normalize_to_xla_dtypes(value)
return [xops.Constant(builder, value)]
def _ndarray_constant_handler(c, val, canonicalize_types=True):
"""Constant handler for ndarray literals, handling zero-size strides.
This function essentially calls _numpy_array_constant(val) except it has
special handling of arrays with any strides of size zero: for those, it
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
to avoid staging in large literals that might arise from np.zeros or np.ones
or the output of lax.broadcast (which uses np.broadcast_to which in turn
uses size-zero strides).
Args:
c: an XlaBuilder
val: an ndarray.
Returns:
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
staged into the XLA Computation.
"""
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
if dtypes.result_type(val) == dtypes.float0:
return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_))
elif np.any(np.equal(0, val.strides)) and val.size > 0:
zero_stride_axes, = np.where(np.equal(0, val.strides))
other_axes, = np.where(np.not_equal(0, val.strides))
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
for ax in range(val.ndim))]
xla_val = xops.Broadcast(
_numpy_array_constant(c, collapsed_val, canonicalize_types)[0],
np.take(val.shape, zero_stride_axes))
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
return [xops.Transpose(xla_val, permutation)]
else:
return _numpy_array_constant(c, val, canonicalize_types)
register_constant_handler(np.ndarray, _ndarray_constant_handler)
def _scalar_constant_handler(c, val, canonicalize_types=True):
return _numpy_array_constant(c, val, canonicalize_types)
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64,
np.bool_, np.longlong,
dtypes.bfloat16]:
register_constant_handler(scalar_type, _scalar_constant_handler)
# https://github.com/winpython/winpython/issues/613#issuecomment-380121523
if hasattr(np, "float128"):
register_constant_handler(np.float128, _scalar_constant_handler)
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
return _numpy_array_constant(c, dtype.type(val))
for ptype, dtype in dtypes.python_scalar_dtypes.items():
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
# Result handlers
def aval_to_result_handler(device: Optional[Device],
aval: core.AbstractValue) -> Callable:
try:
@ -434,7 +539,7 @@ def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
# assert ctx.platform is not None
def read(v):
if type(v) is Literal:
return xb.constant_general(ctx.builder, canonicalize_dtype(v.val))
return pyval_to_ir_constants(ctx.builder, canonicalize_dtype(v.val))
else:
return env[v]
@ -449,7 +554,8 @@ def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
env[v] = node
env: Dict[core.Var, Sequence[XlaOp]] = {}
_partitionmap(write, [core.unitvar], _make_unit_constant(ctx.builder))
_partitionmap(write, [core.unitvar],
pyval_to_ir_constants(ctx.builder, core.unit))
_partitionmap(write, jaxpr.constvars, consts)
_partitionmap(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
@ -645,7 +751,7 @@ def _flatten_shape(s: XlaShape, index: Tuple[int, ...],
def _xla_consts(c, consts):
unique_consts = {id(const): const for const in consts}
xla_consts = {
id_: xb.constant_general(c, const) for id_, const in unique_consts.items()}
id_: pyval_to_ir_constants(c, const) for id_, const in unique_consts.items()}
return [c for const in consts for c in xla_consts[id(const)]]
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
@ -740,6 +846,23 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
out_avals, kept_var_idx)
def compile_or_get_cached(backend, computation, compile_options):
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc
# Persistent compilation cache only implemented on TPU.
# TODO(skye): add warning when initializing cache on unsupported default platform
if cc.is_initialized() and backend.platform == 'tpu':
cached_executable = cc.get_executable(computation, compile_options, backend)
if cached_executable is not None:
logging.info('Persistent compilation cache hit')
return cached_executable
else:
compiled = backend_compile(backend, computation, compile_options)
cc.put_executable(computation, compile_options, compiled, backend)
return compiled
return backend_compile(backend, computation, compile_options)
class XlaComputation:
name: str
_is_trivial: bool
@ -1183,7 +1306,7 @@ register_translation(xla_call_p, _xla_call_translation_rule)
def zeros_like_translation_rule(c, x):
shape = c.get_shape(x)
assert not shape.is_tuple()
zero = xb.constant(c, np.array(0, shape.element_type()))
zero = xops.Constant(c, np.array(0, shape.element_type()))
return xops.Broadcast(zero, shape.dimensions())
translations[ad_util.zeros_like_p] = zeros_like_translation_rule
@ -1539,9 +1662,9 @@ for device_array in [_CppDeviceArray, _DeviceArray]:
canonicalize_dtype_handlers[device_array] = identity
def _device_array_constant_handler(c, val, canonicalize_types=True):
return xb.constant_general(c, val.device_buffer.to_py())
xb.register_constant_handler(_DeviceArray, _device_array_constant_handler)
xb.register_constant_handler(_CppDeviceArray, _device_array_constant_handler)
return pyval_to_ir_constants(c, val.device_buffer.to_py())
register_constant_handler(_DeviceArray, _device_array_constant_handler)
register_constant_handler(_CppDeviceArray, _device_array_constant_handler)
def _device_put_device_array(x: Union[DeviceArrayProtocol, _DeviceArray], device: Optional[Device]):
x = _copy_device_array_to_device(x, device)
@ -1593,7 +1716,7 @@ masking.defvectorized(device_put_p)
def _zeros(c, xla_shape):
if xla_shape.is_array():
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
zero = xb.constant(c, np.array(0, dtype=dtype))
zero = xops.Constant(c, np.array(0, dtype=dtype))
return xops.Broadcast(zero, shape)
else:
# It is a token
@ -1608,10 +1731,10 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr):
Conditional."""
# Fake condition which always selects True branch.
c = ctx.builder
rng = xops.RngUniform(xb.constant(c, np.array(0, dtype=np.float32)),
xb.constant(c, np.array(1, dtype=np.float32)),
rng = xops.RngUniform(xops.Constant(c, np.array(0, dtype=np.float32)),
xops.Constant(c, np.array(1, dtype=np.float32)),
xc.Shape.array_shape(xc.PrimitiveType.F32, []))
pred = xops.Lt(rng, xb.constant(c, np.array(2, dtype=np.float32)))
pred = xops.Lt(rng, xops.Constant(c, np.array(2, dtype=np.float32)))
true_op = xops.Tuple(c, in_nodes)
remat_subc = xc.XlaBuilder("remat_call_subcomputation")
@ -1648,22 +1771,22 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr):
dummy_subcomp_outs = jaxpr_subcomp(dummy_ctx, call_jaxpr, (), *dummy_args)
out_node_shapes = [dummy_subc.get_shape(o) for o in dummy_subcomp_outs]
i_init = xb.constant(c, np.array(0, dtype=np.int32))
i_init = xops.Constant(c, np.array(0, dtype=np.int32))
zeros_like_outs = [_zeros(c, s) for s in out_node_shapes]
inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs)
cond_subc = xc.XlaBuilder("remat_cond_subcomputation")
input_op = xb.parameter(cond_subc, 0, c.get_shape(inputs), replicated=[])
i = xops.GetTupleElement(input_op, 0)
rng = xops.RngUniform(xb.constant(cond_subc, np.array(1, dtype=np.int32)),
xb.constant(cond_subc, np.array(2, dtype=np.int32)),
rng = xops.RngUniform(xops.Constant(cond_subc, np.array(1, dtype=np.int32)),
xops.Constant(cond_subc, np.array(2, dtype=np.int32)),
xc.Shape.array_shape(xc.PrimitiveType.S32, []))
cond_subc = cond_subc.build(xops.Lt(i, rng))
body_subc = xc.XlaBuilder("remat_body_subcomputation")
input_op = xb.parameter(body_subc, 0, c.get_shape(inputs), replicated=[])
i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes)+1]
i_next = xops.Add(i, xb.constant(body_subc, np.array(1, dtype=np.int32)))
i_next = xops.Add(i, xops.Constant(body_subc, np.array(1, dtype=np.int32)))
body_ctx = ctx.replace(
builder=body_subc,
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'remat')))

View File

@ -121,8 +121,8 @@ def sparse_array_device_put_handler(a, device):
def sparse_array_constant_handler(c, val, canonicalize_dtypes):
return (
xb.constant(val.data, canonicalize_dtypes),
xb.constant(val.indices, canonicalize_dtypes)
xla.pyval_to_ir_constant(val.data, canonicalize_dtypes),
xla.pyval_to_ir_constant(val.indices, canonicalize_dtypes)
)
core.pytype_aval_mappings[SparseArray] = lambda x: x.aval
@ -132,7 +132,7 @@ xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
xla.device_put_handlers[SparseArray] = sparse_array_device_put_handler
xla.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
xb.register_constant_handler(SparseArray, sparse_array_constant_handler)
xla.register_constant_handler(SparseArray, sparse_array_constant_handler)
sp_indices_p = core.Primitive('sp_indices')

View File

@ -38,8 +38,10 @@ from jax import lax
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import tree_util
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_bridge
xops = xla_client.ops
import numpy as np
@ -1761,7 +1763,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
"Consumer ID cannot be a reserved value: 0"):
hcb._callback_handler_data.receiver.add_outfeed(
comp, token, 0,
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
[xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))])
def test_tap_error_different_shapes(self):
"""Try to register different shapes for the same consumer ID."""
@ -1772,17 +1774,17 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
hcb._callback_handler_data.receiver.add_outfeed(
comp, token, 123,
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
[xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))])
with self.assertRaisesRegex(
RuntimeError, ".*does not match previous shape element_type.*"):
hcb._callback_handler_data.receiver.add_outfeed(
comp, token, 123,
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))])
[xops.Constant(comp, np.zeros((2, 3), dtype=np.int32))])
with self.assertRaisesRegex(
RuntimeError, ".*does not match previous shape element_type.*"):
hcb._callback_handler_data.receiver.add_outfeed(
comp, token, 123,
[xla_bridge.constant(comp, np.zeros((2,), dtype=np.float32))])
[xops.Constant(comp, np.zeros((2,), dtype=np.float32))])
def test_tap_id_tap_removed_kwargs(self):
def func(x, transforms, y):