Move dtype canonicalization out of core.AbstractValue subclasses.

This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.

The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.

PiperOrigin-RevId: 414704700
This commit is contained in:
Peter Hawkins 2021-12-07 06:12:32 -08:00 committed by jax authors
parent 56f029f7f0
commit 06cd1fedee
14 changed files with 56 additions and 38 deletions

View File

@ -49,18 +49,23 @@ array_types = {np.ndarray, np.bool_,
np.complex64, np.complex128,
np.longlong, np.intc}
def canonical_concrete_aval(val, weak_type=None):
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,
weak_type=weak_type)
for t in array_types:
core.pytype_aval_mappings[t] = ConcreteArray
core.pytype_aval_mappings[t] = canonical_concrete_aval
ad_util.jaxval_zeros_likers[t] = zeros_like_array
core.literalable_types.update(array_types)
def _zeros_like_python_scalar(t, x):
aval = core.ShapedArray((), dtypes.python_scalar_dtypes[t], weak_type=True)
dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[t])
aval = core.ShapedArray((), dtype, weak_type=True)
return ad_util.zeros_like_aval(aval)
def _make_concrete_python_scalar(t, x):
return ConcreteArray(
return canonical_concrete_aval(
np.array(x, dtype=dtypes._scalar_type_to_dtype(t, x)),
weak_type=True)

View File

@ -69,7 +69,7 @@ from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
local_devices, process_index,
process_count, host_id, host_ids,
host_count, default_backend)
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.core import ShapedArray, raise_to_shaped
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import pxla

View File

@ -23,6 +23,7 @@ import numpy as np
from jax import core
from jax._src.config import config
from jax._src import abstract_arrays
from jax._src import dtypes
from jax._src import profiler
from jax._src.lib import xla_client as xc
@ -306,4 +307,4 @@ deleted_buffer = DeletedBuffer()
device_array_types: List[type] = [xc.Buffer, _DeviceArray]
for _device_array in device_array_types:
core.literalable_types.add(_device_array)
core.pytype_aval_mappings[device_array] = core.ConcreteArray
core.pytype_aval_mappings[device_array] = abstract_arrays.canonical_concrete_aval

View File

@ -1480,7 +1480,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
return carry, stacked_y
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
x_dtypes = [x.dtype for x in xs_flat]
x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat]
x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
def _create_jaxpr(init):
@ -2038,7 +2038,7 @@ def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
for new_c, c in zip(new_carry, carry)]
return [i + 1] + new_carry + ys
aval = ShapedArray((), dtypes.int_)
aval = ShapedArray((), dtypes.canonicalize_dtype(dtypes.int_))
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)

View File

