[mgpu] Pointwise min

PiperOrigin-RevId: 700175724
This commit is contained in:
Christos Perivolaropoulos 2024-11-25 19:12:56 -08:00 committed by jax authors
parent 627debc78b
commit f828f2d7d0
3 changed files with 60 additions and 16 deletions

View File

@ -622,10 +622,10 @@ class FragmentedArray:
reg, self.shape, new_layout, is_signed=self.is_signed
)
def _pointwise(self, op, *other, output_is_signed: bool | None = None):
def _pointwise(self, op, *other, output_is_signed: bool | None = None, force_no_dispatch=False):
# If our layout is a splat, then we should either dispatch to a non-splat
# layout, or broadcast ourselves to the output shape first.
if isinstance(self.layout, WGSplatFragLayout):
if not force_no_dispatch and isinstance(self.layout, WGSplatFragLayout):
output_shape = self.shape
for i, o in enumerate(other):
if not isinstance(o, FragmentedArray):
@ -642,7 +642,7 @@ class FragmentedArray:
output_shape = np.broadcast_shapes(output_shape, o.shape)
# If we get here then we haven't found any non-splat layout.
return self.broadcast(output_shape)._pointwise(
op, *other, output_is_signed=output_is_signed
op, *other, output_is_signed=output_is_signed, force_no_dispatch=True,
)
other_arrs = []
@ -884,7 +884,17 @@ class FragmentedArray:
arith.maxsi if self.is_signed else arith.maxui, other
)
else:
return NotImplemented
return NotImplementedError
def min(self, other):
if ir.FloatType.isinstance(self.mlir_dtype):
return self._pointwise(arith.minimumf, other)
elif ir.IntegerType.isinstance(self.mlir_dtype):
return self._pointwise(
arith.minsi if self.is_signed else arith.minui, other
)
else:
return NotImplementedError
def exp(self, *, approx: bool = False):
if not ir.FloatType.isinstance(self.mlir_dtype):

View File

@ -1256,6 +1256,7 @@ class FragmentedArrayTest(TestCase):
operator.add,
operator.mul,
operator.sub,
(lambda x, y: mgpu.FragmentedArray.min(x, y), np.minimum),
(lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum),
),
dtype=[jnp.float32, jnp.int32, jnp.uint32],
@ -1285,6 +1286,32 @@ class FragmentedArrayTest(TestCase):
ref_rhs = scalar_rhs or ref_x
np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs))
def test_minimum_np_compatibility(self):
one = np.ones((128, 128)).astype(np.float32)
negz = one * -0.
posz = one * 0.
nan = one * np.nan
expectation = (np.minimum(negz, posz) == negz) & (np.minimum(nan, one) != one)
assert np.all(expectation), expectation
def kernel(ctx, dst, _):
f32 = ir.F32Type.get()
splat = lambda i: mgpu.FragmentedArray.splat(c(i, f32), (128, 128))
negz = splat(-0.)
posz = splat(0.)
nan = splat(np.nan)
one = splat(1.)
res = (negz.min(posz) == negz) & (one.min(nan) != one) & (nan.min(one) != one)
i8 = ir.IntegerType.get_signless(8)
res.astype(i8, is_signed=False).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((128, 128), np.int8)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
# astype() uses extsi so i1=True becomes -1
np.testing.assert_array_equal(result == -1, expectation)
@parameterized.product(
op=[operator.truediv, operator.floordiv, operator.mod],
dtype=[jnp.float32, jnp.int32, jnp.uint32],

View File

@ -83,6 +83,25 @@ class PallasCallTest(PallasTest):
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_allclose(kernel(x), unary(x), rtol=rtol)
@parameterized.named_parameters(
("add", lambda x, y: x + y),
("mul", lambda x, y: x * y),
("div", lambda x, y: x / y),
("min", lambda x, y: jnp.minimum(x, y)),
("max", lambda x, y: jnp.maximum(x, y)),
)
def test_binary_op(self, bop):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = bop(x_ref[...], y_ref[...])
x = jnp.arange(256).astype(jnp.float32)
y = x + 1
np.testing.assert_array_equal(kernel(x, y), bop(x, y))
def test_add_first(self):
@functools.partial(
pl.pallas_call,
@ -111,18 +130,6 @@ class PallasCallTest(PallasTest):
x = jnp.arange(math.prod(shape1)).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x.reshape(shape2))
def test_add_xy(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[...]
x = jnp.arange(256).astype(jnp.float32)
y = x + 1
np.testing.assert_array_equal(kernel(x, y), x + y)
def test_add_xy_indexed(self):
@functools.partial(
pl.pallas_call,