use Partial to make ravel_pytree unflatteners jit-friendly

Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
This commit is contained in:
Matthew Johnson 2023-03-13 10:47:45 -07:00
parent 233911c001
commit a6d3ae1446
4 changed files with 64 additions and 43 deletions

View File

@ -17,7 +17,7 @@ import warnings
import numpy as np
from jax._src.tree_util import tree_flatten, tree_unflatten
from jax._src.util import safe_zip, unzip2
from jax._src.util import safe_zip, unzip2, HashablePartial
import jax.numpy as jnp
from jax._src import dtypes
@ -47,39 +47,43 @@ def ravel_pytree(pytree):
"""
leaves, treedef = tree_flatten(pytree)
flat, unravel_list = _ravel_list(leaves)
unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
return flat, unravel_pytree
return flat, HashablePartial(unravel_pytree, treedef, unravel_list)
def unravel_pytree(treedef, unravel_list, flat):
return tree_unflatten(treedef, unravel_list(flat))
def _ravel_list(lst):
if not lst: return jnp.array([], jnp.float32), lambda _: []
from_dtypes = [dtypes.dtype(l) for l in lst]
from_dtypes = tuple(dtypes.dtype(l) for l in lst)
to_dtype = dtypes.result_type(*from_dtypes)
sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
indices = np.cumsum(sizes)
indices = tuple(np.cumsum(sizes))
if all(dt == to_dtype for dt in from_dtypes):
# Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
# See https://github.com/google/jax/issues/7809.
del from_dtypes, to_dtype
def unravel(arr):
chunks = jnp.split(arr, indices[:-1])
return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
raveled = jnp.concatenate([jnp.ravel(e) for e in lst])
return raveled, unravel
return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes)
# When there is more than one distinct input dtype, we perform type
# conversions and produce a dtype-specific unravel function.
def unravel(arr):
arr_dtype = dtypes.dtype(arr)
if arr_dtype != to_dtype:
raise TypeError(f"unravel function given array of dtype {arr_dtype}, "
f"but expected dtype {to_dtype}")
chunks = jnp.split(arr, indices[:-1])
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
return [lax.convert_element_type(chunk.reshape(shape), dtype)
for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]
ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
raveled = jnp.concatenate([ravel(e) for e in lst])
return raveled, unravel
unrav = HashablePartial(_unravel_list, indices, shapes, from_dtypes, to_dtype)
return raveled, unrav
def _unravel_list_single_dtype(indices, shapes, arr):
chunks = jnp.split(arr, indices[:-1])
return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr):
arr_dtype = dtypes.dtype(arr)
if arr_dtype != to_dtype:
raise TypeError(f"unravel function given array of dtype {arr_dtype}, "
f"but expected dtype {to_dtype}")
chunks = jnp.split(arr, indices[:-1])
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
return [lax.convert_element_type(chunk.reshape(shape), dtype)
for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]

View File

@ -433,6 +433,23 @@ class HashableFunction:
def as_hashable_function(closure):
return lambda f: HashableFunction(f, closure)
class HashablePartial:
def __init__(self, f, *args, **kwargs):
self.f = f
self.args = args
self.kwargs = kwargs
def __eq__(self, other):
return (type(other) is HashablePartial and
self.f.__code__ == other.f.__code__ and
self.args == other.args and self.kwargs == other.kwargs)
def __hash__(self):
return hash((self.f.__code__, self.args, tuple(self.kwargs.items())))
def __call__(self, *args, **kwargs):
return self.f(*self.args, *args, **self.kwargs, **kwargs)
def maybe_named_axis(axis, if_pos, if_named):
try:
pos = operator.index(axis)

View File

@ -39,8 +39,9 @@ from jax._src import util
from jax._src.core import Tracer
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, fft, linalg)
from jax._src.util import (HashableFunction, unzip2, as_hashable_function,
memoize, partition_list, merge_lists)
from jax._src.util import (HashableFunction, HashablePartial, unzip2,
as_hashable_function, memoize, partition_list,
merge_lists)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax.interpreters import batching
from jax._src.interpreters import mlir
@ -984,23 +985,3 @@ def _pe_custom_ctx(params):
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
partial(pe.call_partial_eval_custom_rule, 'jaxpr', _pe_custom_params,
res_aval=_pe_custom_res, ctx=_pe_custom_ctx)
# Misc
# TODO(mattjj): move this to _src/util.py
class HashablePartial:
def __init__(self, f, *args, **kwargs):
self.f = f
self.args = args
self.kwargs = kwargs
def __eq__(self, other):
return (type(other) is HashablePartial and
self.f.__code__ == other.f.__code__ and
self.args == other.args and self.kwargs == other.kwargs)
def __hash__(self):
return hash((self.f.__code__, self.args, tuple(self.kwargs.items())))
def __call__(self, *args, **kwargs):
return self.f(*self.args, *args, **self.kwargs, **kwargs)

View File

@ -604,6 +604,25 @@ class RavelUtilTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, 'but expected dtype'):
_ = unravel(y)
def test_no_recompile(self):
x1 = jnp.array([1, 2])
x2 = jnp.array([3, 4])
x_flat1, unravel1 = flatten_util.ravel_pytree((x1, x2))
x_flat2, unravel2 = flatten_util.ravel_pytree((x1, x2))
num_traces = 0
def run(flat, unravel):
nonlocal num_traces
num_traces += 1
flat = flat + 1
return unravel(flat)
run = jax.jit(run, static_argnums=1)
run(x_flat1, unravel1)
run(x_flat2, unravel2)
self.assertEqual(num_traces, 1)
class TreePrefixErrorsTest(jtu.JaxTestCase):