rocm_jax/tests/custom_object_test.py
Peter Hawkins a87b21148c [MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.

Previously the MLIR lowering rule signature was

```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```

where `ctx` was a module-wide context.

Change it to

```
def rule(ctx, *args, **jaxpr_params)
```

where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.

This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.

PiperOrigin-RevId: 416698663
2021-12-15 19:06:58 -08:00

397 lines
12 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest, parameterized
import numpy as np
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax import core, jit, lax, make_jaxpr
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.lib.mlir import ir
from jax._src.lib import xla_bridge, xla_client
xops = xla_client.ops
xc = xla_client
xb = xla_bridge
from jax.config import config
config.parse_flags_with_absl()
# TODO(jakevdp): use a setup/teardown method to populate and unpopulate all the
# dictionaries associated with the following objects.
# Define a sparse array data structure. The important feature here is that
# it is a jaxpr object that is backed by two device buffers.
class SparseArray:
"""Simple sparse COO array data structure."""
def __init__(self, aval, data, indices):
self.aval = aval
self.shape = aval.shape
self.data = data
self.indices = indices
@property
def index_dtype(self):
return self.indices.dtype
@property
def dtype(self):
return self.data.dtype
@property
def nnz(self):
return self.data.shape[0]
def __repr__(self):
return repr(list((tuple(ind), d) for ind, d in zip(self.indices, self.data)))
class AbstractSparseArray(core.ShapedArray):
__slots__ = ['index_dtype', 'nnz', 'data_aval', 'indices_aval']
def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False,
named_shape=None):
super().__init__(shape, dtypes.canonicalize_dtype(dtype))
named_shape = {} if named_shape is None else named_shape
self.index_dtype = index_dtype
self.nnz = nnz
self.data_aval = core.ShapedArray((nnz,), dtypes.canonicalize_dtype(dtype),
weak_type, named_shape)
self.indices_aval = core.ShapedArray(
(nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype),
named_shape=named_shape)
def update(self, shape=None, dtype=None, index_dtype=None, nnz=None,
weak_type=None, named_shape=None):
if shape is None:
shape = self.shape
if dtype is None:
dtype = self.dtype
if index_dtype is None:
index_dtype = self.dtype
if nnz is None:
nnz = self.nnz
if weak_type is None:
weak_type = self.weak_type
if named_shape is None:
named_shape = self.named_shape
return AbstractSparseArray(
shape, dtype, index_dtype, nnz, weak_type, named_shape)
def strip_weak_type(self):
return self
@core.aval_property
def data(self):
return sp_data_p.bind(self)
@core.aval_property
def indices(self):
return sp_indices_p.bind(self)
class ConcreteSparseArray(AbstractSparseArray):
pass
def sparse_array_result_handler(device, aval):
def build_sparse_array(data_buf, indices_buf):
data = device_array.make_device_array(aval.data_aval, device, data_buf)
indices = device_array.make_device_array(aval.indices_aval, device, indices_buf)
return SparseArray(aval, data, indices)
return build_sparse_array
def sparse_array_shape_handler(a):
return (
xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape),
xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape),
)
def sparse_array_device_put_handler(a, device):
return (
xb.get_device_backend(device).buffer_from_pyval(a.data, device),
xb.get_device_backend(device).buffer_from_pyval(a.indices, device)
)
def sparse_array_constant_handler(c, val, canonicalize_dtypes):
return (
xla.pyval_to_ir_constant(val.data, canonicalize_dtypes),
xla.pyval_to_ir_constant(val.indices, canonicalize_dtypes)
)
core.pytype_aval_mappings[SparseArray] = lambda x: x.aval
core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval
xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
dispatch.device_put_handlers[SparseArray] = sparse_array_device_put_handler
dispatch.result_handlers[AbstractSparseArray] = sparse_array_result_handler
dispatch.num_buffers_handlers[AbstractSparseArray] = lambda _: 2
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
xla.register_constant_handler(SparseArray, sparse_array_constant_handler)
def sparse_array_mlir_type_handler(a):
return (
ir.RankedTensorType.get(
a.data_aval.shape, mlir.dtype_to_ir_type(a.data_aval.dtype)),
ir.RankedTensorType.get(
a.indices_aval.shape, mlir.dtype_to_ir_type(a.indices_aval.dtype)),
)
mlir.ir_type_handlers[AbstractSparseArray] = sparse_array_mlir_type_handler
sp_indices_p = core.Primitive('sp_indices')
@sp_indices_p.def_impl
def _sp_indices_impl(mat):
return mat.indices
@sp_indices_p.def_abstract_eval
def _sp_indices_abstract_eval(mat):
return mat.indices_aval
def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices):
return [indices]
# Note: cannot use lower_fun to define attribute access primitives
# because it leads to infinite recursion.
xla.register_translation(sp_indices_p, _sp_indices_translation_rule)
def _sp_indices_mhlo_lowering(ctx, data_and_indices):
return [data_and_indices[1]]
mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering)
sp_data_p = core.Primitive('sp_data')
@sp_data_p.def_impl
def _sp_data_impl(mat):
return mat.data
@sp_data_p.def_abstract_eval
def _sp_data_abstract_eval(mat):
return mat.data_aval
def _sp_data_translation_rule(ctx, avals_in, avals_out, data, indices):
return [data]
# Note: cannot use lower_fun to define attribute access primitives
# because it leads to infinite recursion.
xla.register_translation(sp_data_p, _sp_data_translation_rule)
def _sp_data_mhlo_lowering(ctx, data_and_indices):
return [data_and_indices[0]]
mlir.register_lowering(sp_data_p, _sp_data_mhlo_lowering)
def identity(x):
return identity_p.bind(x)
identity_p = core.Primitive('identity')
@identity_p.def_impl
def _identity_impl(mat):
return mat
@identity_p.def_abstract_eval
def _identity_abstract_eval(mat):
return AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz)
xla.register_translation(
identity_p, xla.lower_fun(_identity_impl, multiple_results=False,
new_style=True))
mlir.register_lowering(
identity_p, mlir.lower_fun(_identity_impl, multiple_results=False))
def split(x):
return split_p.bind(x)
split_p = core.Primitive('split')
split_p.multiple_results = True
@split_p.def_impl
def _split_impl(mat):
return mat, mat
@split_p.def_abstract_eval
def _split_abstract_eval(mat):
m = AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz)
return m, m
xla.register_translation(
split_p, xla.lower_fun(_split_impl, multiple_results=True, new_style=True))
def make_sparse_array(rng, shape, dtype, nnz=0.2):
mat = rng(shape, dtype)
size = int(np.prod(shape))
if 0 < nnz < 1:
nnz = nnz * size
nnz = int(nnz)
if nnz == 0:
mat = np.zeros_like(mat)
elif nnz < size:
# TODO(jakevdp): do we care about duplicates?
cutoff = np.sort(mat.ravel())[nnz]
mat[mat >= cutoff] = 0
nz = (mat != 0)
data = jnp.array(mat[nz])
indices = jnp.array(np.where(nz)).T
aval = AbstractSparseArray(shape, data.dtype, indices.dtype, len(indices))
return SparseArray(aval, data, indices)
def matvec(mat, v):
v = jnp.asarray(v)
assert v.ndim == 1
assert len(mat.shape) == 2
assert v.shape[0] == mat.shape[1]
rows = mat.indices[:, 0]
cols = mat.indices[:, 1]
dv = mat.data * v[cols]
return jnp.zeros(mat.shape[0], dtype=dv.dtype).at[rows].add(dv)
class Empty:
def __init__(self, aval):
self.aval = aval
class AbstractEmpty(core.AbstractValue):
def join(self, other):
assert isinstance(other, self.__class__), other
return self
def __hash__(self):
return hash(())
def __eq__(self, other):
return isinstance(other, AbstractEmpty)
class ConcreteEmpty(AbstractEmpty):
pass
core.pytype_aval_mappings[Empty] = lambda x: ConcreteEmpty()
core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval
xla.pytype_aval_mappings[Empty] = lambda x: AbstractEmpty()
xla.canonicalize_dtype_handlers[Empty] = lambda x: x
dispatch.device_put_handlers[Empty] = lambda _, __: ()
dispatch.result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
dispatch.num_buffers_handlers[AbstractEmpty] = lambda _: 0
xla.xla_shape_handlers[AbstractEmpty] = lambda _: ()
class CustomObjectTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_compile={}_primitive={}".format(compile, primitive),
"compile": compile, "primitive": primitive}
for primitive in [True, False]
for compile in [True, False]))
def testSparseIdentity(self, compile, primitive):
f = identity if primitive else (lambda x: x)
f = jit(f) if compile else f
rng = jtu.rand_default(self.rng())
M = make_sparse_array(rng, (10,), jnp.float32)
M2 = f(M)
jaxpr = make_jaxpr(f)(M).jaxpr
core.check_jaxpr(jaxpr)
self.assertEqual(M.dtype, M2.dtype)
self.assertEqual(M.index_dtype, M2.index_dtype)
self.assertAllClose(M.data, M2.data)
self.assertAllClose(M.indices, M2.indices)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_compile={}".format(compile),
"compile": compile}
for compile in [True, False]))
def testSparseSplit(self, compile):
f = jit(split) if compile else split
rng = jtu.rand_default(self.rng())
M = make_sparse_array(rng, (10,), jnp.float32)
M2, M3 = f(M)
jaxpr = make_jaxpr(f)(M).jaxpr
core.check_jaxpr(jaxpr)
for MM in M2, M3:
self.assertEqual(M.dtype, MM.dtype)
self.assertEqual(M.index_dtype, MM.index_dtype)
self.assertArraysEqual(M.data, MM.data)
self.assertArraysEqual(M.indices, MM.indices)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_compile={}_primitive={}".format(compile, primitive),
"compile": compile, "primitive": primitive}
for primitive in [True, False]
for compile in [True, False]))
def testSparseLaxLoop(self, compile, primitive):
rng = jtu.rand_default(self.rng())
f = identity if primitive else (lambda x: x)
f = jit(f) if compile else f
body_fun = lambda _, A: f(A)
M = make_sparse_array(rng, (10,), jnp.float32)
lax.fori_loop(0, 10, body_fun, M)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_attr={}".format(attr), "attr": attr}
for attr in ["data", "indices"]))
def testSparseAttrAccess(self, attr):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [make_sparse_array(rng, (10,), jnp.float32)]
f = lambda x: getattr(x, attr)
self._CompileAndCheck(f, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(3, 3), (2, 6), (6, 2)]
for dtype in jtu.dtypes.floating))
def testSparseMatvec(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [make_sparse_array(rng, shape, dtype), rng(shape[-1:], dtype)]
self._CompileAndCheck(matvec, args_maker)
def testLowerToNothing(self):
empty = Empty(AbstractEmpty())
jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr
core.check_jaxpr(jaxpr)
# cannot return a unit, because CompileAndCheck assumes array output.
testfunc = lambda e: None
args_maker = lambda: [empty]
self._CompileAndCheck(testfunc, args_maker)
def testConstantHandler(self):
def make_const_array():
data = np.arange(3.0)
indices = np.arange(3)[:, None]
shape = (5,)
aval = AbstractSparseArray(shape, data.dtype, indices.dtype, len(indices))
return SparseArray(aval, data, indices)
out1 = make_const_array()
out2 = jit(make_const_array)()
self.assertArraysEqual(out1.data, out2.data)
self.assertArraysEqual(out1.indices, out2.indices)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())