mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Expose shape
and dtype
on ArgInfo and mark aval
as private. Aval is an internal property of JAX and shouldn't have been exposed to users. Users can create their own SDS with shape and dtype until we expose ArrayDuck
.
PiperOrigin-RevId: 640577261
This commit is contained in:
parent
da87e4470a
commit
ebc9de3dbc
@ -369,15 +369,24 @@ class XlaLowering(Lowering):
|
||||
|
||||
# -- Public-facing API, plus helpers
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class ArgInfo:
|
||||
aval: core.AbstractValue
|
||||
_aval: core.AbstractValue
|
||||
donated: bool
|
||||
|
||||
@dataclass
|
||||
@property
|
||||
def shape(self):
|
||||
return self._aval.shape # pytype: disable=attribute-error
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._aval.dtype # pytype: disable=attribute-error
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OutInfo:
|
||||
shape: tuple[int, ...]
|
||||
dtype: Any
|
||||
dtype: jax.typing.DTypeLike
|
||||
sharding: jax.sharding.Sharding
|
||||
|
||||
|
||||
@ -392,7 +401,7 @@ class Stage:
|
||||
@property
|
||||
def in_avals(self):
|
||||
"""Tree of input avals."""
|
||||
return tree_util.tree_map(lambda x: x.aval, self.args_info)
|
||||
return tree_util.tree_map(lambda x: x._aval, self.args_info)
|
||||
|
||||
@property
|
||||
def donate_argnums(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user