update jaxlib version in readme

fixes #1297
will update notebooks in #1260
This commit is contained in:
Matthew Johnson 2019-09-02 07:25:06 -07:00
parent 110634d50d
commit c760b05f9b
4 changed files with 74 additions and 339 deletions

View File

@ -122,7 +122,7 @@ PYTHON_VERSION=cp37 # alternatives: cp27, cp35, cp36, cp37
CUDA_VERSION=cuda92 # alternatives: cuda90, cuda92, cuda100
PLATFORM=linux_x86_64 # alternatives: linux_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.23-$PYTHON_VERSION-none-$PLATFORM.whl
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.27-$PYTHON_VERSION-none-$PLATFORM.whl
pip install --upgrade jax # install jax
```

View File

@ -134,12 +134,7 @@ def _nan_like(c, operand):
nan = c.Constant(onp.array(onp.nan, dtype=dtype))
return c.Broadcast(nan, shape.dimensions())
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "potrf"):
_cpu_potrf = lapack.potrf
else:
_cpu_potrf = _unpack_tuple(lapack.jax_potrf, 2)
_cpu_potrf = lapack.potrf
def cholesky_cpu_translation_rule(c, operand):
shape = c.GetShape(operand)
@ -181,12 +176,7 @@ def eig_abstract_eval(operand):
raise NotImplementedError
return w, vl, vr
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "geev"):
_cpu_geev = lapack.geev
else:
_cpu_geev = _unpack_tuple(lapack.jax_geev, 4)
_cpu_geev = lapack.geev
def eig_cpu_translation_rule(c, operand):
shape = c.GetShape(operand)
@ -294,21 +284,13 @@ eigh_p.def_abstract_eval(eigh_abstract_eval)
xla.translations[eigh_p] = eigh_translation_rule
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "syevd"):
_cpu_syevd = lapack.syevd
else:
_cpu_syevd = _unpack_tuple(lapack.jax_syevd, 3)
_cpu_syevd = lapack.syevd
xla.backend_specific_translations['cpu'][eigh_p] = partial(
_eigh_cpu_gpu_translation_rule, _cpu_syevd)
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if cusolver:
xla.backend_specific_translations['gpu'][eigh_p] = partial(
_eigh_cpu_gpu_translation_rule, cusolver.syevd)
xla.backend_specific_translations['gpu'][eigh_p] = partial(
_eigh_cpu_gpu_translation_rule, cusolver.syevd)
batching.primitive_batchers[eigh_p] = eigh_batching_rule
@ -612,12 +594,7 @@ xla.translations[lu_p] = xla.lower_fun(_lu_python, instantiate=True)
ad.primitive_jvps[lu_p] = _lu_jvp_rule
batching.primitive_batchers[lu_p] = _lu_batching_rule
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "getrf"):
_cpu_getrf = lapack.getrf
else:
_cpu_getrf = _unpack_tuple(lapack.jax_getrf, 3)
_cpu_getrf = lapack.getrf
xla.backend_specific_translations['cpu'][lu_p] = partial(
_lu_cpu_gpu_translation_rule, _cpu_getrf)
@ -803,18 +780,10 @@ ad.primitive_jvps[svd_p] = svd_jvp_rule
batching.primitive_batchers[svd_p] = svd_batching_rule
xla.translations[svd_p] = svd_translation_rule
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "gesdd"):
_cpu_gesdd = lapack.gesdd
else:
_cpu_gesdd = _unpack_tuple(lapack.jax_gesdd, 4)
_cpu_gesdd = lapack.gesdd
xla.backend_specific_translations['cpu'][svd_p] = partial(
_svd_cpu_gpu_translation_rule, _cpu_gesdd)
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if cusolver:
xla.backend_specific_translations['gpu'][svd_p] = partial(
_svd_cpu_gpu_translation_rule, cusolver.gesvd)
xla.backend_specific_translations['gpu'][svd_p] = partial(
_svd_cpu_gpu_translation_rule, cusolver.gesvd)

View File

@ -41,16 +41,5 @@ from jaxlib import xla_client
from jaxlib import xrt
from jaxlib import lapack
# TODO(phawkins): make the import unconditional when the minimum Jaxlib version
# has been increased to 0.1.23.
try:
from jaxlib import pytree
except ImportError:
pytree = None
# TODO(phawkins): make the import unconditional when the minimum Jaxlib version
# has been increased to 0.1.23.
try:
from jaxlib import cusolver
except ImportError:
cusolver = None
from jaxlib import pytree
from jaxlib import cusolver

View File

@ -46,308 +46,85 @@ from .lib import pytree
from .util import unzip2, partial, safe_map
# TODO(phawkins): use the first case unconditionally when the minimum Jaxlib
# version has been increased to 0.1.23.
if pytree:
def tree_map(f, tree):
"""Map a function over a pytree to produce a new pytree.
def tree_map(f, tree):
"""Map a function over a pytree to produce a new pytree.
Args:
f: function to be applied at each leaf.
tree: a pytree to be mapped over.
Args:
f: function to be applied at each leaf.
tree: a pytree to be mapped over.
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by `f(x)` where `x` is the value at the corresponding leaf in
`tree`.
"""
leaves, treedef = pytree.flatten(tree)
return treedef.unflatten(map(f, leaves))
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by `f(x)` where `x` is the value at the corresponding leaf in
`tree`.
"""
leaves, treedef = pytree.flatten(tree)
return treedef.unflatten(map(f, leaves))
def tree_multimap(f, tree, *rest):
"""Map a multi-input function over pytree args to produce a new pytree.
def tree_multimap(f, tree, *rest):
"""Map a multi-input function over pytree args to produce a new pytree.
Args:
f: function that takes `1 + len(rest)` arguments, to be applied at the
corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf providing the first
positional argument to `f`.
*rest: a tuple of pytrees, each of which has the same structure as tree or
or has tree as a prefix.
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
in `tree` and `xs` is the tuple of values at corresponding nodes in
`rest`.
"""
leaves, treedef = pytree.flatten(tree)
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
Args:
f: function that takes `1 + len(rest)` arguments, to be applied at the
corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf providing the first
positional argument to `f`.
*rest: a tuple of pytrees, each of which has the same structure as tree or
or has tree as a prefix.
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
in `tree` and `xs` is the tuple of values at corresponding nodes in
`rest`.
"""
leaves, treedef = pytree.flatten(tree)
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
def tree_leaves(tree):
return pytree.flatten(tree)[0]
def tree_leaves(tree):
return pytree.flatten(tree)[0]
def process_pytree(process_node, tree):
leaves, treedef = pytree.flatten(tree)
return treedef.walk(process_node, None, leaves), treedef
def process_pytree(process_node, tree):
leaves, treedef = pytree.flatten(tree)
return treedef.walk(process_node, None, leaves), treedef
tree_flatten = pytree.flatten
tree_flatten = pytree.flatten
def build_tree(treedef, xs):
return treedef.from_iterable_tree(xs)
def build_tree(treedef, xs):
return treedef.from_iterable_tree(xs)
def treedef_is_leaf(treedef):
return treedef.num_nodes == 1
def treedef_is_leaf(treedef):
return treedef.num_nodes == 1
def tree_unflatten(treedef, xs):
return treedef.unflatten(xs)
def tree_unflatten(treedef, xs):
return treedef.unflatten(xs)
def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
flat, treedef = tree_flatten(pytree_to_transpose)
expected_treedef = outer_treedef.compose(inner_treedef)
if treedef != expected_treedef:
raise TypeError("Mismatch\n{}\n != \n{}".format(treedef, expected_treedef))
def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
flat, treedef = tree_flatten(pytree_to_transpose)
expected_treedef = outer_treedef.compose(inner_treedef)
if treedef != expected_treedef:
raise TypeError("Mismatch\n{}\n != \n{}".format(treedef, expected_treedef))
inner_size = inner_treedef.num_leaves
outer_size = outer_treedef.num_leaves
flat = iter(flat)
lol = [[next(flat) for _ in range(inner_size)] for __ in range(outer_size)]
transposed_lol = zip(*lol)
subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
return tree_unflatten(inner_treedef, subtrees)
inner_size = inner_treedef.num_leaves
outer_size = outer_treedef.num_leaves
flat = iter(flat)
lol = [[next(flat) for _ in range(inner_size)] for __ in range(outer_size)]
transposed_lol = zip(*lol)
subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
return tree_unflatten(inner_treedef, subtrees)
def tree_structure(tree):
_, treedef = pytree.flatten(tree)
return treedef
def tree_structure(tree):
_, treedef = pytree.flatten(tree)
return treedef
def treedef_tuple(trees):
return pytree.tuple(list(trees))
def treedef_tuple(trees):
return pytree.tuple(list(trees))
def treedef_children(treedef):
return treedef.children()
def treedef_children(treedef):
return treedef.children()
register_pytree_node = pytree.register_node
register_pytree_node = pytree.register_node
else:
def tree_map(f, tree):
"""Map a function over a pytree to produce a new pytree.
Args:
f: function to be applied at each leaf.
tree: a pytree to be mapped over.
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by `f(x)` where `x` is the value at the corresponding leaf in
`tree`.
"""
node_type = _get_node_type(tree)
if node_type:
children, node_spec = node_type.to_iterable(tree)
new_children = [tree_map(f, child) for child in children]
return node_type.from_iterable(node_spec, new_children)
else:
return f(tree)
def tree_multimap(f, tree, *rest):
"""Map a multi-input function over pytree args to produce a new pytree.
Args:
f: function that takes `1 + len(rest)` arguments, to be applied at the
corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf providing the first
positional argument to `f`.
*rest: a tuple of pytrees, each with the same structure as `tree`.
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`.
"""
node_type = _get_node_type(tree)
if node_type:
children, aux_data = node_type.to_iterable(tree)
all_children = [children]
for other_tree in rest:
other_node_type = _get_node_type(other_tree)
if node_type != other_node_type:
raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type))
other_children, other_aux_data = node_type.to_iterable(other_tree)
if other_aux_data != aux_data:
raise TypeError('Mismatch: {} != {}'.format(other_aux_data, aux_data))
all_children.append(other_children)
new_children = [tree_multimap(f, *xs) for xs in zip(*all_children)]
return node_type.from_iterable(aux_data, new_children)
else:
return f(tree, *rest)
def _walk_pytree(f_node, f_leaf, tree):
node_type = _get_node_type(tree)
if node_type:
children, node_spec = node_type.to_iterable(tree)
proc_children, child_specs = unzip2([_walk_pytree(f_node, f_leaf, child)
for child in children])
tree_def = _PyTreeDef(node_type, node_spec, child_specs)
return f_node(proc_children), tree_def
else:
return f_leaf(tree), leaf
def process_pytree(process_node, tree):
return _walk_pytree(process_node, lambda x: x, tree)
def build_tree(treedef, xs):
if treedef is leaf:
return xs
else:
# We use 'iter' for clearer error messages
children = safe_map(build_tree, iter(treedef.children), iter(xs))
return treedef.node_type.from_iterable(treedef.node_data, children)
def tree_leaves(tree):
"""Generator that iterates over all leaves of a pytree."""
node_type = _get_node_type(tree)
if node_type:
children, _ = node_type.to_iterable(tree)
for child in children:
# TODO(mattjj,phawkins): use 'yield from' when PY2 is dropped
for leaf in tree_leaves(child):
yield leaf
else:
yield tree
def tree_flatten(tree):
itr, treedef = _walk_pytree(it.chain.from_iterable, lambda x: (x,), tree)
return list(itr), treedef
def _tree_unflatten(xs, treedef):
if treedef is leaf:
return next(xs)
else:
children = tuple(map(partial(_tree_unflatten, xs), treedef.children))
return treedef.node_type.from_iterable(treedef.node_data, children)
def tree_unflatten(treedef, xs):
return _tree_unflatten(iter(xs), treedef)
def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
flat, treedef = tree_flatten(pytree_to_transpose)
expected_treedef = _nested_treedef(inner_treedef, outer_treedef)
if treedef != expected_treedef:
raise TypeError("Mismatch\n{}\n != \n{}".format(treedef, expected_treedef))
inner_size = _num_leaves(inner_treedef)
outer_size = _num_leaves(outer_treedef)
flat = iter(flat)
lol = [[next(flat) for _ in range(inner_size)] for __ in range(outer_size)]
transposed_lol = zip(*lol)
subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
return tree_unflatten(inner_treedef, subtrees)
def _num_leaves(treedef):
return 1 if treedef is leaf else sum(map(_num_leaves, treedef.children))
def _nested_treedef(inner, outer):
# just used in tree_transpose error checking
if outer is leaf:
return inner
else:
children = map(partial(_nested_treedef, inner), outer.children)
return _PyTreeDef(outer.node_type, outer.node_data, tuple(children))
def tree_structure(tree):
_, spec = process_pytree(lambda _: None, tree)
return spec
class _PyTreeDef(object):
__slots__ = ("node_type", "node_data", "children")
def __init__(self, node_type, node_data, children):
self.node_type = node_type
self.node_data = node_data
self.children = children
def __repr__(self):
if self.node_data is None:
data_repr = ""
else:
data_repr = "[{}]".format(self.node_data)
return "PyTree({}{}, [{}])".format(self.node_type.name, data_repr,
','.join(map(repr, self.children)))
def __hash__(self):
return hash((self.node_type, self.node_data, tuple(self.children)))
def __eq__(self, other):
if other is leaf:
return False
else:
return (self.node_type == other.node_type and
self.node_data == other.node_data and
self.children == other.children)
def __ne__(self, other):
return not self == other
class _PyLeaf(object):
__slots__ = ()
def __repr__(self):
return '*'
leaf = _PyLeaf()
def treedef_is_leaf(treedef):
return treedef is leaf
def treedef_tuple(treedefs):
return _PyTreeDef(node_types[tuple], None, tuple(treedefs))
def treedef_children(treedef):
return treedef.children
def dict_to_iterable(xs):
keys = tuple(sorted(xs.keys()))
return tuple(map(xs.get, keys)), keys
class NodeType(object):
def __init__(self, name, to_iterable, from_iterable):
self.name = name
self.to_iterable = to_iterable
self.from_iterable = from_iterable
def __repr__(self):
return self.name
node_types = {}
def register_pytree_node(py_type, to_iterable, from_iterable):
assert py_type not in node_types
node_types[py_type] = NodeType(str(py_type), to_iterable, from_iterable)
register_pytree_node(tuple, lambda xs: (xs, None), lambda _, xs: tuple(xs))
register_pytree_node(list, lambda xs: (tuple(xs), None), lambda _, xs: list(xs))
register_pytree_node(dict, dict_to_iterable, lambda keys, xs: dict(zip(keys, xs)))
# To handle namedtuples, we can't just use the standard table of node_types
# because every namedtuple creates its own type and thus would require its own
# entry in the table. Instead we use a heuristic check on the type itself to
# decide whether it's a namedtuple type, and if so treat it as a pytree node.
def _get_node_type(maybe_tree):
t = type(maybe_tree)
return node_types.get(t) or _namedtuple_node(t)
def _namedtuple_node(t):
if issubclass(t, tuple) and hasattr(t, '_fields'):
return NamedtupleNode
NamedtupleNode = NodeType('namedtuple',
lambda xs: (tuple(xs), type(xs)),
lambda t, xs: t(*xs))
def tree_reduce(f, tree):