implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
# Copyright 2019 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
from collections import namedtuple
|
2020-01-08 13:17:55 -05:00
|
|
|
import functools
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
import operator as op
|
2020-01-18 08:26:23 -05:00
|
|
|
from typing import Any, Callable
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
|
|
|
import numpy as onp
|
|
|
|
|
|
|
|
from .util import safe_map, safe_zip, unzip2, subvals
|
|
|
|
from .lib import xla_bridge as xb
|
|
|
|
|
|
|
|
map = safe_map
|
|
|
|
zip = safe_zip
|
|
|
|
|
|
|
|
|
|
|
|
### util
|
|
|
|
|
|
|
|
# TODO(mattjj): replace with dataclass when Python 2 support is removed
|
2020-01-18 08:26:23 -05:00
|
|
|
def taggedtuple(name, fields) -> Callable[..., Any]:
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
"""Lightweight version of namedtuple where equality depends on the type."""
|
|
|
|
def __new__(cls, *xs):
|
|
|
|
return tuple.__new__(cls, (cls,) + xs)
|
|
|
|
def __str__(self):
|
|
|
|
return '{}{}'.format(name, tuple.__str__(self[1:]))
|
|
|
|
class_namespace = {'__new__' : __new__, '__str__': __str__}
|
|
|
|
for i, f in enumerate(fields):
|
|
|
|
class_namespace[f] = property(op.itemgetter(i+1))
|
|
|
|
return type(name, (tuple,), class_namespace)
|
|
|
|
|
|
|
|
|
|
|
|
### lazy sublanguage
|
|
|
|
|
|
|
|
# There are two components to a LazyExpr: an input and a reindexing
|
|
|
|
# specification. The input represents a base array to which the reindexing
|
|
|
|
# specification is applied.
|
|
|
|
#
|
|
|
|
# An input can represent an array constructor (Iota, Eye, etc.) or it can be an
|
|
|
|
# ArrayVar which encodes that the base array is some exogenous array value (from
|
|
|
|
# an environment with only a single value in it). These LazyExprs are attached
|
|
|
|
# to DeviceArrays, so when the input part of the expression is ArrayVar that
|
|
|
|
# basically means the associated device buffer represents the input, while if
|
|
|
|
# the input is an array constructor then the associated device_buffer field of
|
|
|
|
# the DeviceArray should be set to a DeviceConstant sentinel value. For the
|
|
|
|
# array constructor expressions:
|
|
|
|
# * Iota builds a 1D sequence [0, 1, ..., N-1],
|
|
|
|
# * Eye builds a 2D array with ones on a (possibly offset) diagonal and zeros
|
|
|
|
# elsewhere (like numpy.eye),
|
|
|
|
# * Tri builds a triangular matrix with ones on and below a diagonal and zeros
|
|
|
|
# elsewhere (like numpy.tri), and
|
|
|
|
# * Delta builds a Kronecker delta array with ones along its multidimensional
|
|
|
|
# main diagonal and zeros elsewhere (for use in tensor contractions).
|
|
|
|
#
|
|
|
|
# The reindexing specification encodes the shape of the final result and a list
|
|
|
|
# of dimensions, which are integers or Nones. The integer entries take on values
|
|
|
|
# 0, 1, ..., R-1 where R is the rank of the input array, and encode where the
|
|
|
|
# axes of the input array are to be mapped in the final output. When an entry is
|
|
|
|
# None that indicates that the corresponding axis of the result is a broadcasted
|
|
|
|
# one.
|
|
|
|
#
|
|
|
|
# Here are some examples of lazy expressions and the arrays they represent:
|
|
|
|
#
|
|
|
|
# LazyExpr(input=Iota(dtype=dtype('float32'), size=3),
|
|
|
|
# shape=(3, 4), dims=(0, None))
|
|
|
|
# DeviceArray([[0., 0., 0., 0.],
|
|
|
|
# [1., 1., 1., 1.],
|
|
|
|
# [2., 2., 2., 2.]], dtype=float32)
|
|
|
|
#
|
|
|
|
# LazyExpr(input=Iota(dtype=dtype('float32'), size=3),
|
|
|
|
# shape=(4, 3), dims=(None, 0))
|
|
|
|
# DeviceArray([[0., 1., 2.],
|
|
|
|
# [0., 1., 2.],
|
|
|
|
# [0., 1., 2.],
|
|
|
|
# [0., 1., 2.]], dtype=float32)
|
|
|
|
#
|
|
|
|
# For performance, some functions on lazy expressions accept None as an input to
|
|
|
|
# stand for the identity lazy expression.
|
|
|
|
#
|
|
|
|
# We use the `taggedtuple` class constructor, rather than standard namedtuples,
|
|
|
|
# because two namedtuple instances of different types but equal elements hash to
|
|
|
|
# the same value, e.g.
|
|
|
|
# A = namedtuple('A', ['x', 'y'])
|
|
|
|
# B = namedtuple('B', ['x', 'y'])
|
|
|
|
# hash(A(1, 2)) == hash(B(1, 2)) # True
|
|
|
|
# but we want hashes to be sensitive to the type tag (while still being fast).
|
|
|
|
|
2020-01-18 08:26:23 -05:00
|
|
|
# pytype: disable=wrong-arg-count
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
LazyExpr = namedtuple('LazyExpr', ['input', 'shape', 'dims'])
|
|
|
|
ArrayVar = taggedtuple('ArrayVar', [])
|
|
|
|
Iota = taggedtuple('Iota', ['dtype', 'size']) # like np.arange(N)
|
|
|
|
Eye = taggedtuple('Eye', ['dtype', 'shape', 'offset']) # like np.eye
|
|
|
|
Tri = taggedtuple('Tri', ['dtype', 'shape', 'offset']) # like np.tri
|
|
|
|
Delta = taggedtuple('Delta', ['dtype', 'shape']) # kronecker delta arrays
|
2020-01-18 08:26:23 -05:00
|
|
|
# pytype: enable=wrong-arg-count
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
|
|
|
def array(shape):
|
|
|
|
return LazyExpr(ArrayVar(), shape, tuple(range(len(shape))))
|
|
|
|
|
|
|
|
def iota(dtype, size):
|
|
|
|
return LazyExpr(Iota(dtype, size), (size,), (0,))
|
|
|
|
|
|
|
|
def eye(dtype, shape, offset):
|
|
|
|
assert len(shape) == 2
|
|
|
|
return LazyExpr(Eye(dtype, shape, offset), shape, (0, 1))
|
|
|
|
|
|
|
|
def tri(dtype, shape, offset):
|
|
|
|
assert len(shape) == 2
|
|
|
|
return LazyExpr(Tri(dtype, shape, offset), shape, (0, 1))
|
|
|
|
|
|
|
|
def delta(dtype, shape):
|
|
|
|
return LazyExpr(Delta(dtype, shape), shape, tuple(range(len(shape))))
|
|
|
|
|
|
|
|
def broadcast(lexpr, shape, broadcast_dimensions):
|
|
|
|
new_dims = [None] * len(shape)
|
|
|
|
for i, d in enumerate(broadcast_dimensions):
|
|
|
|
new_dims[d] = lexpr.dims[i]
|
|
|
|
return LazyExpr(lexpr.input, shape, tuple(new_dims))
|
|
|
|
|
|
|
|
def transpose(lexpr, perm):
|
|
|
|
new_shape = tuple(lexpr.shape[i] for i in perm)
|
|
|
|
new_dims = tuple(lexpr.dims[i] for i in perm)
|
|
|
|
return LazyExpr(lexpr.input, new_shape, new_dims)
|
|
|
|
|
|
|
|
def is_constant(lexpr):
|
|
|
|
return lexpr is not None and type(lexpr.input) is not ArrayVar
|
|
|
|
|
|
|
|
def is_trivial(lexpr):
|
|
|
|
return (type(lexpr.input) is ArrayVar and
|
|
|
|
lexpr.dims == tuple(range(len(lexpr.shape))))
|
|
|
|
|
|
|
|
|
|
|
|
def eval_lexpr(lexpr, x):
|
|
|
|
"""Evaluate a lazy expression using NumPy.
|
|
|
|
Args:
|
|
|
|
lexpr: the LazyExpr to evaluate.
|
|
|
|
x: ndarray or None, representing the value of ArrayVar if present.
|
|
|
|
Returns:
|
|
|
|
An ndarray representing the value of the lazy expression.
|
|
|
|
"""
|
|
|
|
if is_trivial(lexpr):
|
|
|
|
return x
|
|
|
|
|
|
|
|
input_, shape, dims = lexpr
|
|
|
|
|
|
|
|
# first create a starting ndarray from input_
|
|
|
|
t = type(input_)
|
|
|
|
if t is ArrayVar:
|
|
|
|
assert x is not None and type(x) is onp.ndarray
|
|
|
|
elif t is Iota:
|
|
|
|
assert x is None
|
|
|
|
x = onp.arange(input_.size, dtype=input_.dtype)
|
|
|
|
elif t is Eye:
|
|
|
|
assert x is None
|
|
|
|
N, M = input_.shape
|
|
|
|
x = onp.eye(N, M, dtype=input_.dtype, k=input_.offset)
|
|
|
|
elif t is Tri:
|
|
|
|
assert x is None
|
|
|
|
N, M = input_.shape
|
|
|
|
x = onp.tri(N, M, dtype=input_.dtype, k=input_.offset)
|
|
|
|
elif t is Delta:
|
|
|
|
ones = [1] * len(input_.shape)
|
|
|
|
iotas = [onp.arange(d).reshape(subvals(ones, [(i, -1)]))
|
|
|
|
for i, d in enumerate(input_.shape)]
|
|
|
|
eyes = [i1 == i2 for i1, i2 in zip(iotas[:-1], iotas[1:])]
|
2020-01-08 13:17:55 -05:00
|
|
|
x = onp.asarray(functools.reduce(op.and_, eyes), input_.dtype)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
|
|
|
|
# then apply the reindexing operation
|
|
|
|
perm = [d for d in dims if d is not None]
|
|
|
|
if perm != list(range(len(perm))):
|
|
|
|
x = onp.transpose(x, perm)
|
|
|
|
if shape != x.shape:
|
|
|
|
in_shape = [1 if d is None else s for d, s in zip(dims, shape)]
|
|
|
|
x = onp.broadcast_to(onp.reshape(x, in_shape), shape)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def stage_lexpr(c, lexpr, x):
|
|
|
|
"""Stage a lazy expression into an XLA computation.
|
|
|
|
Args:
|
|
|
|
c: XLA ComputationBuilder into which to stage the expression.
|
|
|
|
lexpr: a LazyExpr to evaluate (or None for the identity expression).
|
|
|
|
x: XlaOp or None, representing the value of ArrayVar if present.
|
|
|
|
Returns:
|
|
|
|
An XlaOp representing the value of the lazy expression.
|
|
|
|
"""
|
|
|
|
if lexpr is None or is_trivial(lexpr):
|
|
|
|
return x
|
|
|
|
|
|
|
|
input_, shape, dims = lexpr
|
|
|
|
|
|
|
|
# first create a starting XlaOp from input_
|
|
|
|
t = type(input_)
|
|
|
|
if t is ArrayVar:
|
|
|
|
assert x is not None
|
|
|
|
elif t is Iota:
|
|
|
|
assert x is None
|
|
|
|
x = c.Iota(input_.dtype, input_.size)
|
|
|
|
elif t is Eye:
|
|
|
|
assert x is None
|
|
|
|
N, M = input_.shape
|
|
|
|
bool_eye = c.Eq(c.Add(c.BroadcastedIota(onp.int32, (N, M), 0),
|
|
|
|
c.Constant(onp.array(input_.offset, onp.int32))),
|
|
|
|
c.BroadcastedIota(onp.int32, (N, M), 1))
|
|
|
|
x = c.ConvertElementType(bool_eye, xb.dtype_to_etype(input_.dtype))
|
|
|
|
elif t is Tri:
|
|
|
|
assert x is None
|
|
|
|
N, M = input_.shape
|
|
|
|
bool_tri = c.Ge(c.Add(c.BroadcastedIota(onp.int32, (N, M), 0),
|
|
|
|
c.Constant(onp.array(input_.offset, onp.int32))),
|
|
|
|
c.BroadcastedIota(onp.int32, (N, M), 1))
|
|
|
|
x = c.ConvertElementType(bool_tri, xb.dtype_to_etype(input_.dtype))
|
|
|
|
elif t is Delta:
|
|
|
|
etype = xb.dtype_to_etype(input_.dtype)
|
|
|
|
iotas = [c.BroadcastedIota(onp.uint32, input_.shape, i)
|
|
|
|
for i in range(len(input_.shape))]
|
|
|
|
eyes = [c.Eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
|
2020-01-08 13:17:55 -05:00
|
|
|
x = c.ConvertElementType(functools.reduce(c.And, eyes), etype)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
|
|
|
|
# then apply the operations encoded in reindex
|
|
|
|
bcast_dims, perm = unzip2((i, d) for i, d in enumerate(dims) if d is not None)
|
|
|
|
if tuple(perm) != tuple(range(len(perm))):
|
|
|
|
x = c.Transpose(x, perm)
|
|
|
|
if shape != c.GetShape(x).dimensions():
|
|
|
|
x = c.BroadcastInDim(x, shape, bcast_dims)
|
|
|
|
|
|
|
|
return x
|