@ -1240,7 +1240,8 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
if core.symbolic_equal_dim(operand.shape[0], 0):
output_shape = _gather_shape_rule(
core.ShapedArray(operand.shape[1:], operand.dtype),
core.ShapedArray(indices.shape[1:], indices.dtype),
core.ShapedArray(indices.shape[1:],
dtypes.canonicalize_dtype(indices.dtype)),
dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode, fill_value=fill_value)
@ -1456,8 +1457,8 @@ def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices,
if mode == GatherScatterMode.CLIP:
clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False,
new_style=True)
indices, = clip_fn(ctx, avals_in, [indices_aval.update(dtype=np.int64)],
operand, indices, updates, dnums=dimension_numbers)
indices, = clip_fn(ctx, avals_in, None, operand, indices, updates,
dnums=dimension_numbers)
c = ctx.builder
@ -1477,8 +1478,8 @@ def _scatter_add_translation_rule(
if mode == GatherScatterMode.CLIP:
clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False,
new_style=True)
indices, = clip_fn(ctx, avals_in, [indices_aval.update(dtype=np.int64)],
operand, indices, updates, dnums=dimension_numbers)
indices, = clip_fn(ctx, avals_in, None, operand, indices, updates,
dnums=dimension_numbers)
dtype = operand_aval.dtype
scatter_dims = _scatter_dimensions_proto(

View File

@ -60,8 +60,8 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
least_specialized = _max(map(type, avals),
key=operator.attrgetter('array_abstraction_level'))
if least_specialized is core.ConcreteArray:
return core.ConcreteArray(prim.impl(*[x.val for x in avals], **kwargs),
weak_type=weak_type)
out = prim.impl(*[x.val for x in avals], **kwargs)
return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
elif least_specialized is core.ShapedArray:
return core.ShapedArray(shape_rule(*avals, **kwargs),
dtype_rule(*avals, **kwargs), weak_type=weak_type,
@ -81,7 +81,7 @@ def standard_multi_result_abstract_eval(
weak_types = weak_type_rule(*avals, **kwargs)
if least_specialized is core.ConcreteArray:
out_vals = prim.impl(*[x.val for x in avals], **kwargs)
return [core.ConcreteArray(val, weak_type=weak_type)
return [core.ConcreteArray(val.dtype, val, weak_type=weak_type)
for val, weak_type in safe_zip(out_vals, weak_types)]
elif least_specialized is core.ShapedArray:
out_shapes = shape_rule(*avals, **kwargs)

View File

@ -1041,7 +1041,7 @@ class UnshapedArray(AbstractValue):
array_abstraction_level = 2
def __init__(self, dtype, weak_type=False):
self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
self.dtype = np.dtype(dtype)
self.weak_type = weak_type
def update(self, dtype=None, weak_type=None):
@ -1183,19 +1183,20 @@ class ConcreteArray(ShapedArray):
__slots__ = ['val']
array_abstraction_level = 0
def __init__(self, val, weak_type=None):
super().__init__(np.shape(val), np.result_type(val),
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
def __init__(self, dtype, val, weak_type=None):
super().__init__(
np.shape(val), dtype,
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
# Note: canonicalized self.dtype doesn't necessarily match self.val
assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype)
self.val = val
assert self.dtype != np.dtype('O'), val
def update(self, val=None, weak_type=None):
if val is None:
val = self.val
if weak_type is None:
weak_type = self.weak_type
return ConcreteArray(val, weak_type)
def update(self, dtype=None, val=None, weak_type=None):
dtype = self.dtype if dtype is None else dtype
val = self.val if val is None else val
weak_type = self.weak_type if weak_type is None else weak_type
return ConcreteArray(dtype, val, weak_type)
def __eq__(self, other):
if (type(self) is type(other) and self.dtype == other.dtype
@ -1271,7 +1272,8 @@ raise_to_shaped_mappings : Dict[type, Callable] = {
Bot: lambda aval, _: aval,
UnshapedArray: lambda aval, _: aval,
ShapedArray: lambda aval, weak_type: ShapedArray(
aval.shape, aval.dtype, weak_type, aval.named_shape)
aval.shape, dtypes.canonicalize_dtype(aval.dtype), weak_type,
aval.named_shape)
}
### Operations on shapes and dimension sizes.

View File

@ -351,14 +351,14 @@ def _code_generator_and_avals(
xla_comp_parameter_shapes = xla_comp.program_shape().parameter_shapes()
found_parameter_avals = [
core.ShapedArray(found_xla_shape.dimensions(),
found_xla_shape.numpy_dtype())
dtypes.canonicalize_dtype(found_xla_shape.numpy_dtype()))
for found_xla_shape in xla_comp_parameter_shapes
]
# Add the captured_inputs to args_flat_sig_tf
expected_args_flat_sig_tf = list(args_flat_sig_tf) + list(captured_inputs)
expected_parameter_avals = [
core.ShapedArray(tuple(arg_sig.shape.as_list()),
arg_sig.dtype.as_numpy_dtype)
dtypes.canonicalize_dtype(arg_sig.dtype.as_numpy_dtype))
for arg_sig in expected_args_flat_sig_tf]
if found_parameter_avals != expected_parameter_avals:
msg = ("Compiled TensorFlow function has unexpected parameter types " +

View File

@ -459,7 +459,8 @@ class MaskTracer(Tracer):
@property
def aval(self):
return ShapedArray(self.polymorphic_shape, self.dtype)
return ShapedArray(self.polymorphic_shape,
dtypes.canonicalize_dtype(self.dtype))
@property
def dtype(self):

View File

@ -45,6 +45,7 @@ import numpy as np
from jax._src.config import config
from jax import core
from jax import linear_util as lu
from jax._src import abstract_arrays
from jax._src.abstract_arrays import array_types
from jax.core import ConcreteArray, ShapedArray
from jax._src import device_array
@ -740,7 +741,7 @@ def _register_handlers_for_sharded_device_array(sda):
shard_arg_handlers[sda] = _shard_sharded_device_array_slow_path
xla.register_constant_handler(sda, _sharded_device_array_constant_handler)
core.pytype_aval_mappings[sda] = ConcreteArray
core.pytype_aval_mappings[sda] = abstract_arrays.canonical_concrete_aval
dispatch.device_put_handlers[sda] = dispatch._device_put_array
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
xla.canonicalize_dtype_handlers[sda] = identity

View File

@ -2729,7 +2729,8 @@ class APITest(jtu.JaxTestCase):
@jit
def f():
core.lattice_join(core.ConcreteArray(x), core.ConcreteArray(y))
core.lattice_join(core.ConcreteArray(x.dtype, x),
core.ConcreteArray(y.dtype, y))
f() # doesn't crash

View File

@ -336,7 +336,8 @@ class CoreTest(jtu.JaxTestCase):
def test_concrete_array_string_representation(self):
# https://github.com/google/jax/issues/5364
self.assertEqual(
str(core.ConcreteArray(np.array([1], dtype=np.int32))),
str(core.ConcreteArray(np.dtype(np.int32),
np.array([1], dtype=np.int32))),
'ConcreteArray([1], dtype=int32)')

View File

@ -21,6 +21,7 @@ import jax.numpy as jnp
from jax import core, jit, lax, make_jaxpr
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.lib.mlir import ir
@ -66,13 +67,15 @@ class AbstractSparseArray(core.ShapedArray):
def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False,
named_shape=None):
super().__init__(shape, dtype)
super().__init__(shape, dtypes.canonicalize_dtype(dtype))
named_shape = {} if named_shape is None else named_shape
self.index_dtype = index_dtype
self.nnz = nnz
self.data_aval = core.ShapedArray((nnz,), dtype, weak_type, named_shape)
self.indices_aval = core.ShapedArray((nnz, len(shape)), index_dtype,
named_shape=named_shape)
self.data_aval = core.ShapedArray((nnz,), dtypes.canonicalize_dtype(dtype),
weak_type, named_shape)
self.indices_aval = core.ShapedArray(
(nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype),
named_shape=named_shape)
def update(self, shape=None, dtype=None, index_dtype=None, nnz=None,
weak_type=None, named_shape=None):

View File

@ -56,8 +56,9 @@ class JaxJitTest(parameterized.TestCase):
output_buffer = device_put_function(value, device=device)
self.assertFalse(output_buffer.aval.weak_type)
dtype = dtypes.canonicalize_dtype(dtype)
self.assertEqual(output_buffer.aval, jax.core.ShapedArray((), dtype))
self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype))
self.assertEqual(output_buffer.dtype, dtype)
@parameterized.parameters([jax.device_put, _cpp_device_put])
def test_device_put_on_numpy_arrays(self, device_put_function):
@ -68,8 +69,9 @@ class JaxJitTest(parameterized.TestCase):
output_buffer = device_put_function(value, device=device)
self.assertFalse(output_buffer.aval.weak_type)
dtype = dtypes.canonicalize_dtype(dtype)
self.assertEqual(output_buffer.aval, jax.core.ShapedArray((3, 4), dtype))
self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype))
self.assertEqual(output_buffer.dtype, dtype)
np.testing.assert_array_equal(output_buffer, np.zeros((3, 4),
dtype=dtype))