mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 19:46:07 +00:00
245 lines
8.8 KiB
Python
245 lines
8.8 KiB
Python
# Copyright 2019 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
from collections import namedtuple
|
|
import functools
|
|
import operator as op
|
|
from typing import Any, Callable
|
|
|
|
import numpy as onp
|
|
|
|
from .util import safe_map, safe_zip, unzip2, subvals
|
|
from .lib import xla_bridge as xb
|
|
|
|
map = safe_map
|
|
zip = safe_zip
|
|
|
|
|
|
### util
|
|
|
|
# TODO(mattjj): replace with dataclass when Python 2 support is removed
|
|
def taggedtuple(name, fields) -> Callable[..., Any]:
|
|
"""Lightweight version of namedtuple where equality depends on the type."""
|
|
def __new__(cls, *xs):
|
|
return tuple.__new__(cls, (cls,) + xs)
|
|
def __str__(self):
|
|
return '{}{}'.format(name, tuple.__str__(self[1:]))
|
|
class_namespace = {'__new__' : __new__, '__str__': __str__}
|
|
for i, f in enumerate(fields):
|
|
class_namespace[f] = property(op.itemgetter(i+1))
|
|
return type(name, (tuple,), class_namespace)
|
|
|
|
|
|
### lazy sublanguage
|
|
|
|
# There are two components to a LazyExpr: an input and a reindexing
|
|
# specification. The input represents a base array to which the reindexing
|
|
# specification is applied.
|
|
#
|
|
# An input can represent an array constructor (Iota, Eye, etc.) or it can be an
|
|
# ArrayVar which encodes that the base array is some exogenous array value (from
|
|
# an environment with only a single value in it). These LazyExprs are attached
|
|
# to DeviceArrays, so when the input part of the expression is ArrayVar that
|
|
# basically means the associated device buffer represents the input, while if
|
|
# the input is an array constructor then the associated device_buffer field of
|
|
# the DeviceArray should be set to a DeviceConstant sentinel value. For the
|
|
# array constructor expressions:
|
|
# * Iota builds a 1D sequence [0, 1, ..., N-1],
|
|
# * Eye builds a 2D array with ones on a (possibly offset) diagonal and zeros
|
|
# elsewhere (like numpy.eye),
|
|
# * Tri builds a triangular matrix with ones on and below a diagonal and zeros
|
|
# elsewhere (like numpy.tri), and
|
|
# * Delta builds a Kronecker delta array with ones along its multidimensional
|
|
# main diagonal and zeros elsewhere (for use in tensor contractions).
|
|
#
|
|
# The reindexing specification encodes the shape of the final result and a list
|
|
# of dimensions, which are integers or Nones. The integer entries take on values
|
|
# 0, 1, ..., R-1 where R is the rank of the input array, and encode where the
|
|
# axes of the input array are to be mapped in the final output. When an entry is
|
|
# None that indicates that the corresponding axis of the result is a broadcasted
|
|
# one.
|
|
#
|
|
# Here are some examples of lazy expressions and the arrays they represent:
|
|
#
|
|
# LazyExpr(input=Iota(dtype=dtype('float32'), size=3),
|
|
# shape=(3, 4), dims=(0, None))
|
|
# DeviceArray([[0., 0., 0., 0.],
|
|
# [1., 1., 1., 1.],
|
|
# [2., 2., 2., 2.]], dtype=float32)
|
|
#
|
|
# LazyExpr(input=Iota(dtype=dtype('float32'), size=3),
|
|
# shape=(4, 3), dims=(None, 0))
|
|
# DeviceArray([[0., 1., 2.],
|
|
# [0., 1., 2.],
|
|
# [0., 1., 2.],
|
|
# [0., 1., 2.]], dtype=float32)
|
|
#
|
|
# For performance, some functions on lazy expressions accept None as an input to
|
|
# stand for the identity lazy expression.
|
|
#
|
|
# We use the `taggedtuple` class constructor, rather than standard namedtuples,
|
|
# because two namedtuple instances of different types but equal elements hash to
|
|
# the same value, e.g.
|
|
# A = namedtuple('A', ['x', 'y'])
|
|
# B = namedtuple('B', ['x', 'y'])
|
|
# hash(A(1, 2)) == hash(B(1, 2)) # True
|
|
# but we want hashes to be sensitive to the type tag (while still being fast).
|
|
|
|
# pytype: disable=wrong-arg-count
|
|
LazyExpr = namedtuple('LazyExpr', ['input', 'shape', 'dims'])
|
|
ArrayVar = taggedtuple('ArrayVar', [])
|
|
Iota = taggedtuple('Iota', ['dtype', 'size']) # like np.arange(N)
|
|
Eye = taggedtuple('Eye', ['dtype', 'shape', 'offset']) # like np.eye
|
|
Tri = taggedtuple('Tri', ['dtype', 'shape', 'offset']) # like np.tri
|
|
Delta = taggedtuple('Delta', ['dtype', 'shape']) # kronecker delta arrays
|
|
# pytype: enable=wrong-arg-count
|
|
|
|
def array(shape):
|
|
return LazyExpr(ArrayVar(), shape, tuple(range(len(shape))))
|
|
|
|
def iota(dtype, size):
|
|
return LazyExpr(Iota(dtype, size), (size,), (0,))
|
|
|
|
def eye(dtype, shape, offset):
|
|
assert len(shape) == 2
|
|
return LazyExpr(Eye(dtype, shape, offset), shape, (0, 1))
|
|
|
|
def tri(dtype, shape, offset):
|
|
assert len(shape) == 2
|
|
return LazyExpr(Tri(dtype, shape, offset), shape, (0, 1))
|
|
|
|
def delta(dtype, shape):
|
|
return LazyExpr(Delta(dtype, shape), shape, tuple(range(len(shape))))
|
|
|
|
def broadcast(lexpr, shape, broadcast_dimensions):
|
|
new_dims = [None] * len(shape)
|
|
for i, d in enumerate(broadcast_dimensions):
|
|
new_dims[d] = lexpr.dims[i]
|
|
return LazyExpr(lexpr.input, shape, tuple(new_dims))
|
|
|
|
def transpose(lexpr, perm):
|
|
new_shape = tuple(lexpr.shape[i] for i in perm)
|
|
new_dims = tuple(lexpr.dims[i] for i in perm)
|
|
return LazyExpr(lexpr.input, new_shape, new_dims)
|
|
|
|
def is_constant(lexpr):
|
|
return lexpr is not None and type(lexpr.input) is not ArrayVar
|
|
|
|
def is_trivial(lexpr):
|
|
return (type(lexpr.input) is ArrayVar and
|
|
lexpr.dims == tuple(range(len(lexpr.shape))))
|
|
|
|
|
|
def eval_lexpr(lexpr, x):
|
|
"""Evaluate a lazy expression using NumPy.
|
|
Args:
|
|
lexpr: the LazyExpr to evaluate.
|
|
x: ndarray or None, representing the value of ArrayVar if present.
|
|
Returns:
|
|
An ndarray representing the value of the lazy expression.
|
|
"""
|
|
if is_trivial(lexpr):
|
|
return x
|
|
|
|
input_, shape, dims = lexpr
|
|
|
|
# first create a starting ndarray from input_
|
|
t = type(input_)
|
|
if t is ArrayVar:
|
|
assert x is not None and type(x) is onp.ndarray
|
|
elif t is Iota:
|
|
assert x is None
|
|
x = onp.arange(input_.size, dtype=input_.dtype)
|
|
elif t is Eye:
|
|
assert x is None
|
|
N, M = input_.shape
|
|
x = onp.eye(N, M, dtype=input_.dtype, k=input_.offset)
|
|
elif t is Tri:
|
|
assert x is None
|
|
N, M = input_.shape
|
|
x = onp.tri(N, M, dtype=input_.dtype, k=input_.offset)
|
|
elif t is Delta:
|
|
ones = [1] * len(input_.shape)
|
|
iotas = [onp.arange(d).reshape(subvals(ones, [(i, -1)]))
|
|
for i, d in enumerate(input_.shape)]
|
|
eyes = [i1 == i2 for i1, i2 in zip(iotas[:-1], iotas[1:])]
|
|
x = onp.asarray(functools.reduce(op.and_, eyes), input_.dtype)
|
|
else:
|
|
assert False
|
|
|
|
# then apply the reindexing operation
|
|
perm = [d for d in dims if d is not None]
|
|
if perm != list(range(len(perm))):
|
|
x = onp.transpose(x, perm)
|
|
if shape != x.shape:
|
|
in_shape = [1 if d is None else s for d, s in zip(dims, shape)]
|
|
x = onp.broadcast_to(onp.reshape(x, in_shape), shape)
|
|
|
|
return x
|
|
|
|
|
|
def stage_lexpr(c, lexpr, x):
|
|
"""Stage a lazy expression into an XLA computation.
|
|
Args:
|
|
c: XLA ComputationBuilder into which to stage the expression.
|
|
lexpr: a LazyExpr to evaluate (or None for the identity expression).
|
|
x: XlaOp or None, representing the value of ArrayVar if present.
|
|
Returns:
|
|
An XlaOp representing the value of the lazy expression.
|
|
"""
|
|
if lexpr is None or is_trivial(lexpr):
|
|
return x
|
|
|
|
input_, shape, dims = lexpr
|
|
|
|
# first create a starting XlaOp from input_
|
|
t = type(input_)
|
|
if t is ArrayVar:
|
|
assert x is not None
|
|
elif t is Iota:
|
|
assert x is None
|
|
x = c.Iota(input_.dtype, input_.size)
|
|
elif t is Eye:
|
|
assert x is None
|
|
N, M = input_.shape
|
|
bool_eye = c.Eq(c.Add(c.BroadcastedIota(onp.int32, (N, M), 0),
|
|
c.Constant(onp.array(input_.offset, onp.int32))),
|
|
c.BroadcastedIota(onp.int32, (N, M), 1))
|
|
x = c.ConvertElementType(bool_eye, xb.dtype_to_etype(input_.dtype))
|
|
elif t is Tri:
|
|
assert x is None
|
|
N, M = input_.shape
|
|
bool_tri = c.Ge(c.Add(c.BroadcastedIota(onp.int32, (N, M), 0),
|
|
c.Constant(onp.array(input_.offset, onp.int32))),
|
|
c.BroadcastedIota(onp.int32, (N, M), 1))
|
|
x = c.ConvertElementType(bool_tri, xb.dtype_to_etype(input_.dtype))
|
|
elif t is Delta:
|
|
etype = xb.dtype_to_etype(input_.dtype)
|
|
iotas = [c.BroadcastedIota(onp.uint32, input_.shape, i)
|
|
for i in range(len(input_.shape))]
|
|
eyes = [c.Eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
|
|
x = c.ConvertElementType(functools.reduce(c.And, eyes), etype)
|
|
else:
|
|
assert False
|
|
|
|
# then apply the operations encoded in reindex
|
|
bcast_dims, perm = unzip2((i, d) for i, d in enumerate(dims) if d is not None)
|
|
if tuple(perm) != tuple(range(len(perm))):
|
|
x = c.Transpose(x, perm)
|
|
if shape != c.GetShape(x).dimensions():
|
|
x = c.BroadcastInDim(x, shape, bcast_dims)
|
|
|
|
return x
|