mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add helpers for internal static shape/dimension assertions
Also some typing fixes.
This commit is contained in:
parent
5be53524a7
commit
971afab587
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
from functools import partial, reduce
|
||||
from typing import Iterable, NamedTuple, Sequence, Union
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -31,9 +32,9 @@ map, zip = util.safe_map, util.safe_zip
|
||||
DInt = jax.Array
|
||||
Address = DInt
|
||||
XInt = Union[int, DInt]
|
||||
DShape = tuple[XInt]
|
||||
SShape = tuple[int]
|
||||
DType = jax.typing.DTypeLike
|
||||
DShape = tuple[XInt, ...]
|
||||
SShape = tuple[int, ...]
|
||||
DType = jnp.dtype
|
||||
|
||||
class Slab(NamedTuple):
|
||||
data: jax.Array
|
||||
@ -75,10 +76,10 @@ def _xmul(x: XInt, y: XInt) -> XInt:
|
||||
return x * y
|
||||
|
||||
def xadd(*xs: XInt) -> XInt:
|
||||
return reduce(_xadd, xs, 0)
|
||||
return reduce(_xadd, xs, typing.cast(XInt, 0))
|
||||
|
||||
def xmul(*xs: XInt) -> XInt:
|
||||
return reduce(_xmul, xs, 1)
|
||||
return reduce(_xmul, xs, typing.cast(XInt, 1))
|
||||
|
||||
def xsum(xs: Iterable[XInt]) -> XInt:
|
||||
return xadd(*list(xs))
|
||||
@ -86,7 +87,23 @@ def xsum(xs: Iterable[XInt]) -> XInt:
|
||||
def xprod(xs: Iterable[XInt]) -> XInt:
|
||||
return xmul(*list(xs))
|
||||
|
||||
def tile_shape(shape: DShape, dtype):
|
||||
def static_int(x: XInt) -> bool:
|
||||
return isinstance(core.get_aval(x), core.ConcreteArray)
|
||||
|
||||
def static_shape(s: DShape) -> bool:
|
||||
return all(map(static_int, s))
|
||||
|
||||
def assert_static_int(x: XInt) -> int:
|
||||
if not static_int(x):
|
||||
raise TypeError(f'{x} is not a static int')
|
||||
return int(x)
|
||||
|
||||
def assert_static_shape(s: DShape) -> SShape:
|
||||
if not static_shape(s):
|
||||
raise TypeError(f'{s} is not a static shape')
|
||||
return tuple(map(int, s))
|
||||
|
||||
def tile_shape(shape: DShape, dtype) -> SShape:
|
||||
# Units: (1, 1, ..., elements, 1)
|
||||
if len(shape) < 2:
|
||||
raise NotImplementedError('matrices or bust')
|
||||
@ -146,7 +163,8 @@ def reinterpret_cast(x: jax.Array, shape: SShape, dtype: DType):
|
||||
|
||||
def slab_read(slab, view, slice_base: DShape, slice_shape: SShape):
|
||||
view_tile_shape = tile_shape(view.shape, view.dtype)
|
||||
tiled_shape = map(xceil_div, slice_shape, view_tile_shape)
|
||||
tiled_shape = assert_static_shape(
|
||||
tuple(map(xceil_div, slice_shape, view_tile_shape)))
|
||||
slices = [
|
||||
jax.lax.dynamic_slice_in_dim(slab.data, addr, phrases)
|
||||
for addr, phrases in slab_slices(view, slice_base, slice_shape)]
|
||||
@ -290,9 +308,6 @@ def slab_download(slab, v):
|
||||
if not static_shape(v.shape): raise Exception
|
||||
return slab_read(slab, v, (0,) * v.ndim(), v.shape)
|
||||
|
||||
def static_shape(s: DShape) -> bool:
|
||||
return all(isinstance(core.get_aval(d), core.ConcreteArray) for d in s)
|
||||
|
||||
def slab_upload(slab, x):
|
||||
slab, xv = slab_alloc(slab, x.shape, x.dtype)
|
||||
slab = slab_write(slab, xv, (0,) * x.ndim, x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user