Remove UnshapedArray values from JAX (it remains as an abstract class).

Part of a plan to move away from our "abstract value" lattice to more traditional types.

PiperOrigin-RevId: 691626481
This commit is contained in:
Dougal Maclaurin 2024-10-30 18:53:16 -07:00 committed by jax authors
parent 7f4a34e12b
commit f355dcf34b
5 changed files with 8 additions and 65 deletions

View File

@ -1482,15 +1482,11 @@ class UnshapedArray(AbstractValue):
array_abstraction_level = 4
def __init__(self, dtype, weak_type=False):
# Is it silly to initialize this object and then complain that we should
# never create one? Yes. But otherwise pytype complains.
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
def update(self, dtype=None, weak_type=None):
if dtype is None:
dtype = self.dtype
if weak_type is None:
weak_type = self.weak_type
return UnshapedArray(dtype, weak_type)
raise Exception("We should never create an UnshapedArray object")
def __eq__(self, other):
return (type(self) is type(other) and self.dtype == other.dtype and
@ -1517,19 +1513,6 @@ class UnshapedArray(AbstractValue):
_oct = concretization_function_error(oct)
_index = concretization_function_error(operator.index)
def to_tangent_aval(self) -> AbstractValue:
return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
def join(self, other):
if self.dtype == other.dtype:
if self.weak_type == other.weak_type:
return self
else:
return UnshapedArray(self.dtype, weak_type=False)
else:
raise TypeError(self, other)
def str_short(self, short_dtypes=False) -> str:
return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
@ -1537,13 +1520,6 @@ class UnshapedArray(AbstractValue):
"""Returns a copy of the aval with weak_type=False."""
return self.update(weak_type=False)
@property
def shape(self):
msg = ("UnshapedArray has no shape. Please open an issue at "
"https://github.com/jax-ml/jax/issues because it's unexpected for "
"UnshapedArray instances to ever be produced.")
raise TypeError(msg)
def _canonicalize_dimension(dim: DimSize) -> DimSize:
# Dimensions are most commonly integral (by far), so we check that first.
try:
@ -1670,8 +1646,6 @@ class ShapedArray(UnshapedArray):
if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
return self.update(weak_type=weak_type)
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype)
else:
raise TypeError(self, other)
@ -1753,8 +1727,6 @@ class ConcreteArray(ShapedArray):
elif self.shape == other.shape and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
return ShapedArray(self.shape, self.dtype, weak_type=weak_type)
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype, weak_type=self.weak_type and other.weak_type)
else:
raise TypeError(self, other)
@ -1838,8 +1810,6 @@ class DShapedArray(UnshapedArray):
self.dtype == other.dtype):
weak_type = self.weak_type and other.weak_type
return self.update(weak_type=weak_type)
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype)
else:
raise TypeError(self, other)
@ -1996,6 +1966,8 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None):
aval_type = type(aval)
if aval_type is ShapedArray and weak_type is None:
return aval
if aval_type is DShapedArray and weak_type is None:
return aval
if weak_type is None:
weak_type = getattr(aval, 'weak_type', False)
for typ in aval_type.__mro__:
@ -2011,8 +1983,8 @@ def _shaped_array_mapping(aval, weak_type):
raise_to_shaped_mappings: dict[type, Callable] = {
AbstractToken: lambda aval, _: aval,
Bot: lambda aval, _: aval,
UnshapedArray: lambda aval, _: aval,
ShapedArray: _shaped_array_mapping,
DShapedArray: lambda aval, _: aval,
DConcreteArray: lambda aval, weak_type: DShapedArray(
aval.shape, aval.dtype, weak_type
),

View File

@ -812,7 +812,7 @@ def _threefry2x32_abstract_eval(*args):
shape = lax_internal.broadcasting_shape_rule(*args)
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32))
else:
aval = core.UnshapedArray(jnp.dtype(jnp.uint32))
raise TypeError(f"Arguments to threefry2x32 must all be arrays, got {args}")
return (aval,) * 2

View File

@ -932,23 +932,6 @@ class BatchingTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False,
rtol=jtu.default_gradient_tolerance)
def testIssue387(self):
# https://github.com/jax-ml/jax/issues/387
R = self.rng().rand(100, 2)
def dist_sq(R):
dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :]
zero = jnp.zeros_like(dR)
dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR))
return jnp.sum(dR ** 2, axis=2)
@jit
def f(R):
_ = dist_sq(R)
return jnp.sum(R ** 2)
_ = hessian(f)(R) # don't crash on UnshapedArray
@jax.legacy_prng_key('allow')
def testIssue489(self):
# https://github.com/jax-ml/jax/issues/489

View File

@ -33,14 +33,13 @@ from jax._src import core
from jax._src import linear_util as lu
from jax._src import util
from jax._src import test_util as jtu
from jax._src.core import UnshapedArray, ShapedArray, DBIdx
from jax._src.core import ShapedArray, DBIdx
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
config.parse_flags_with_absl()
_ = pe.PartialVal.unknown(UnshapedArray(np.float32))
__ = pe.PartialVal.unknown(ShapedArray((), np.float32))
def call(f, *args):

View File

@ -1362,17 +1362,6 @@ class StateControlFlowTest(jtu.JaxTestCase):
class GeneralRefTest(jtu.JaxTestCase):
def test_unshaped_ref(self):
def f(x_ref):
x = x_ref[...]
x_ref[...] = x
ref_addupdate(x_ref, (), x)
return [x]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [AbstractRef(core.UnshapedArray(jnp.int32))])
self.assertIs(type(jaxpr.outvars[0].aval), core.UnshapedArray)
self.assertEqual(jaxpr.outvars[0].aval.dtype, jnp.dtype("int32"))
def test_token(self):
def f(x_ref):
x = x_ref[...]