Expose shape and dtype on ArgInfo and mark aval as private. Aval is an internal property of JAX and shouldn't have been exposed to users. Users can create their own SDS with shape and dtype until we expose ArrayDuck.

PiperOrigin-RevId: 640577261
This commit is contained in:
Yash Katariya 2024-06-05 10:49:48 -07:00 committed by jax authors
parent da87e4470a
commit ebc9de3dbc

View File

@ -369,15 +369,24 @@ class XlaLowering(Lowering):
# -- Public-facing API, plus helpers
@dataclass
@dataclass(frozen=True)
class ArgInfo:
aval: core.AbstractValue
_aval: core.AbstractValue
donated: bool
@dataclass
@property
def shape(self):
return self._aval.shape # pytype: disable=attribute-error
@property
def dtype(self):
return self._aval.dtype # pytype: disable=attribute-error
@dataclass(frozen=True)
class OutInfo:
shape: tuple[int, ...]
dtype: Any
dtype: jax.typing.DTypeLike
sharding: jax.sharding.Sharding
@ -392,7 +401,7 @@ class Stage:
@property
def in_avals(self):
"""Tree of input avals."""
return tree_util.tree_map(lambda x: x.aval, self.args_info)
return tree_util.tree_map(lambda x: x._aval, self.args_info)
@property
def donate_argnums(self):