mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Move dtype canonicalization out of core.AbstractValue subclasses.
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context. The callers to which canonicalization was added were: a) all callers of `ConcreteArray` inside the JAX Tree. b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures. PiperOrigin-RevId: 414704700
This commit is contained in:
parent
56f029f7f0
commit
06cd1fedee
@ -49,18 +49,23 @@ array_types = {np.ndarray, np.bool_,
|
||||
np.complex64, np.complex128,
|
||||
np.longlong, np.intc}
|
||||
|
||||
def canonical_concrete_aval(val, weak_type=None):
|
||||
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,
|
||||
weak_type=weak_type)
|
||||
|
||||
for t in array_types:
|
||||
core.pytype_aval_mappings[t] = ConcreteArray
|
||||
core.pytype_aval_mappings[t] = canonical_concrete_aval
|
||||
ad_util.jaxval_zeros_likers[t] = zeros_like_array
|
||||
|
||||
core.literalable_types.update(array_types)
|
||||
|
||||
def _zeros_like_python_scalar(t, x):
|
||||
aval = core.ShapedArray((), dtypes.python_scalar_dtypes[t], weak_type=True)
|
||||
dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[t])
|
||||
aval = core.ShapedArray((), dtype, weak_type=True)
|
||||
return ad_util.zeros_like_aval(aval)
|
||||
|
||||
def _make_concrete_python_scalar(t, x):
|
||||
return ConcreteArray(
|
||||
return canonical_concrete_aval(
|
||||
np.array(x, dtype=dtypes._scalar_type_to_dtype(t, x)),
|
||||
weak_type=True)
|
||||
|
||||
|
@ -69,7 +69,7 @@ from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
|
||||
local_devices, process_index,
|
||||
process_count, host_id, host_ids,
|
||||
host_count, default_backend)
|
||||
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from jax.core import ShapedArray, raise_to_shaped
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax._src.config import config
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import dtypes
|
||||
from jax._src import profiler
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -306,4 +307,4 @@ deleted_buffer = DeletedBuffer()
|
||||
device_array_types: List[type] = [xc.Buffer, _DeviceArray]
|
||||
for _device_array in device_array_types:
|
||||
core.literalable_types.add(_device_array)
|
||||
core.pytype_aval_mappings[device_array] = core.ConcreteArray
|
||||
core.pytype_aval_mappings[device_array] = abstract_arrays.canonical_concrete_aval
|
||||
|
@ -1480,7 +1480,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
return carry, stacked_y
|
||||
|
||||
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
|
||||
x_dtypes = [x.dtype for x in xs_flat]
|
||||
x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat]
|
||||
x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
|
||||
|
||||
def _create_jaxpr(init):
|
||||
@ -2038,7 +2038,7 @@ def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
|
||||
for new_c, c in zip(new_carry, carry)]
|
||||
return [i + 1] + new_carry + ys
|
||||
|
||||
aval = ShapedArray((), dtypes.int_)
|
||||
aval = ShapedArray((), dtypes.canonicalize_dtype(dtypes.int_))
|
||||
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
||||
return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
|
||||
|
||||
|
@ -1240,7 +1240,8 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
if core.symbolic_equal_dim(operand.shape[0], 0):
|
||||
output_shape = _gather_shape_rule(
|
||||
core.ShapedArray(operand.shape[1:], operand.dtype),
|
||||
core.ShapedArray(indices.shape[1:], indices.dtype),
|
||||
core.ShapedArray(indices.shape[1:],
|
||||
dtypes.canonicalize_dtype(indices.dtype)),
|
||||
dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
|
||||
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
|
||||
mode=mode, fill_value=fill_value)
|
||||
@ -1456,8 +1457,8 @@ def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices,
|
||||
if mode == GatherScatterMode.CLIP:
|
||||
clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False,
|
||||
new_style=True)
|
||||
indices, = clip_fn(ctx, avals_in, [indices_aval.update(dtype=np.int64)],
|
||||
operand, indices, updates, dnums=dimension_numbers)
|
||||
indices, = clip_fn(ctx, avals_in, None, operand, indices, updates,
|
||||
dnums=dimension_numbers)
|
||||
|
||||
c = ctx.builder
|
||||
|
||||
@ -1477,8 +1478,8 @@ def _scatter_add_translation_rule(
|
||||
if mode == GatherScatterMode.CLIP:
|
||||
clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False,
|
||||
new_style=True)
|
||||
indices, = clip_fn(ctx, avals_in, [indices_aval.update(dtype=np.int64)],
|
||||
operand, indices, updates, dnums=dimension_numbers)
|
||||
indices, = clip_fn(ctx, avals_in, None, operand, indices, updates,
|
||||
dnums=dimension_numbers)
|
||||
|
||||
dtype = operand_aval.dtype
|
||||
scatter_dims = _scatter_dimensions_proto(
|
||||
|
@ -60,8 +60,8 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
least_specialized = _max(map(type, avals),
|
||||
key=operator.attrgetter('array_abstraction_level'))
|
||||
if least_specialized is core.ConcreteArray:
|
||||
return core.ConcreteArray(prim.impl(*[x.val for x in avals], **kwargs),
|
||||
weak_type=weak_type)
|
||||
out = prim.impl(*[x.val for x in avals], **kwargs)
|
||||
return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
|
||||
elif least_specialized is core.ShapedArray:
|
||||
return core.ShapedArray(shape_rule(*avals, **kwargs),
|
||||
dtype_rule(*avals, **kwargs), weak_type=weak_type,
|
||||
@ -81,7 +81,7 @@ def standard_multi_result_abstract_eval(
|
||||
weak_types = weak_type_rule(*avals, **kwargs)
|
||||
if least_specialized is core.ConcreteArray:
|
||||
out_vals = prim.impl(*[x.val for x in avals], **kwargs)
|
||||
return [core.ConcreteArray(val, weak_type=weak_type)
|
||||
return [core.ConcreteArray(val.dtype, val, weak_type=weak_type)
|
||||
for val, weak_type in safe_zip(out_vals, weak_types)]
|
||||
elif least_specialized is core.ShapedArray:
|
||||
out_shapes = shape_rule(*avals, **kwargs)
|
||||
|
24
jax/core.py
24
jax/core.py
@ -1041,7 +1041,7 @@ class UnshapedArray(AbstractValue):
|
||||
array_abstraction_level = 2
|
||||
|
||||
def __init__(self, dtype, weak_type=False):
|
||||
self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
|
||||
self.dtype = np.dtype(dtype)
|
||||
self.weak_type = weak_type
|
||||
|
||||
def update(self, dtype=None, weak_type=None):
|
||||
@ -1183,19 +1183,20 @@ class ConcreteArray(ShapedArray):
|
||||
__slots__ = ['val']
|
||||
array_abstraction_level = 0
|
||||
|
||||
def __init__(self, val, weak_type=None):
|
||||
super().__init__(np.shape(val), np.result_type(val),
|
||||
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
|
||||
def __init__(self, dtype, val, weak_type=None):
|
||||
super().__init__(
|
||||
np.shape(val), dtype,
|
||||
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
|
||||
# Note: canonicalized self.dtype doesn't necessarily match self.val
|
||||
assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype)
|
||||
self.val = val
|
||||
assert self.dtype != np.dtype('O'), val
|
||||
|
||||
def update(self, val=None, weak_type=None):
|
||||
if val is None:
|
||||
val = self.val
|
||||
if weak_type is None:
|
||||
weak_type = self.weak_type
|
||||
return ConcreteArray(val, weak_type)
|
||||
def update(self, dtype=None, val=None, weak_type=None):
|
||||
dtype = self.dtype if dtype is None else dtype
|
||||
val = self.val if val is None else val
|
||||
weak_type = self.weak_type if weak_type is None else weak_type
|
||||
return ConcreteArray(dtype, val, weak_type)
|
||||
|
||||
def __eq__(self, other):
|
||||
if (type(self) is type(other) and self.dtype == other.dtype
|
||||
@ -1271,7 +1272,8 @@ raise_to_shaped_mappings : Dict[type, Callable] = {
|
||||
Bot: lambda aval, _: aval,
|
||||
UnshapedArray: lambda aval, _: aval,
|
||||
ShapedArray: lambda aval, weak_type: ShapedArray(
|
||||
aval.shape, aval.dtype, weak_type, aval.named_shape)
|
||||
aval.shape, dtypes.canonicalize_dtype(aval.dtype), weak_type,
|
||||
aval.named_shape)
|
||||
}
|
||||
|
||||
### Operations on shapes and dimension sizes.
|
||||
|
@ -351,14 +351,14 @@ def _code_generator_and_avals(
|
||||
xla_comp_parameter_shapes = xla_comp.program_shape().parameter_shapes()
|
||||
found_parameter_avals = [
|
||||
core.ShapedArray(found_xla_shape.dimensions(),
|
||||
found_xla_shape.numpy_dtype())
|
||||
dtypes.canonicalize_dtype(found_xla_shape.numpy_dtype()))
|
||||
for found_xla_shape in xla_comp_parameter_shapes
|
||||
]
|
||||
# Add the captured_inputs to args_flat_sig_tf
|
||||
expected_args_flat_sig_tf = list(args_flat_sig_tf) + list(captured_inputs)
|
||||
expected_parameter_avals = [
|
||||
core.ShapedArray(tuple(arg_sig.shape.as_list()),
|
||||
arg_sig.dtype.as_numpy_dtype)
|
||||
dtypes.canonicalize_dtype(arg_sig.dtype.as_numpy_dtype))
|
||||
for arg_sig in expected_args_flat_sig_tf]
|
||||
if found_parameter_avals != expected_parameter_avals:
|
||||
msg = ("Compiled TensorFlow function has unexpected parameter types " +
|
||||
|
@ -459,7 +459,8 @@ class MaskTracer(Tracer):
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
return ShapedArray(self.polymorphic_shape, self.dtype)
|
||||
return ShapedArray(self.polymorphic_shape,
|
||||
dtypes.canonicalize_dtype(self.dtype))
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
|
@ -45,6 +45,7 @@ import numpy as np
|
||||
from jax._src.config import config
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax.core import ConcreteArray, ShapedArray
|
||||
from jax._src import device_array
|
||||
@ -740,7 +741,7 @@ def _register_handlers_for_sharded_device_array(sda):
|
||||
shard_arg_handlers[sda] = _shard_sharded_device_array_slow_path
|
||||
xla.register_constant_handler(sda, _sharded_device_array_constant_handler)
|
||||
|
||||
core.pytype_aval_mappings[sda] = ConcreteArray
|
||||
core.pytype_aval_mappings[sda] = abstract_arrays.canonical_concrete_aval
|
||||
dispatch.device_put_handlers[sda] = dispatch._device_put_array
|
||||
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
|
||||
xla.canonicalize_dtype_handlers[sda] = identity
|
||||
|
@ -2729,7 +2729,8 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
@jit
|
||||
def f():
|
||||
core.lattice_join(core.ConcreteArray(x), core.ConcreteArray(y))
|
||||
core.lattice_join(core.ConcreteArray(x.dtype, x),
|
||||
core.ConcreteArray(y.dtype, y))
|
||||
|
||||
f() # doesn't crash
|
||||
|
||||
|
@ -336,7 +336,8 @@ class CoreTest(jtu.JaxTestCase):
|
||||
def test_concrete_array_string_representation(self):
|
||||
# https://github.com/google/jax/issues/5364
|
||||
self.assertEqual(
|
||||
str(core.ConcreteArray(np.array([1], dtype=np.int32))),
|
||||
str(core.ConcreteArray(np.dtype(np.int32),
|
||||
np.array([1], dtype=np.int32))),
|
||||
'ConcreteArray([1], dtype=int32)')
|
||||
|
||||
|
||||
|
@ -21,6 +21,7 @@ import jax.numpy as jnp
|
||||
from jax import core, jit, lax, make_jaxpr
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -66,13 +67,15 @@ class AbstractSparseArray(core.ShapedArray):
|
||||
|
||||
def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False,
|
||||
named_shape=None):
|
||||
super().__init__(shape, dtype)
|
||||
super().__init__(shape, dtypes.canonicalize_dtype(dtype))
|
||||
named_shape = {} if named_shape is None else named_shape
|
||||
self.index_dtype = index_dtype
|
||||
self.nnz = nnz
|
||||
self.data_aval = core.ShapedArray((nnz,), dtype, weak_type, named_shape)
|
||||
self.indices_aval = core.ShapedArray((nnz, len(shape)), index_dtype,
|
||||
named_shape=named_shape)
|
||||
self.data_aval = core.ShapedArray((nnz,), dtypes.canonicalize_dtype(dtype),
|
||||
weak_type, named_shape)
|
||||
self.indices_aval = core.ShapedArray(
|
||||
(nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype),
|
||||
named_shape=named_shape)
|
||||
|
||||
def update(self, shape=None, dtype=None, index_dtype=None, nnz=None,
|
||||
weak_type=None, named_shape=None):
|
||||
|
@ -56,8 +56,9 @@ class JaxJitTest(parameterized.TestCase):
|
||||
output_buffer = device_put_function(value, device=device)
|
||||
|
||||
self.assertFalse(output_buffer.aval.weak_type)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
self.assertEqual(output_buffer.aval, jax.core.ShapedArray((), dtype))
|
||||
self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype))
|
||||
self.assertEqual(output_buffer.dtype, dtype)
|
||||
|
||||
@parameterized.parameters([jax.device_put, _cpp_device_put])
|
||||
def test_device_put_on_numpy_arrays(self, device_put_function):
|
||||
@ -68,8 +69,9 @@ class JaxJitTest(parameterized.TestCase):
|
||||
output_buffer = device_put_function(value, device=device)
|
||||
|
||||
self.assertFalse(output_buffer.aval.weak_type)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
self.assertEqual(output_buffer.aval, jax.core.ShapedArray((3, 4), dtype))
|
||||
self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype))
|
||||
self.assertEqual(output_buffer.dtype, dtype)
|
||||
np.testing.assert_array_equal(output_buffer, np.zeros((3, 4),
|
||||
dtype=dtype))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user