mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[mgpu] Pointwise min
PiperOrigin-RevId: 700175724
This commit is contained in:
parent
627debc78b
commit
f828f2d7d0
@ -622,10 +622,10 @@ class FragmentedArray:
|
||||
reg, self.shape, new_layout, is_signed=self.is_signed
|
||||
)
|
||||
|
||||
def _pointwise(self, op, *other, output_is_signed: bool | None = None):
|
||||
def _pointwise(self, op, *other, output_is_signed: bool | None = None, force_no_dispatch=False):
|
||||
# If our layout is a splat, then we should either dispatch to a non-splat
|
||||
# layout, or broadcast ourselves to the output shape first.
|
||||
if isinstance(self.layout, WGSplatFragLayout):
|
||||
if not force_no_dispatch and isinstance(self.layout, WGSplatFragLayout):
|
||||
output_shape = self.shape
|
||||
for i, o in enumerate(other):
|
||||
if not isinstance(o, FragmentedArray):
|
||||
@ -642,7 +642,7 @@ class FragmentedArray:
|
||||
output_shape = np.broadcast_shapes(output_shape, o.shape)
|
||||
# If we get here then we haven't found any non-splat layout.
|
||||
return self.broadcast(output_shape)._pointwise(
|
||||
op, *other, output_is_signed=output_is_signed
|
||||
op, *other, output_is_signed=output_is_signed, force_no_dispatch=True,
|
||||
)
|
||||
|
||||
other_arrs = []
|
||||
@ -884,7 +884,17 @@ class FragmentedArray:
|
||||
arith.maxsi if self.is_signed else arith.maxui, other
|
||||
)
|
||||
else:
|
||||
return NotImplemented
|
||||
return NotImplementedError
|
||||
|
||||
def min(self, other):
|
||||
if ir.FloatType.isinstance(self.mlir_dtype):
|
||||
return self._pointwise(arith.minimumf, other)
|
||||
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
||||
return self._pointwise(
|
||||
arith.minsi if self.is_signed else arith.minui, other
|
||||
)
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
def exp(self, *, approx: bool = False):
|
||||
if not ir.FloatType.isinstance(self.mlir_dtype):
|
||||
|
@ -1256,6 +1256,7 @@ class FragmentedArrayTest(TestCase):
|
||||
operator.add,
|
||||
operator.mul,
|
||||
operator.sub,
|
||||
(lambda x, y: mgpu.FragmentedArray.min(x, y), np.minimum),
|
||||
(lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum),
|
||||
),
|
||||
dtype=[jnp.float32, jnp.int32, jnp.uint32],
|
||||
@ -1285,6 +1286,32 @@ class FragmentedArrayTest(TestCase):
|
||||
ref_rhs = scalar_rhs or ref_x
|
||||
np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs))
|
||||
|
||||
def test_minimum_np_compatibility(self):
|
||||
one = np.ones((128, 128)).astype(np.float32)
|
||||
negz = one * -0.
|
||||
posz = one * 0.
|
||||
nan = one * np.nan
|
||||
expectation = (np.minimum(negz, posz) == negz) & (np.minimum(nan, one) != one)
|
||||
assert np.all(expectation), expectation
|
||||
|
||||
def kernel(ctx, dst, _):
|
||||
f32 = ir.F32Type.get()
|
||||
splat = lambda i: mgpu.FragmentedArray.splat(c(i, f32), (128, 128))
|
||||
negz = splat(-0.)
|
||||
posz = splat(0.)
|
||||
nan = splat(np.nan)
|
||||
one = splat(1.)
|
||||
res = (negz.min(posz) == negz) & (one.min(nan) != one) & (nan.min(one) != one)
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
res.astype(i8, is_signed=False).store_untiled(dst)
|
||||
|
||||
out_shape = jax.ShapeDtypeStruct((128, 128), np.int8)
|
||||
result = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
||||
)()
|
||||
# astype() uses extsi so i1=True becomes -1
|
||||
np.testing.assert_array_equal(result == -1, expectation)
|
||||
|
||||
@parameterized.product(
|
||||
op=[operator.truediv, operator.floordiv, operator.mod],
|
||||
dtype=[jnp.float32, jnp.int32, jnp.uint32],
|
||||
|
@ -83,6 +83,25 @@ class PallasCallTest(PallasTest):
|
||||
x = jnp.arange(256).astype(jnp.float32)
|
||||
np.testing.assert_allclose(kernel(x), unary(x), rtol=rtol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("add", lambda x, y: x + y),
|
||||
("mul", lambda x, y: x * y),
|
||||
("div", lambda x, y: x / y),
|
||||
("min", lambda x, y: jnp.minimum(x, y)),
|
||||
("max", lambda x, y: jnp.maximum(x, y)),
|
||||
)
|
||||
def test_binary_op(self, bop):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
||||
)
|
||||
def kernel(x_ref, y_ref, o_ref):
|
||||
o_ref[...] = bop(x_ref[...], y_ref[...])
|
||||
|
||||
x = jnp.arange(256).astype(jnp.float32)
|
||||
y = x + 1
|
||||
np.testing.assert_array_equal(kernel(x, y), bop(x, y))
|
||||
|
||||
def test_add_first(self):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
@ -111,18 +130,6 @@ class PallasCallTest(PallasTest):
|
||||
x = jnp.arange(math.prod(shape1)).astype(jnp.float32)
|
||||
np.testing.assert_array_equal(kernel(x), x.reshape(shape2))
|
||||
|
||||
def test_add_xy(self):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
||||
)
|
||||
def kernel(x_ref, y_ref, o_ref):
|
||||
o_ref[...] = x_ref[...] + y_ref[...]
|
||||
|
||||
x = jnp.arange(256).astype(jnp.float32)
|
||||
y = x + 1
|
||||
np.testing.assert_array_equal(kernel(x, y), x + y)
|
||||
|
||||
def test_add_xy_indexed(self):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
|
Loading…
x
Reference in New Issue
Block a user