diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 3e43dac65..1979d3ceb 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -2033,7 +2033,7 @@ " env: Dict[Var, xe.XlaOp] = {}\n", "\n", " def read(x: Atom) -> xe.XlaOp:\n", - " return env[x] if type(x) is Var else xb.constant(c, x.val, False)\n", + " return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))\n", "\n", " def write(v: Var, val: xe.XlaOp) -> None:\n", " env[v] = val\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 084ec9889..e07841e76 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -1593,7 +1593,7 @@ def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp] env: Dict[Var, xe.XlaOp] = {} def read(x: Atom) -> xe.XlaOp: - return env[x] if type(x) is Var else xb.constant(c, x.val, False) + return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val)) def write(v: Var, val: xe.XlaOp) -> None: env[v] = val diff --git a/docs/autodidax.py b/docs/autodidax.py index e94863d89..155d52d37 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -1587,7 +1587,7 @@ def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp] env: Dict[Var, xe.XlaOp] = {} def read(x: Atom) -> xe.XlaOp: - return env[x] if type(x) is Var else xb.constant(c, x.val, False) + return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val)) def write(v: Var, val: xe.XlaOp) -> None: env[v] = val diff --git a/jax/_src/api.py b/jax/_src/api.py index a2cab44f9..01dc243c1 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -818,7 +818,7 @@ def xla_computation(fun: Callable, out_parts_flat = tuple(flatten_axes( "xla_computation out_parts", out_tree(), out_parts)) c = xc.XlaBuilder(f"xla_computation_{fun_name}") - xla_consts = map(partial(xb.constant, c), consts) + xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts) should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100) xla_args, donated_invars = xla._xla_callable_args( c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars) diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index 7898739b7..bd27a78b3 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -344,12 +344,13 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr, name_stack=extend_name_stack(ctx.name_stack, 'cond')) pred, = xla.jaxpr_subcomp( cond_ctx, cond_jaxpr.jaxpr, - _map(partial(xb.constant, cond_c), cond_jaxpr.consts), *(x + z)) + _map(partial(xla.pyval_to_ir_constant, cond_c), cond_jaxpr.consts), + *(x + z)) if batched: scalar = ShapedArray((), np.bool_) or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar) - pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, np.array(False))], or_, - list(range(cond_jaxpr.out_avals[0].ndim))) + pred = xops.Reduce(cond_c, [pred], [xops.Constant(cond_c, np.array(False))], + or_, list(range(cond_jaxpr.out_avals[0].ndim))) body_c = xla_client.XlaBuilder("body_computation") body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry)) @@ -359,15 +360,17 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr, name_stack=extend_name_stack(ctx.name_stack, 'body')) new_z = xla.jaxpr_subcomp( body_ctx, body_jaxpr.jaxpr, - _map(partial(xb.constant, body_c), body_jaxpr.consts), + _map(partial(xla.pyval_to_ir_constant, body_c), body_jaxpr.consts), *(y + z)) if batched: body_pred_ctx = body_ctx.replace( name_stack=extend_name_stack(ctx.name_stack, 'body_pred')) body_pred, = xla.jaxpr_subcomp( body_pred_ctx, cond_jaxpr.jaxpr, - _map(partial(xb.constant, body_c), cond_jaxpr.consts), *(x + z)) - new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z, body_jaxpr.out_avals) + _map(partial(xla.pyval_to_ir_constant, body_c), cond_jaxpr.consts), + *(x + z)) + new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z, + body_jaxpr.out_avals) assert _map(body_c.get_shape, new_z) == _map(body_c.get_shape, z) # no broadcast new_carry = xops.Tuple(body_c, [*x, *y, *new_z]) @@ -806,8 +809,9 @@ def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches, ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))] subctx = ctx.replace( builder=c, name_stack=extend_name_stack(name_stack, name + '_fun')) - outs = xla.jaxpr_subcomp(subctx, jaxpr.jaxpr, - _map(partial(xb.constant, c), jaxpr.consts), *ops) + outs = xla.jaxpr_subcomp( + subctx, jaxpr.jaxpr, + _map(partial(xla.pyval_to_ir_constant, c), jaxpr.consts), *ops) return c.build(xops.Tuple(c, outs)) c = ctx.builder diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 2eee69109..88d7d445e 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2460,10 +2460,12 @@ def _sign_translation_rule(ctx, avals_in, avals_out, x): x_aval, = avals_in dtype = x_aval.dtype if dtypes.issubdtype(dtype, np.unsignedinteger): - zero = xb.constant(c, np.array(0, dtype=dtype)) - return [xops.Select(xops.Eq(x, zero), xops.Broadcast(zero, x_aval.shape), - xops.Broadcast(xb.constant(c, np.array(1, dtype=dtype)), - x_aval.shape))] + zero = xops.Constant(c, np.array(0, dtype=dtype)) + return [xops.Select( + xops.Eq(x, zero), + xops.Broadcast(zero, x_aval.shape), + xops.Broadcast(xops.Constant(c, np.array(1, dtype=dtype)), + x_aval.shape))] return [xops.Sign(x)] sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule) @@ -2764,7 +2766,8 @@ else: return [xops.Mul( xops.Sign(x), xops.Pow(xops.Abs(x), - xb.constant(ctx.builder, np.array(1/3, dtype=x_aval.dtype))))] + xla.pyval_to_ir_constant(ctx.builder, + np.array(1/3, dtype=x_aval.dtype))))] cbrt_p = standard_unop(_float, 'cbrt', translation_rule=_cbrt_translation_rule) @@ -2794,7 +2797,7 @@ def _integer_pow_translation_rule(ctx, avals_in, avals_out, x, *, y): # This should be kept in sync with the jax2tf translation rule. x_aval, = avals_in if y == 0: - one = xb.constant(ctx.builder, np.array(1, dtype=x_aval.dtype)) + one = xla.pyval_to_ir_constant(ctx.builder, np.array(1, dtype=x_aval.dtype)) return [xops.Broadcast(one, x_aval.shape)] is_reciprocal = y < 0 if is_reciprocal: @@ -4752,8 +4755,8 @@ def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *, upper_bound = operand_dims[intarray(dnums.start_index_map)] upper_bound -= intarray(slice_sizes)[intarray(dnums.start_index_map)] - mask = xops.And(xops.Ge(indices, xb.constant(c, intarray(0))), - xops.Le(indices, xb.constant(c, upper_bound), + mask = xops.And(xops.Ge(indices, xla.pyval_to_ir_constant(c, intarray(0))), + xops.Le(indices, xla.pyval_to_ir_constant(c, upper_bound), broadcast_dimensions=[num_batch_dims])) # Compute the conjunction of the mask elements across the dimensions in which @@ -4762,8 +4765,8 @@ def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *, scalar_pred = xla_client.Shape.array_shape(np.dtype(np.bool_), ()) xops.And(xb.parameter(and_builder, 0, scalar_pred), xb.parameter(and_builder, 1, scalar_pred)) - mask = xops.Reduce(c, [mask], [xb.constant(c, True)], and_builder.build(), - [num_batch_dims]) + mask = xops.Reduce(c, [mask], [xla.pyval_to_ir_constant(c, True)], + and_builder.build(), [num_batch_dims]) # Computes the output shape and the positions of the batch dimensions in the # output @@ -4778,7 +4781,7 @@ def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *, xops.Gather(operand, indices, dimensions, slice_sizes, indices_are_sorted=indices_are_sorted), xops.Broadcast( - xb.constant(c, np.array(fill_value, operand_aval.dtype)), + xla.pyval_to_ir_constant(c, np.array(fill_value, operand_aval.dtype)), aval_out.shape))] def _gather_jvp_rule(g, operand, indices, *, dimension_numbers, @@ -5023,8 +5026,9 @@ def _clamp_scatter_indices(c, indices, operand_shape, updates_shape, dnums): upper_bound -= intarray(slice_sizes)[intarray(dnums.scatter_dims_to_operand_dims)] upper_bound = np.minimum(upper_bound, np.iinfo(indices_dtype).max) return xops.Min( - xops.Max(xb.constant(c, np.array(0, dtype=indices_dtype)), indices), - xb.constant(c, upper_bound.astype(indices_dtype)), + xops.Max(xla.pyval_to_ir_constant(c, np.array(0, dtype=indices_dtype)), + indices), + xla.pyval_to_ir_constant(c, upper_bound.astype(indices_dtype)), broadcast_dimensions=[len(indices_shape.dimensions()) - 1]) def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices, @@ -5037,7 +5041,7 @@ def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices, indices = _clamp_scatter_indices(c, indices, operand_aval.shape, updates_aval.shape, dimension_numbers) - init_value = xb.constant(c, np.array(0, operand_aval.dtype)) + init_value = xla.pyval_to_ir_constant(c, np.array(0, operand_aval.dtype)) update_computation = _reduction_computation( c, update_jaxpr, update_consts, init_value) return [xops.Scatter( @@ -5690,7 +5694,7 @@ def _reduce_sum_translation_rule(ctx, avals_in, avals_out, operand, *, axes): scalar = ShapedArray((), operand_aval.dtype) return [xops.Reduce( ctx.builder, [operand], - [xb.constant(ctx.builder, np.array(0, operand_aval.dtype))], + [xla.pyval_to_ir_constant(ctx.builder, np.array(0, operand_aval.dtype))], xla.primitive_subcomputation(add_p, scalar, scalar), axes)] def _reduce_sum_transpose_rule(cotangent, operand, *, axes): @@ -5723,7 +5727,7 @@ def _reduce_prod_translation_rule(ctx, avals_in, avals_out, operand, *, axes): scalar = ShapedArray((), operand_aval.dtype) return [xops.Reduce( ctx.builder, [operand], - [xb.constant(ctx.builder, np.array(1, operand_aval.dtype))], + [xla.pyval_to_ir_constant(ctx.builder, np.array(1, operand_aval.dtype))], xla.primitive_subcomputation(mul_p, scalar, scalar), axes)] def _reduce_prod_jvp_rule(primals, tangents, *, axes): @@ -5749,7 +5753,8 @@ def _reduce_chooser_translation_rule(prim, identity, ctx, avals_in, avals_out, operand_aval, = avals_in scalar = ShapedArray((), operand_aval.dtype) return [xops.Reduce(ctx.builder, [operand], - [xb.constant(ctx.builder, identity(operand_aval.dtype))], + [xla.pyval_to_ir_constant(ctx.builder, + identity(operand_aval.dtype))], xla.primitive_subcomputation(prim, scalar, scalar), axes)] def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): @@ -5872,7 +5877,7 @@ def _reduce_logical_translation_rule(prim, identity, ctx, avals_in, avals_out, operand, *, axes): scalar = ShapedArray((), np.bool_) return [xops.Reduce(ctx.builder, [operand], - [xb.constant(ctx.builder, identity(np.bool_))], + [xla.pyval_to_ir_constant(ctx.builder, identity(np.bool_))], xla.primitive_subcomputation(prim, scalar, scalar), axes)] _reduce_or_translation_rule = partial(_reduce_logical_translation_rule, @@ -5958,7 +5963,8 @@ def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *, operand_aval, = avals_in scalar = ShapedArray((), operand_aval.dtype) return [xops.ReduceWindowWithGeneralPadding( - operand, xb.constant(ctx.builder, np.array(0, operand_aval.dtype)), + operand, + xla.pyval_to_ir_constant(ctx.builder, np.array(0, operand_aval.dtype)), xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions, window_strides, base_dilation, window_dilation, padding)] @@ -6012,7 +6018,8 @@ def _reduce_window_chooser_translation_rule( operand_aval, = avals_in scalar = ShapedArray((), operand_aval.dtype) return [xops.ReduceWindowWithGeneralPadding( - operand, xb.constant(ctx.builder, identity(operand_aval.dtype)), + operand, + xla.pyval_to_ir_constant(ctx.builder, identity(operand_aval.dtype)), xla.primitive_subcomputation(prim, scalar, scalar), window_dimensions, window_strides, base_dilation, window_dilation, padding)] @@ -6162,7 +6169,7 @@ def _select_and_scatter_add_translation( select = xla.primitive_subcomputation(select_prim, scalar, scalar) scatter = xla.primitive_subcomputation(or_p if dtype == np.bool_ else add_p, scalar, scalar) - zero = xb.constant(c, np.array(0, dtype)) + zero = xla.pyval_to_ir_constant(c, np.array(0, dtype)) # TODO(b/161704903): remove this workaround when XLA:CPU bug is fixed. expand_padding = (expand_padding and not all(lo == 0 and hi == 0 for (lo, hi) in padding)) @@ -6171,7 +6178,7 @@ def _select_and_scatter_add_translation( identity = (_get_max_identity if select_prim is ge_p else _get_min_identity) pads = [(lo, hi, 0) for (lo, hi) in padding] - operand = xops.Pad(operand, xb.constant(c, identity(dtype)), + operand = xops.Pad(operand, xla.pyval_to_ir_constant(c, identity(dtype)), xc.make_padding_config(pads)) padding = [(0, 0) for _ in padding] output = xops.SelectAndScatterWithGeneralPadding( @@ -6287,8 +6294,7 @@ def _select_and_gather_add_translation( assert nbits <= max_bits double_word_reduction = nbits * 2 <= max_bits - const = lambda c, dtype, x: xb.constant(c, np.array(x, dtype=dtype), - canonicalize_types=False) + const = lambda c, dtype, x: xops.Constant(c, np.array(x, dtype=dtype)) if double_word_reduction: # TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so @@ -6381,8 +6387,7 @@ def _select_and_gather_add_translation_using_variadic_reducewindow( tangents_aval, operand_aval = avals_in dtype = operand_aval.dtype - const = lambda c, dtype, x: xb.constant(c, np.array(x, dtype=dtype), - canonicalize_types=False) + const = lambda c, dtype, x: xops.Constant(c, np.array(x, dtype=dtype)) def reducer(): c = xc.XlaBuilder("select_and_gather_pair_reducer") @@ -6870,26 +6875,28 @@ def _rng_bit_generator_translation_rule( def _convert_4xU32_to_2xU64_without_bitcast(c, key): u64_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint64')) - new_key = xb.constant(c, np.zeros(2, dtype=np.dtype('uint64')), - canonicalize_types=False) - _32 = xb.constant(c, np.uint64(32), canonicalize_types=False) + new_key = xops.Constant(c, np.zeros(2, dtype=np.dtype('uint64'))) + _32 = xops.Constant(c, np.array(32, np.uint64)) for i in [0, 2]: hi = xops.ConvertElementType(xops.Slice(key, [i] , [i+1], [1]), u64_etype) lo = xops.ConvertElementType(xops.Slice(key, [i+1], [i+2], [1]), u64_etype) elt = xops.Xor(xops.ShiftLeft(hi, _32), lo) - new_key = xops.DynamicUpdateSlice(new_key, elt, [xb.constant(c, i // 2)]) + new_key = xops.DynamicUpdateSlice(new_key, elt, + [xla.pyval_to_ir_constant(c, i // 2)]) return new_key def _convert_2xU64_to_4xU32_without_bitcast(c, key): u32_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint32')) - new_key = xb.constant(c, np.zeros(4, dtype=np.dtype('uint32'))) - _32 = xb.constant(c, np.uint64(32), canonicalize_types=False) + new_key = xops.Constant(c, np.zeros(4, dtype=np.dtype('uint32'))) + _32 = xops.Constant(c, np.array(32, np.uint64)) for i in [0, 1]: elt = xops.Slice(key, [i], [i+1], [1]) hi = xops.ConvertElementType(xops.ShiftRightLogical(elt, _32), u32_etype) lo = xops.ConvertElementType(elt, u32_etype) - new_key = xops.DynamicUpdateSlice(new_key, hi, [xb.constant(c, 2 * i)]) - new_key = xops.DynamicUpdateSlice(new_key, lo, [xb.constant(c, 2 * i + 1)]) + new_key = xops.DynamicUpdateSlice(new_key, hi, + [xla.pyval_to_ir_constant(c, 2 * i)]) + new_key = xops.DynamicUpdateSlice(new_key, lo, + [xla.pyval_to_ir_constant(c, 2 * i + 1)]) return new_key def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm): diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 4a2064873..389a7b2cf 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -41,7 +41,6 @@ from jax._src.lib import cusparse from jax._src.lib import rocsolver from jax._src.lib import xla_client -from jax._src.lib import xla_bridge as xb from jax._src.lib import version as jaxlib_version xops = xla_client.ops @@ -346,9 +345,9 @@ def _nan_like(c, operand): shape = c.get_shape(operand) dtype = shape.element_type() if jnp.issubdtype(dtype, np.complexfloating): - nan = xb.constant(c, np.array(np.nan * (1. + 1j), dtype=dtype)) + nan = xops.Constant(c, np.array(np.nan * (1. + 1j), dtype=dtype)) else: - nan = xb.constant(c, np.array(np.nan, dtype=dtype)) + nan = xops.Constant(c, np.array(np.nan, dtype=dtype)) return xops.Broadcast(nan, shape.dimensions()) def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand): @@ -696,7 +695,7 @@ def _triangular_solve_cpu_translation_rule( conjugate_a = False if len(shape.dimensions()) == 2 and np.dtype(dtype) in _cpu_lapack_types: return lapack.jax_trsm( - c, xb.constant(c, np.array(1, dtype=dtype)), + c, xops.Constant(c, np.array(1, dtype=dtype)), a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal) else: # Fall back to the HLO implementation for unsupported types or batching. diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index d9f0a57ef..3cd886d80 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -32,7 +32,6 @@ from jax.interpreters import pxla from jax.interpreters import batching from jax._src import dtypes from jax._src.lib import xla_client as xc -from jax._src.lib import xla_bridge as xb from jax._src.numpy import lax_numpy from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, moveaxis @@ -1321,9 +1320,10 @@ def _build_axis_index_lowering(c, axis_name, axis_env): axis_name, = axis_name axis_pos = list(axis_env.names).index(axis_name) nreplicas = axis_env.nreps // prod(axis_env.sizes) - div = xb.constant(c, np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]), - dtype=np.uint32)) - mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32)) + div = xops.Constant(c, + np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]), + dtype=np.uint32)) + mod = xops.Constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32)) unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) return xops.ConvertElementType( unsigned_index, xla.dtype_to_primitive_type(np.dtype(np.int32))) diff --git a/jax/_src/lib/xla_bridge.py b/jax/_src/lib/xla_bridge.py index bd20b8736..93b372305 100644 --- a/jax/_src/lib/xla_bridge.py +++ b/jax/_src/lib/xla_bridge.py @@ -22,7 +22,7 @@ XLA. There are also a handful of related casting utilities. from functools import partial, lru_cache import os -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import warnings from absl import logging @@ -34,7 +34,6 @@ from jax._src.config import flags, bool_env from . import tpu_driver_client from . import xla_client from jax._src import util, traceback_util -from jax._src import dtypes import numpy as np import threading @@ -427,20 +426,6 @@ def host_ids(backend=None): ### utility functions -# TODO(mattjj,frostig): try to remove this function -def normalize_to_xla_dtypes(val): - """Normalize dtypes in a value.""" - if hasattr(val, '__array__') or np.isscalar(val): - return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val))) - elif isinstance(val, (tuple, list)): - return tuple(normalize_to_xla_dtypes(x) for x in val) - raise TypeError('Can\'t convert to XLA: {}'.format(val)) - -def _numpy_array_constant(builder, value, canonicalize_types=True): - if canonicalize_types: - value = normalize_to_xla_dtypes(value) - return [xops.ConstantLiteral(builder, value)] - def parameter(builder, num, shape, name=None, replicated=None): if name is None: name = '' @@ -453,36 +438,6 @@ def parameter(builder, num, shape, name=None, replicated=None): shape.with_major_to_minor_layout_if_absent(), name, replicated) - -def constant_general(builder, py_val, canonicalize_types=True): - """Translate a general constant `py_val` to a constant, canonicalizing its dtype. - - Args: - py_val: a Python value to be translated to a constant. - - Returns: - A representation of the constant as a list of xla ops. - """ - for t in type(py_val).mro(): - handler = _constant_handlers.get(t) - if handler: return handler(builder, py_val, canonicalize_types) - if hasattr(py_val, '__jax_array__'): - return constant(builder, py_val.__jax_array__(), canonicalize_types) - raise TypeError("No constant handler for type: {}".format(type(py_val))) - -def constant(builder, py_val, canonicalize_types=True): - """Translate constant `py_val` to a constant, canonicalizing its dtype. - - Args: - py_val: a Python value to be translated to a constant. - - Returns: - A representation of the constant, either a ComputationDataHandle or None - """ - const = constant_general(builder, py_val, canonicalize_types=canonicalize_types) - assert len(const) == 1, f"Internal error: cannot create constant from object of type {type(py_val)}" - return const[0] - # HLO instructions optionally can be annotated to say how the output should be # spatially partitioned (represented in XLA as OpSharding protos, see # _sharding_to_proto). For array outputs, the annotation is either an int per @@ -545,65 +500,3 @@ def set_sharding(builder, op, sharding: SpatialSharding): def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs): """Builds op_fn(*args, **kwargs) with sharding annotation.""" return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs) - - -def register_constant_handler(type_, handler_fun): - _constant_handlers[type_] = handler_fun -_constant_handlers: Dict[type, Callable] = {} - - -def _ndarray_constant_handler(c, val, canonicalize_types=True): - """Constant handler for ndarray literals, handling zero-size strides. - - This function essentially calls _numpy_array_constant(val) except it has - special handling of arrays with any strides of size zero: for those, it - generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose - to avoid staging in large literals that might arise from np.zeros or np.ones - or the output of lax.broadcast (which uses np.broadcast_to which in turn - uses size-zero strides). - - Args: - c: an XlaBuilder - val: an ndarray. - - Returns: - An XLA ComputationDataHandle / XlaOp representing the constant ndarray - staged into the XLA Computation. - """ - # TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose - if dtypes.result_type(val) == dtypes.float0: - return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_)) - elif np.any(np.equal(0, val.strides)) and val.size > 0: - zero_stride_axes, = np.where(np.equal(0, val.strides)) - other_axes, = np.where(np.not_equal(0, val.strides)) - collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) - for ax in range(val.ndim))] - xla_val = xops.Broadcast( - _numpy_array_constant(c, collapsed_val, canonicalize_types)[0], - np.take(val.shape, zero_stride_axes)) - permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes)) - return [xops.Transpose(xla_val, permutation)] - else: - return _numpy_array_constant(c, val, canonicalize_types) -register_constant_handler(np.ndarray, _ndarray_constant_handler) - - -def _scalar_constant_handler(c, val, canonicalize_types=True): - return _numpy_array_constant(c, val, canonicalize_types) - -for scalar_type in [np.int8, np.int16, np.int32, np.int64, - np.uint8, np.uint16, np.uint32, np.uint64, - np.float16, np.float32, np.float64, - np.bool_, np.longlong, - xla_client.bfloat16]: - register_constant_handler(scalar_type, _scalar_constant_handler) - -# https://github.com/winpython/winpython/issues/613#issuecomment-380121523 -if hasattr(np, "float128"): - register_constant_handler(np.float128, _scalar_constant_handler) - -def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True): - return _numpy_array_constant(c, dtype.type(val)) - -for ptype, dtype in dtypes.python_scalar_dtypes.items(): - register_constant_handler(ptype, partial(_python_scalar_handler, dtype)) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 7105008ff..63b852689 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -28,7 +28,6 @@ from jax.dtypes import float0 from jax.interpreters import batching from jax.interpreters import xla from jax._src.api import jit, vmap -from jax._src.lib import xla_bridge from jax._src.lib import xla_client from jax._src.lib import cuda_prng import jax._src.pretty_printer as pp @@ -358,7 +357,7 @@ def _threefry2x32_gpu_translation_rule(c, k1, k2, x1, x2): rank = len(shape) if 0 in shape: zeros = xla_client.ops.Broadcast( - xla_bridge.constant(c, np.array(0, np.uint32)), shape) + xla_client.ops.Constant(c, np.array(0, np.uint32)), shape) return xla_client.ops.Tuple(c, [zeros, zeros]) def _broadcast(x): ndims = c.get_shape(x).rank() diff --git a/jax/experimental/djax.py b/jax/experimental/djax.py index 69f0d0529..09137354f 100644 --- a/jax/experimental/djax.py +++ b/jax/experimental/djax.py @@ -643,7 +643,8 @@ def _make_params(c, dim_in_avals, in_avals): def _xla_consts(c, consts): unique_consts = {id(const): const for const in consts} xla_consts = { - id_: [xb.constant(c, const)] for id_, const in unique_consts.items()} + id_: [xla.pyval_to_ir_constant(c, const)] + for id_, const in unique_consts.items()} return [xla_consts[id(const)] for const in consts] def djaxpr_subcomp(c, jaxpr, dim_args, args): @@ -654,7 +655,7 @@ def djaxpr_subcomp(c, jaxpr, dim_args, args): def read(v): if type(v) is core.Literal: - return [xb.constant(c, xla.canonicalize_dtype(v.val))] + return [xla.pyval_to_ir_constant(c, xla.canonicalize_dtype(v.val))] else: return env[v] diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 04fc4657a..de897b354 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -1360,20 +1360,21 @@ def _xmap_translation_rule_replica(c, axis_env, return xops.Tuple(c, outs) def _xla_tile_base_indices(c, axis_env, tile_shape, axes, axis_sizes): - zero = xb.constant(c, np.zeros((), dtype=np.int32)) + zero = xops.Constant(c, np.zeros((), dtype=np.int32)) linear_idxs = [zero] * len(tile_shape) strides = [1] * len(tile_shape) for name, axis in reversed(axes.items()): axis_index = _build_axis_index_lowering( c, axis_name=name, axis_env=axis_env) - stride_c = xb.constant(c, np.array(strides[axis], np.int32)) + stride_c = xops.Constant(c, np.array(strides[axis], np.int32)) if linear_idxs[axis] is zero and strides[axis] == 1: linear_idxs[axis] = axis_index else: linear_idxs[axis] = xops.Add(linear_idxs[axis], xops.Mul(axis_index, stride_c)) strides[axis] *= axis_sizes[name] return [zero if linear_idx is zero else - xops.Mul(linear_idx, xb.constant(c, np.array(tile_dim_size, np.int32))) + xops.Mul(linear_idx, + xops.Constant(c, np.array(tile_dim_size, np.int32))) for linear_idx, tile_dim_size in zip(linear_idxs, tile_shape)] def _xla_tile(c, axis_env, x, in_axes, axis_sizes): @@ -1405,7 +1406,7 @@ def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend): shape[axis] *= axis_sizes[name] base_idxs = _xla_tile_base_indices(c, axis_env, tile_shape, out_axes, axis_sizes) - padded = xops.Broadcast(xb.constant(c, np.array(0, x_dtype)), shape) + padded = xops.Broadcast(xops.Constant(c, np.array(0, x_dtype)), shape) padded = xops.DynamicUpdateSlice(padded, x, base_idxs) replica_groups_protos = xc.make_replica_groups( xla.axis_groups(axis_env, tuple(out_axes.keys()))) @@ -1413,7 +1414,7 @@ def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend): # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU if convert_bool: - nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32))) + nonzero = xops.Ne(out, xops.Constant(c, np.array(0, dtype=np.float32))) out = xops.ConvertElementType( nonzero, xla.dtype_to_primitive_type(np.dtype(np.bool_))) return out diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index ac102dca8..5e4608705 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -686,12 +686,13 @@ def _shard_sharded_device_array_slow_path(x, devices, indices): def _sharded_device_array_constant_handler(c, val, canonicalize_types=True): - return xb.constant_general(c, np.asarray(val), canonicalize_types=canonicalize_types) + return xla.pyval_to_ir_constants(c, np.asarray(val), + canonicalize_types=canonicalize_types) def _register_handlers_for_sharded_device_array(sda): shard_arg_handlers[sda] = _shard_sharded_device_array_slow_path - xb.register_constant_handler(sda, _sharded_device_array_constant_handler) + xla.register_constant_handler(sda, _sharded_device_array_constant_handler) core.pytype_aval_mappings[sda] = ConcreteArray xla.device_put_handlers[sda] = xla._device_put_array @@ -876,7 +877,7 @@ def parallel_callable(fun: lu.WrappedFun, tuple_args = len(global_sharded_avals) > 100 # pass long arg lists as tuple for TPU c = xc.XlaBuilder("pmap_{}".format(fun.__name__)) - xla_consts = map(partial(xb.constant, c), consts) + xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts) replicated_args = [axis is None for axis in in_axes] xla_args, donated_invars = xla._xla_callable_args(c, global_sharded_avals, tuple_args, replicated=replicated_args, @@ -1317,7 +1318,7 @@ def _xla_shard(c, aval, axis_env, x, in_axis): return x elif isinstance(aval, ShapedArray): dims = list(c.get_shape(x).dimensions()) - zero = xb.constant(c, np.zeros((), dtype=np.uint32)) + zero = xops.Constant(c, np.zeros((), dtype=np.uint32)) idxs = [zero] * (len(dims) - 1) idxs.insert(in_axis, _unravel_index(c, axis_env)) dims_unsqueezed = dims.copy() @@ -1344,9 +1345,10 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend): xla_shape = c.get_shape(x) dims = list(xla_shape.dimensions()) - padded = xops.Broadcast(xb.constant(c, np.array(0, xla_shape.numpy_dtype())), - [axis_env.sizes[-1]] + dims) - zero = xb.constant(c, np.zeros((), dtype=np.uint32)) + padded = xops.Broadcast( + xops.Constant(c, np.array(0, xla_shape.numpy_dtype())), + [axis_env.sizes[-1]] + dims) + zero = xops.Constant(c, np.zeros((), dtype=np.uint32)) idxs = [_unravel_index(c, axis_env)] + [zero] * len(dims) padded = xops.DynamicUpdateSlice(padded, xops.Reshape(x, [1] + dims), idxs) replica_groups_protos = xc.make_replica_groups( @@ -1360,7 +1362,7 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend): # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU if convert_bool: - nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32))) + nonzero = xops.Ne(out, xops.Constant(c, np.array(0, dtype=np.float32))) out = xops.ConvertElementType( nonzero, xla.dtype_to_primitive_type(np.dtype(np.bool_))) return out @@ -1368,8 +1370,9 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend): raise TypeError((aval, c.get_shape(x))) def _unravel_index(c, axis_env): - div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes), np.uint32)) - mod = xb.constant(c, np.array(axis_env.sizes[-1], np.uint32)) + div = xops.Constant(c, np.array(axis_env.nreps // prod(axis_env.sizes), + np.uint32)) + mod = xops.Constant(c, np.array(axis_env.sizes[-1], np.uint32)) return xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) # ------------------- xmap ------------------- @@ -1597,7 +1600,7 @@ def lower_mesh_computation( # 3. Build up the HLO c = xc.XlaBuilder(f"xmap_{fun.__name__}") - xla_consts = map(partial(xb.constant, c), consts) + xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts) tuple_args = len(in_jaxpr_avals) > 100 # pass long arg lists as tuple for TPU in_partitions: Optional[List] if spmd_lowering: diff --git a/jax/interpreters/sharded_jit.py b/jax/interpreters/sharded_jit.py index b572ad052..33cca68cc 100644 --- a/jax/interpreters/sharded_jit.py +++ b/jax/interpreters/sharded_jit.py @@ -139,7 +139,7 @@ def _sharded_callable( fun.__name__, nparts, global_abstract_args) c = xc.XlaBuilder("spjit_{}".format(fun.__name__)) - xla_consts = _map(partial(xb.constant, c), consts) + xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts) xla_args = _xla_sharded_args(c, global_abstract_args, in_parts) axis_env = xla.AxisEnv(nrep, (), ()) ctx = xla.TranslationContext( diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index b7e1eab54..35ff6e56d 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -70,28 +70,13 @@ XlaExecutable = xc.Executable _on_exit = False -def compile_or_get_cached(backend, computation, compile_options): - # Avoid import cycle between jax and jax.experimental - from jax.experimental.compilation_cache import compilation_cache as cc - # Persistent compilation cache only implemented on TPU. - # TODO(skye): add warning when initializing cache on unsupported default platform - if cc.is_initialized() and backend.platform == 'tpu': - cached_executable = cc.get_executable(computation, compile_options, backend) - if cached_executable is not None: - logging.info('Persistent compilation cache hit') - return cached_executable - else: - compiled = backend_compile(backend, computation, compile_options) - cc.put_executable(computation, compile_options, compiled, backend) - return compiled - return backend_compile(backend, computation, compile_options) - def identity(x): return x _scalar_types = dtypes.python_scalar_dtypes.keys() # unit representation -def _make_unit_constant(c): return xb.constant_general(c, np.zeros((), dtype=np.dtype('bool'))) +def _make_unit_constant(c): return [ + xops.Constant(c, np.zeros((), dtype=np.dtype('bool')))] def _make_unit_shape(_): return (xc.Shape.array_shape(np.dtype('bool'), ()),) def _device_put_unit(_, device): backend = xb.get_device_backend(device) @@ -128,6 +113,8 @@ def make_op_metadata(primitive: core.Primitive, ### handlers +# Numpy dtypes -> XLA primitive types + _dtype_to_primitive_type: Dict[np.dtype, xc.PrimitiveType] = { np.dtype('bool'): xc.PrimitiveType.PRED, np.dtype('int8'): xc.PrimitiveType.S8, @@ -156,7 +143,8 @@ def dtype_to_primitive_type(dtype: np.dtype) -> xc.PrimitiveType: except KeyError as err: raise TypeError(f"No XLA lowering for NumPy dtype: {dtype}") from err -xb.register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c)) + +# JAX abstract values -> XLA shapes def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[XlaShape]: try: @@ -171,6 +159,123 @@ xla_shape_handlers: Dict[Type[core.AbstractValue], ConcreteArray: _make_array_shape, } + + +# IR constants + +_constant_handlers: Dict[type, Callable] = {} + +def pyval_to_ir_constants(builder, py_val, canonicalize_types=True): + """Translate a general constant `py_val` to a constant, canonicalizing its dtype. + + Args: + py_val: a Python value to be translated to a constant. + + Returns: + A representation of the constant as a list of xla ops. + """ + for t in type(py_val).mro(): + handler = _constant_handlers.get(t) + if handler: return handler(builder, py_val, canonicalize_types) + if hasattr(py_val, '__jax_array__'): + return pyval_to_ir_constants(builder, py_val.__jax_array__(), + canonicalize_types) + raise TypeError("No constant handler for type: {}".format(type(py_val))) + +def pyval_to_ir_constant(builder, py_val, canonicalize_types=True): + """Translate constant `py_val` to a constant, canonicalizing its dtype. + + Args: + py_val: a Python value to be translated to a constant. + + Returns: + A representation of the constant, either a ComputationDataHandle or None + """ + const = pyval_to_ir_constants(builder, py_val, canonicalize_types=canonicalize_types) + assert len(const) == 1, f"Internal error: cannot create constant from object of type {type(py_val)}" + return const[0] + + +def register_constant_handler(type_, handler_fun): + _constant_handlers[type_] = handler_fun + +register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c)) + + +# TODO(mattjj,frostig): try to remove this function +def _normalize_to_xla_dtypes(val): + """Normalize dtypes in a value.""" + if hasattr(val, '__array__') or np.isscalar(val): + return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val))) + elif isinstance(val, (tuple, list)): + return tuple(_normalize_to_xla_dtypes(x) for x in val) + raise TypeError('Can\'t convert to XLA: {}'.format(val)) + +def _numpy_array_constant(builder, value, canonicalize_types=True): + if canonicalize_types: + value = _normalize_to_xla_dtypes(value) + return [xops.Constant(builder, value)] + + +def _ndarray_constant_handler(c, val, canonicalize_types=True): + """Constant handler for ndarray literals, handling zero-size strides. + + This function essentially calls _numpy_array_constant(val) except it has + special handling of arrays with any strides of size zero: for those, it + generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose + to avoid staging in large literals that might arise from np.zeros or np.ones + or the output of lax.broadcast (which uses np.broadcast_to which in turn + uses size-zero strides). + + Args: + c: an XlaBuilder + val: an ndarray. + + Returns: + An XLA ComputationDataHandle / XlaOp representing the constant ndarray + staged into the XLA Computation. + """ + # TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose + if dtypes.result_type(val) == dtypes.float0: + return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_)) + elif np.any(np.equal(0, val.strides)) and val.size > 0: + zero_stride_axes, = np.where(np.equal(0, val.strides)) + other_axes, = np.where(np.not_equal(0, val.strides)) + collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) + for ax in range(val.ndim))] + xla_val = xops.Broadcast( + _numpy_array_constant(c, collapsed_val, canonicalize_types)[0], + np.take(val.shape, zero_stride_axes)) + permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes)) + return [xops.Transpose(xla_val, permutation)] + else: + return _numpy_array_constant(c, val, canonicalize_types) +register_constant_handler(np.ndarray, _ndarray_constant_handler) + + +def _scalar_constant_handler(c, val, canonicalize_types=True): + return _numpy_array_constant(c, val, canonicalize_types) + +for scalar_type in [np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + np.float16, np.float32, np.float64, + np.bool_, np.longlong, + dtypes.bfloat16]: + register_constant_handler(scalar_type, _scalar_constant_handler) + +# https://github.com/winpython/winpython/issues/613#issuecomment-380121523 +if hasattr(np, "float128"): + register_constant_handler(np.float128, _scalar_constant_handler) + +def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True): + return _numpy_array_constant(c, dtype.type(val)) + +for ptype, dtype in dtypes.python_scalar_dtypes.items(): + register_constant_handler(ptype, partial(_python_scalar_handler, dtype)) + + +# Result handlers + def aval_to_result_handler(device: Optional[Device], aval: core.AbstractValue) -> Callable: try: @@ -434,7 +539,7 @@ def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr, # assert ctx.platform is not None def read(v): if type(v) is Literal: - return xb.constant_general(ctx.builder, canonicalize_dtype(v.val)) + return pyval_to_ir_constants(ctx.builder, canonicalize_dtype(v.val)) else: return env[v] @@ -449,7 +554,8 @@ def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr, env[v] = node env: Dict[core.Var, Sequence[XlaOp]] = {} - _partitionmap(write, [core.unitvar], _make_unit_constant(ctx.builder)) + _partitionmap(write, [core.unitvar], + pyval_to_ir_constants(ctx.builder, core.unit)) _partitionmap(write, jaxpr.constvars, consts) _partitionmap(write, jaxpr.invars, args) for eqn in jaxpr.eqns: @@ -645,7 +751,7 @@ def _flatten_shape(s: XlaShape, index: Tuple[int, ...], def _xla_consts(c, consts): unique_consts = {id(const): const for const in consts} xla_consts = { - id_: xb.constant_general(c, const) for id_, const in unique_consts.items()} + id_: pyval_to_ir_constants(c, const) for id_, const in unique_consts.items()} return [c for const in consts for c in xla_consts[id(const)]] def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name, @@ -740,6 +846,23 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, out_avals, kept_var_idx) +def compile_or_get_cached(backend, computation, compile_options): + # Avoid import cycle between jax and jax.experimental + from jax.experimental.compilation_cache import compilation_cache as cc + # Persistent compilation cache only implemented on TPU. + # TODO(skye): add warning when initializing cache on unsupported default platform + if cc.is_initialized() and backend.platform == 'tpu': + cached_executable = cc.get_executable(computation, compile_options, backend) + if cached_executable is not None: + logging.info('Persistent compilation cache hit') + return cached_executable + else: + compiled = backend_compile(backend, computation, compile_options) + cc.put_executable(computation, compile_options, compiled, backend) + return compiled + return backend_compile(backend, computation, compile_options) + + class XlaComputation: name: str _is_trivial: bool @@ -1183,7 +1306,7 @@ register_translation(xla_call_p, _xla_call_translation_rule) def zeros_like_translation_rule(c, x): shape = c.get_shape(x) assert not shape.is_tuple() - zero = xb.constant(c, np.array(0, shape.element_type())) + zero = xops.Constant(c, np.array(0, shape.element_type())) return xops.Broadcast(zero, shape.dimensions()) translations[ad_util.zeros_like_p] = zeros_like_translation_rule @@ -1539,9 +1662,9 @@ for device_array in [_CppDeviceArray, _DeviceArray]: canonicalize_dtype_handlers[device_array] = identity def _device_array_constant_handler(c, val, canonicalize_types=True): - return xb.constant_general(c, val.device_buffer.to_py()) -xb.register_constant_handler(_DeviceArray, _device_array_constant_handler) -xb.register_constant_handler(_CppDeviceArray, _device_array_constant_handler) + return pyval_to_ir_constants(c, val.device_buffer.to_py()) +register_constant_handler(_DeviceArray, _device_array_constant_handler) +register_constant_handler(_CppDeviceArray, _device_array_constant_handler) def _device_put_device_array(x: Union[DeviceArrayProtocol, _DeviceArray], device: Optional[Device]): x = _copy_device_array_to_device(x, device) @@ -1593,7 +1716,7 @@ masking.defvectorized(device_put_p) def _zeros(c, xla_shape): if xla_shape.is_array(): shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype() - zero = xb.constant(c, np.array(0, dtype=dtype)) + zero = xops.Constant(c, np.array(0, dtype=dtype)) return xops.Broadcast(zero, shape) else: # It is a token @@ -1608,10 +1731,10 @@ def _remat_using_cond(ctx, in_nodes, name, call_jaxpr): Conditional.""" # Fake condition which always selects True branch. c = ctx.builder - rng = xops.RngUniform(xb.constant(c, np.array(0, dtype=np.float32)), - xb.constant(c, np.array(1, dtype=np.float32)), + rng = xops.RngUniform(xops.Constant(c, np.array(0, dtype=np.float32)), + xops.Constant(c, np.array(1, dtype=np.float32)), xc.Shape.array_shape(xc.PrimitiveType.F32, [])) - pred = xops.Lt(rng, xb.constant(c, np.array(2, dtype=np.float32))) + pred = xops.Lt(rng, xops.Constant(c, np.array(2, dtype=np.float32))) true_op = xops.Tuple(c, in_nodes) remat_subc = xc.XlaBuilder("remat_call_subcomputation") @@ -1648,22 +1771,22 @@ def _remat_using_while(ctx, in_nodes, name, call_jaxpr): dummy_subcomp_outs = jaxpr_subcomp(dummy_ctx, call_jaxpr, (), *dummy_args) out_node_shapes = [dummy_subc.get_shape(o) for o in dummy_subcomp_outs] - i_init = xb.constant(c, np.array(0, dtype=np.int32)) + i_init = xops.Constant(c, np.array(0, dtype=np.int32)) zeros_like_outs = [_zeros(c, s) for s in out_node_shapes] inputs = xops.Tuple(c, [i_init] + list(in_nodes) + zeros_like_outs) cond_subc = xc.XlaBuilder("remat_cond_subcomputation") input_op = xb.parameter(cond_subc, 0, c.get_shape(inputs), replicated=[]) i = xops.GetTupleElement(input_op, 0) - rng = xops.RngUniform(xb.constant(cond_subc, np.array(1, dtype=np.int32)), - xb.constant(cond_subc, np.array(2, dtype=np.int32)), + rng = xops.RngUniform(xops.Constant(cond_subc, np.array(1, dtype=np.int32)), + xops.Constant(cond_subc, np.array(2, dtype=np.int32)), xc.Shape.array_shape(xc.PrimitiveType.S32, [])) cond_subc = cond_subc.build(xops.Lt(i, rng)) body_subc = xc.XlaBuilder("remat_body_subcomputation") input_op = xb.parameter(body_subc, 0, c.get_shape(inputs), replicated=[]) i, *args = xla_destructure(body_subc, input_op)[:len(in_nodes)+1] - i_next = xops.Add(i, xb.constant(body_subc, np.array(1, dtype=np.int32))) + i_next = xops.Add(i, xops.Constant(body_subc, np.array(1, dtype=np.int32))) body_ctx = ctx.replace( builder=body_subc, name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'remat'))) diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index 9a7ed4c11..56b1a71da 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -121,8 +121,8 @@ def sparse_array_device_put_handler(a, device): def sparse_array_constant_handler(c, val, canonicalize_dtypes): return ( - xb.constant(val.data, canonicalize_dtypes), - xb.constant(val.indices, canonicalize_dtypes) + xla.pyval_to_ir_constant(val.data, canonicalize_dtypes), + xla.pyval_to_ir_constant(val.indices, canonicalize_dtypes) ) core.pytype_aval_mappings[SparseArray] = lambda x: x.aval @@ -132,7 +132,7 @@ xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x xla.device_put_handlers[SparseArray] = sparse_array_device_put_handler xla.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler -xb.register_constant_handler(SparseArray, sparse_array_constant_handler) +xla.register_constant_handler(SparseArray, sparse_array_constant_handler) sp_indices_p = core.Primitive('sp_indices') diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index ba88a7be1..a75985b6d 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -38,8 +38,10 @@ from jax import lax from jax import numpy as jnp from jax._src import test_util as jtu from jax import tree_util -from jax._src.lib import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import xla_bridge + +xops = xla_client.ops import numpy as np @@ -1761,7 +1763,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): "Consumer ID cannot be a reserved value: 0"): hcb._callback_handler_data.receiver.add_outfeed( comp, token, 0, - [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))]) + [xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))]) def test_tap_error_different_shapes(self): """Try to register different shapes for the same consumer ID.""" @@ -1772,17 +1774,17 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb._initialize_outfeed_receiver() # Needed if this is the sole test hcb._callback_handler_data.receiver.add_outfeed( comp, token, 123, - [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))]) + [xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))]) with self.assertRaisesRegex( RuntimeError, ".*does not match previous shape element_type.*"): hcb._callback_handler_data.receiver.add_outfeed( comp, token, 123, - [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))]) + [xops.Constant(comp, np.zeros((2, 3), dtype=np.int32))]) with self.assertRaisesRegex( RuntimeError, ".*does not match previous shape element_type.*"): hcb._callback_handler_data.receiver.add_outfeed( comp, token, 123, - [xla_bridge.constant(comp, np.zeros((2,), dtype=np.float32))]) + [xops.Constant(comp, np.zeros((2,), dtype=np.float32))]) def test_tap_id_tap_removed_kwargs(self): def func(x, transforms, y):