mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Remove UnshapedArray values from JAX (it remains as an abstract class).
Part of a plan to move away from our "abstract value" lattice to more traditional types. PiperOrigin-RevId: 691626481
This commit is contained in:
parent
7f4a34e12b
commit
f355dcf34b
@ -1482,15 +1482,11 @@ class UnshapedArray(AbstractValue):
|
||||
array_abstraction_level = 4
|
||||
|
||||
def __init__(self, dtype, weak_type=False):
|
||||
# Is it silly to initialize this object and then complain that we should
|
||||
# never create one? Yes. But otherwise pytype complains.
|
||||
self.dtype = _dtype_object(dtype)
|
||||
self.weak_type = weak_type
|
||||
|
||||
def update(self, dtype=None, weak_type=None):
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
if weak_type is None:
|
||||
weak_type = self.weak_type
|
||||
return UnshapedArray(dtype, weak_type)
|
||||
raise Exception("We should never create an UnshapedArray object")
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other) and self.dtype == other.dtype and
|
||||
@ -1517,19 +1513,6 @@ class UnshapedArray(AbstractValue):
|
||||
_oct = concretization_function_error(oct)
|
||||
_index = concretization_function_error(operator.index)
|
||||
|
||||
def to_tangent_aval(self) -> AbstractValue:
|
||||
return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type)
|
||||
|
||||
def join(self, other):
|
||||
if self.dtype == other.dtype:
|
||||
if self.weak_type == other.weak_type:
|
||||
return self
|
||||
else:
|
||||
return UnshapedArray(self.dtype, weak_type=False)
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
|
||||
@ -1537,13 +1520,6 @@ class UnshapedArray(AbstractValue):
|
||||
"""Returns a copy of the aval with weak_type=False."""
|
||||
return self.update(weak_type=False)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
msg = ("UnshapedArray has no shape. Please open an issue at "
|
||||
"https://github.com/jax-ml/jax/issues because it's unexpected for "
|
||||
"UnshapedArray instances to ever be produced.")
|
||||
raise TypeError(msg)
|
||||
|
||||
def _canonicalize_dimension(dim: DimSize) -> DimSize:
|
||||
# Dimensions are most commonly integral (by far), so we check that first.
|
||||
try:
|
||||
@ -1670,8 +1646,6 @@ class ShapedArray(UnshapedArray):
|
||||
if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
|
||||
weak_type = self.weak_type and other.weak_type
|
||||
return self.update(weak_type=weak_type)
|
||||
elif self.dtype == other.dtype:
|
||||
return UnshapedArray(self.dtype)
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
@ -1753,8 +1727,6 @@ class ConcreteArray(ShapedArray):
|
||||
elif self.shape == other.shape and self.dtype == other.dtype:
|
||||
weak_type = self.weak_type and other.weak_type
|
||||
return ShapedArray(self.shape, self.dtype, weak_type=weak_type)
|
||||
elif self.dtype == other.dtype:
|
||||
return UnshapedArray(self.dtype, weak_type=self.weak_type and other.weak_type)
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
@ -1838,8 +1810,6 @@ class DShapedArray(UnshapedArray):
|
||||
self.dtype == other.dtype):
|
||||
weak_type = self.weak_type and other.weak_type
|
||||
return self.update(weak_type=weak_type)
|
||||
elif self.dtype == other.dtype:
|
||||
return UnshapedArray(self.dtype)
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
@ -1996,6 +1966,8 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None):
|
||||
aval_type = type(aval)
|
||||
if aval_type is ShapedArray and weak_type is None:
|
||||
return aval
|
||||
if aval_type is DShapedArray and weak_type is None:
|
||||
return aval
|
||||
if weak_type is None:
|
||||
weak_type = getattr(aval, 'weak_type', False)
|
||||
for typ in aval_type.__mro__:
|
||||
@ -2011,8 +1983,8 @@ def _shaped_array_mapping(aval, weak_type):
|
||||
raise_to_shaped_mappings: dict[type, Callable] = {
|
||||
AbstractToken: lambda aval, _: aval,
|
||||
Bot: lambda aval, _: aval,
|
||||
UnshapedArray: lambda aval, _: aval,
|
||||
ShapedArray: _shaped_array_mapping,
|
||||
DShapedArray: lambda aval, _: aval,
|
||||
DConcreteArray: lambda aval, weak_type: DShapedArray(
|
||||
aval.shape, aval.dtype, weak_type
|
||||
),
|
||||
|
@ -812,7 +812,7 @@ def _threefry2x32_abstract_eval(*args):
|
||||
shape = lax_internal.broadcasting_shape_rule(*args)
|
||||
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32))
|
||||
else:
|
||||
aval = core.UnshapedArray(jnp.dtype(jnp.uint32))
|
||||
raise TypeError(f"Arguments to threefry2x32 must all be arrays, got {args}")
|
||||
return (aval,) * 2
|
||||
|
||||
|
||||
|
@ -932,23 +932,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False,
|
||||
rtol=jtu.default_gradient_tolerance)
|
||||
|
||||
def testIssue387(self):
|
||||
# https://github.com/jax-ml/jax/issues/387
|
||||
R = self.rng().rand(100, 2)
|
||||
|
||||
def dist_sq(R):
|
||||
dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :]
|
||||
zero = jnp.zeros_like(dR)
|
||||
dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR))
|
||||
return jnp.sum(dR ** 2, axis=2)
|
||||
|
||||
@jit
|
||||
def f(R):
|
||||
_ = dist_sq(R)
|
||||
return jnp.sum(R ** 2)
|
||||
|
||||
_ = hessian(f)(R) # don't crash on UnshapedArray
|
||||
|
||||
@jax.legacy_prng_key('allow')
|
||||
def testIssue489(self):
|
||||
# https://github.com/jax-ml/jax/issues/489
|
||||
|
@ -33,14 +33,13 @@ from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.core import UnshapedArray, ShapedArray, DBIdx
|
||||
from jax._src.core import ShapedArray, DBIdx
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
_ = pe.PartialVal.unknown(UnshapedArray(np.float32))
|
||||
__ = pe.PartialVal.unknown(ShapedArray((), np.float32))
|
||||
|
||||
def call(f, *args):
|
||||
|
@ -1362,17 +1362,6 @@ class StateControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
class GeneralRefTest(jtu.JaxTestCase):
|
||||
|
||||
def test_unshaped_ref(self):
|
||||
def f(x_ref):
|
||||
x = x_ref[...]
|
||||
x_ref[...] = x
|
||||
ref_addupdate(x_ref, (), x)
|
||||
return [x]
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [AbstractRef(core.UnshapedArray(jnp.int32))])
|
||||
self.assertIs(type(jaxpr.outvars[0].aval), core.UnshapedArray)
|
||||
self.assertEqual(jaxpr.outvars[0].aval.dtype, jnp.dtype("int32"))
|
||||
|
||||
def test_token(self):
|
||||
def f(x_ref):
|
||||
x = x_ref[...]
|
||||
|
Loading…
x
Reference in New Issue
Block a user