mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Apply suggestions from code review
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
393bca122d
commit
b0fdf10a63
@ -15,7 +15,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
the `concrete` option, following the previous version's deprecation; see
|
the `concrete` option, following the previous version's deprecation; see
|
||||||
[JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
|
[JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
|
||||||
* Changes
|
* Changes
|
||||||
* Added {func}`jax.pure_callback` that enables calling back to pure Python functions from functions compiled with `jax.jit` or `jax.pmap`.
|
* Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`).
|
||||||
|
|
||||||
## jax 0.3.16
|
## jax 0.3.16
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
|
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
|
||||||
|
@ -3233,37 +3233,45 @@ def block_until_ready(x):
|
|||||||
|
|
||||||
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||||
*args: Any, **kwargs: Any):
|
*args: Any, **kwargs: Any):
|
||||||
"""Calls a pure Python callback function from staged out JAX programs.
|
"""Applies a functionally pure Python callable. Works under `jit`/`pmap`/etc.
|
||||||
|
|
||||||
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
|
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
|
||||||
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
|
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
|
||||||
should also return NumPy arrays. Execution takes place on the CPU host.
|
should also return NumPy arrays. Execution takes place on CPU, like any
|
||||||
|
Python+NumPy function.
|
||||||
|
|
||||||
The callback is treated as "pure" meaning it can be called multiple times when
|
The callback is treated as functionally pure, meaning it has no side-effects
|
||||||
transformed (for example in a ``vmap`` or ``pmap``), and it can also
|
and its output value depends only on its argument values. As a consequence, it
|
||||||
potentially be removed from JAX programs via dead-code elimination. Pure
|
is safe to be called multiple times (e.g. when transformed by ``vmap`` or
|
||||||
callbacks can also be reordered if data-dependence allows.
|
``pmap``), or not to be called at all when e.g. the output of a
|
||||||
|
`jit`-decorated function has no data dependence on its value. Pure callbacks
|
||||||
|
may also be reordered if data-dependence allows.
|
||||||
|
|
||||||
When ``pmap``-ed, the pure callback will be called several times (one on each axis
|
When ``pmap``-ed, the pure callback will be called several times (one on each
|
||||||
of the map). When `vmap`-ed the behavior will depend on the value of the
|
axis of the map). When `vmap`-ed the behavior will depend on the value of the
|
||||||
``rank_polymorphic`` keyword argument. If the callback is indicated as rank
|
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
|
||||||
polymorphic, the callback will be called directly on batched inputs (where the
|
is assumed to obey
|
||||||
batch axis is the leading dimension). Additionally, the callbacks should
|
``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``.
|
||||||
return outputs that also have a leading batch axis. If not rank polymorphic,
|
Therefore, the callback will be called directly on batched inputs (where the
|
||||||
``callback`` will be mapped sequentially across the batched axis.
|
batch axes are the leading dimensions). Additionally, the callbacks should
|
||||||
|
return outputs that have corresponding leading batch axes. If not vectorized
|
||||||
|
``callback`` will be mapped sequentially across the batched axis.
|
||||||
|
For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free
|
||||||
|
to set ``vectorized=True`` because the ``np.matmul`` function handles
|
||||||
|
arbitrary leading batch dimensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback: A Python callable. The callable will be passed in NumPy arrays and
|
callback: A Python callable. The callable will be passed PyTrees of NumPy
|
||||||
should return a PyTree of NumPy arrays that matches
|
arrays as arguments, and should return a PyTree of NumPy arrays that
|
||||||
``result_shape_dtypes``.
|
matches ``result_shape_dtypes``.
|
||||||
result_shape_dtypes: A PyTree of Python objects that have ``shape`` and
|
result_shape_dtypes: A PyTree with leaves that are objects with ``shape``
|
||||||
``dtype`` properties that correspond to the shape and dtypes of the
|
and ``dtype`` attributes which represent to the shapes and dtypes of the
|
||||||
outputs of ``callback``.
|
value of ``callback`` applied to ``args`` and ``kwargs``.
|
||||||
*args: The positional arguments to the callback. Must be PyTrees of JAX
|
*args: The positional arguments to the callback. Must be PyTrees of JAX
|
||||||
types.
|
types.
|
||||||
rank_polymorphic: A boolean that indicates whether or not ``callback`` is
|
vectorized: A boolean that indicates whether or not ``callback`` is
|
||||||
rank polymorphic, meaning it can handle arrays with additional leading
|
vectorized, meaning it can handle arrays with additional leading
|
||||||
dimensions. If ``rank_polymorphic`` is `True`, when the callback is mapped
|
dimensions. If ``vectorized`` is `True`, when the callback is mapped
|
||||||
via `jax.vmap`, it will be called directly on inputs with leading batch
|
via `jax.vmap`, it will be called directly on inputs with leading batch
|
||||||
dimensions instead of executing ``callback`` on each mapped input
|
dimensions instead of executing ``callback`` on each mapped input
|
||||||
individually. The callback should also return outputs batched across the
|
individually. The callback should also return outputs batched across the
|
||||||
|
@ -36,15 +36,15 @@ map, unsafe_map = util.safe_map, map
|
|||||||
|
|
||||||
@pure_callback_p.def_impl
|
@pure_callback_p.def_impl
|
||||||
def pure_callback_impl(*args, result_avals, callback: Callable[..., Any],
|
def pure_callback_impl(*args, result_avals, callback: Callable[..., Any],
|
||||||
rank_polymorphic: bool):
|
vectorized: bool):
|
||||||
del rank_polymorphic, result_avals
|
del vectorized, result_avals
|
||||||
return callback(*args)
|
return callback(*args)
|
||||||
|
|
||||||
|
|
||||||
@pure_callback_p.def_abstract_eval
|
@pure_callback_p.def_abstract_eval
|
||||||
def pure_callback_abstract_eval(*avals, callback: Callable[..., Any],
|
def pure_callback_abstract_eval(*avals, callback: Callable[..., Any],
|
||||||
result_avals, rank_polymorphic: bool):
|
result_avals, vectorized: bool):
|
||||||
del avals, callback, rank_polymorphic
|
del avals, callback, vectorized
|
||||||
return result_avals
|
return result_avals
|
||||||
|
|
||||||
|
|
||||||
@ -67,26 +67,29 @@ def pure_callback_transpose_rule(*args, **kwargs):
|
|||||||
ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule
|
ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule
|
||||||
|
|
||||||
|
|
||||||
def pure_callback_batching_rule(args, dims, *, callback, rank_polymorphic: bool,
|
def pure_callback_batching_rule(args, dims, *, callback, vectorized: bool,
|
||||||
result_avals: Sequence[core.ShapedArray]):
|
result_avals: Sequence[core.ShapedArray]):
|
||||||
axis_size = next(a.shape[0] for a, d in zip(args, dims)
|
axis_size = next(a.shape[0] for a, d in zip(args, dims)
|
||||||
if d is not batching.not_mapped)
|
if d is not batching.not_mapped)
|
||||||
new_args = []
|
new_args = [arg if dim is batching.not_mapped else
|
||||||
for arg, dim in zip(args, dims):
|
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
|
||||||
new_args.append(batching.moveaxis(arg, dim, 0))
|
if vectorized:
|
||||||
if rank_polymorphic:
|
|
||||||
result_avals = tuple(
|
result_avals = tuple(
|
||||||
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
|
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
|
||||||
for aval in result_avals)
|
for aval in result_avals)
|
||||||
outvals = pure_callback_p.bind(
|
outvals = pure_callback_p.bind(
|
||||||
*new_args, callback=callback, rank_polymorphic=rank_polymorphic,
|
*new_args, callback=callback, vectorized=vectorized,
|
||||||
result_avals=result_avals)
|
result_avals=result_avals)
|
||||||
else:
|
else:
|
||||||
|
is_batched = [d is not batching.not_mapped for d in dims]
|
||||||
|
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
|
||||||
|
def _batch_fun(*batched_args):
|
||||||
|
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
|
||||||
|
return pure_callback_p.bind(
|
||||||
|
*merged_args, callback=callback, result_avals=result_avals,
|
||||||
|
vectorized=vectorized)
|
||||||
from jax._src.lax.control_flow import map as lax_map
|
from jax._src.lax.control_flow import map as lax_map
|
||||||
outvals = lax_map(
|
outvals = lax_map(_batch_fun, *batched_args)
|
||||||
functools.partial(pure_callback_p.bind, callback=callback,
|
|
||||||
rank_polymorphic=rank_polymorphic, result_avals=result_avals),
|
|
||||||
*new_args)
|
|
||||||
return tuple(outvals), (0,) * len(outvals)
|
return tuple(outvals), (0,) * len(outvals)
|
||||||
|
|
||||||
|
|
||||||
@ -115,7 +118,7 @@ mlir.register_lowering(pure_callback_p, pure_callback_lowering)
|
|||||||
|
|
||||||
|
|
||||||
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||||
*args: Any, rank_polymorphic: bool = False, **kwargs: Any):
|
*args: Any, vectorized: bool = False, **kwargs: Any):
|
||||||
def _flat_callback(*flat_args):
|
def _flat_callback(*flat_args):
|
||||||
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
|
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
|
||||||
return tree_util.tree_leaves(callback(*args, **kwargs))
|
return tree_util.tree_leaves(callback(*args, **kwargs))
|
||||||
@ -126,5 +129,5 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
|||||||
flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
|
flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
|
||||||
out_flat = pure_callback_p.bind(
|
out_flat = pure_callback_p.bind(
|
||||||
*flat_args, callback=_flat_callback,
|
*flat_args, callback=_flat_callback,
|
||||||
result_avals=tuple(flat_result_avals), rank_polymorphic=rank_polymorphic)
|
result_avals=tuple(flat_result_avals), vectorized=vectorized)
|
||||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||||
|
@ -152,6 +152,7 @@ class RuntimeTokenSet(threading.local):
|
|||||||
token[0].block_until_ready()
|
token[0].block_until_ready()
|
||||||
for token in self.output_runtime_tokens.values():
|
for token in self.output_runtime_tokens.values():
|
||||||
token.block_until_ready()
|
token.block_until_ready()
|
||||||
|
self.clear()
|
||||||
|
|
||||||
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
|
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
|
||||||
|
|
||||||
@ -711,7 +712,7 @@ def _check_special(name, xla_shape, buf):
|
|||||||
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
||||||
|
|
||||||
def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
|
def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
|
||||||
device: Device, input_bufs):
|
has_host_callbacks: bool, device: Device, input_bufs):
|
||||||
tokens = [runtime_tokens.get_token(eff, device) for eff in ordered_effects]
|
tokens = [runtime_tokens.get_token(eff, device) for eff in ordered_effects]
|
||||||
tokens_flat = flatten(tokens)
|
tokens_flat = flatten(tokens)
|
||||||
input_bufs = [*tokens_flat, *input_bufs]
|
input_bufs = [*tokens_flat, *input_bufs]
|
||||||
@ -720,7 +721,7 @@ def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
|
|||||||
num_output_tokens = len(ordered_effects) + (not can_execute_with_token and
|
num_output_tokens = len(ordered_effects) + (not can_execute_with_token and
|
||||||
has_unordered_effects)
|
has_unordered_effects)
|
||||||
token_bufs, output_bufs = util.split_list(output_bufs, [num_output_tokens])
|
token_bufs, output_bufs = util.split_list(output_bufs, [num_output_tokens])
|
||||||
if has_unordered_effects:
|
if has_unordered_effects or has_host_callbacks:
|
||||||
if can_execute_with_token:
|
if can_execute_with_token:
|
||||||
runtime_tokens.set_output_runtime_token(device, runtime_token)
|
runtime_tokens.set_output_runtime_token(device, runtime_token)
|
||||||
else:
|
else:
|
||||||
@ -738,14 +739,15 @@ def _execute_compiled(name: str, compiled: XlaExecutable,
|
|||||||
result_handler: Callable,
|
result_handler: Callable,
|
||||||
has_unordered_effects: bool,
|
has_unordered_effects: bool,
|
||||||
ordered_effects: List[core.Effect],
|
ordered_effects: List[core.Effect],
|
||||||
kept_var_idx, host_callbacks, *args):
|
kept_var_idx, has_host_callbacks: bool, *args):
|
||||||
device, = compiled.local_devices()
|
device, = compiled.local_devices()
|
||||||
args, env = input_handler(args) if input_handler else (args, None)
|
args, env = input_handler(args) if input_handler else (args, None)
|
||||||
in_flat = flatten(device_put(x, device) for i, x in enumerate(args)
|
in_flat = flatten(device_put(x, device) for i, x in enumerate(args)
|
||||||
if i in kept_var_idx)
|
if i in kept_var_idx)
|
||||||
if has_unordered_effects or ordered_effects:
|
if has_unordered_effects or ordered_effects or has_host_callbacks:
|
||||||
in_flat, token_handler = _add_tokens(
|
in_flat, token_handler = _add_tokens(
|
||||||
has_unordered_effects, ordered_effects, device, in_flat)
|
has_unordered_effects, ordered_effects, has_host_callbacks, device,
|
||||||
|
in_flat)
|
||||||
if can_execute_with_token:
|
if can_execute_with_token:
|
||||||
out_flat, runtime_token = compiled.execute_with_token(in_flat)
|
out_flat, runtime_token = compiled.execute_with_token(in_flat)
|
||||||
else:
|
else:
|
||||||
@ -755,7 +757,7 @@ def _execute_compiled(name: str, compiled: XlaExecutable,
|
|||||||
out_flat = compiled.execute(in_flat)
|
out_flat = compiled.execute(in_flat)
|
||||||
check_special(name, out_flat)
|
check_special(name, out_flat)
|
||||||
out_bufs = unflatten(out_flat, output_buffer_counts)
|
out_bufs = unflatten(out_flat, output_buffer_counts)
|
||||||
if ordered_effects or has_unordered_effects:
|
if ordered_effects or has_unordered_effects or has_host_callbacks:
|
||||||
out_bufs = token_handler(out_bufs, runtime_token)
|
out_bufs = token_handler(out_bufs, runtime_token)
|
||||||
return result_handler(env, out_bufs)
|
return result_handler(env, out_bufs)
|
||||||
|
|
||||||
@ -766,7 +768,7 @@ def _execute_replicated(name: str, compiled: XlaExecutable,
|
|||||||
result_handler: Callable,
|
result_handler: Callable,
|
||||||
has_unordered_effects: bool,
|
has_unordered_effects: bool,
|
||||||
ordered_effects: List[core.Effect],
|
ordered_effects: List[core.Effect],
|
||||||
kept_var_idx, host_callbacks, *args):
|
kept_var_idx, has_host_callbacks: bool, *args):
|
||||||
if has_unordered_effects or ordered_effects:
|
if has_unordered_effects or ordered_effects:
|
||||||
# TODO(sharadmv): support jit-of-pmap with effects
|
# TODO(sharadmv): support jit-of-pmap with effects
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -965,7 +967,7 @@ class XlaCompiledComputation(stages.XlaExecutable):
|
|||||||
execute = _execute_compiled if nreps == 1 else _execute_replicated
|
execute = _execute_compiled if nreps == 1 else _execute_replicated
|
||||||
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811
|
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811
|
||||||
result_handler, has_unordered_effects,
|
result_handler, has_unordered_effects,
|
||||||
ordered_effects, kept_var_idx, host_callbacks)
|
ordered_effects, kept_var_idx, bool(host_callbacks))
|
||||||
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call,
|
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call,
|
||||||
keepalive)
|
keepalive)
|
||||||
|
|
||||||
@ -988,7 +990,7 @@ class XlaCompiledComputation(stages.XlaExecutable):
|
|||||||
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
||||||
unsafe_call = partial(_execute_trivial, jaxpr, device, consts, out_avals,
|
unsafe_call = partial(_execute_trivial, jaxpr, device, consts, out_avals,
|
||||||
result_handlers, has_unordered_effects,
|
result_handlers, has_unordered_effects,
|
||||||
ordered_effects, kept_var_idx, host_callbacks)
|
ordered_effects, kept_var_idx, bool(host_callbacks))
|
||||||
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call,
|
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call,
|
||||||
keepalive)
|
keepalive)
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import functools
|
||||||
import io
|
import io
|
||||||
import textwrap
|
import textwrap
|
||||||
import unittest
|
import unittest
|
||||||
@ -25,6 +26,7 @@ from jax import core
|
|||||||
from jax import lax
|
from jax import lax
|
||||||
from jax import tree_util
|
from jax import tree_util
|
||||||
from jax._src import debugging
|
from jax._src import debugging
|
||||||
|
from jax._src import dispatch
|
||||||
from jax._src import lib as jaxlib
|
from jax._src import lib as jaxlib
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src import util
|
from jax._src import util
|
||||||
@ -459,6 +461,10 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
class PurePythonCallbackTest(jtu.JaxTestCase):
|
class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
dispatch.runtime_tokens.clear()
|
||||||
|
|
||||||
@jtu.skip_on_devices(*disabled_backends)
|
@jtu.skip_on_devices(*disabled_backends)
|
||||||
def test_simple_pure_callback(self):
|
def test_simple_pure_callback(self):
|
||||||
|
|
||||||
@ -504,8 +510,8 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def f(x):
|
def f(x):
|
||||||
# Calling a function with a return value that expects no return values
|
# Calling a function with two return values that expects one return value
|
||||||
return jax.pure_callback(lambda x: (x, np.ones(4, np.float32)), x)
|
return jax.pure_callback(lambda x: (x, np.ones(4, np.float32)), x, x)
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
f(2.)
|
f(2.)
|
||||||
@ -513,8 +519,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def g():
|
def g():
|
||||||
# Calling a function with a return value that expects no return values
|
return jax.pure_callback(lambda: (), (
|
||||||
return jax.pure_callback(lambda: None, (
|
|
||||||
core.ShapedArray((1,), np.float32), core.ShapedArray((2,), np.float32)))
|
core.ShapedArray((1,), np.float32), core.ShapedArray((2,), np.float32)))
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
@ -564,7 +569,14 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
|||||||
out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2)))
|
out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2)))
|
||||||
np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T)
|
np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T)
|
||||||
|
|
||||||
def test_vmap_rank_polymorphic_callback(self):
|
@jax.jit
|
||||||
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
||||||
|
def h(x, y):
|
||||||
|
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y)
|
||||||
|
out = h(jnp.arange(4.), 4.)
|
||||||
|
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
|
||||||
|
|
||||||
|
def test_vmap_vectorized_callback(self):
|
||||||
|
|
||||||
def cb(x):
|
def cb(x):
|
||||||
self.assertTupleEqual(x.shape, ())
|
self.assertTupleEqual(x.shape, ())
|
||||||
@ -575,7 +587,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
|||||||
def f(x):
|
def f(x):
|
||||||
return jax.pure_callback(cb, x, x)
|
return jax.pure_callback(cb, x, x)
|
||||||
|
|
||||||
_ = f(jnp.arange(4.))
|
np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.)))
|
||||||
|
|
||||||
def cb2(x):
|
def cb2(x):
|
||||||
self.assertTupleEqual(x.shape, (4,))
|
self.assertTupleEqual(x.shape, (4,))
|
||||||
@ -585,11 +597,19 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
|||||||
@jax.jit
|
@jax.jit
|
||||||
@jax.vmap
|
@jax.vmap
|
||||||
def g(x):
|
def g(x):
|
||||||
return jax.pure_callback(cb2, x, x, rank_polymorphic=True)
|
return jax.pure_callback(cb2, x, x, vectorized=True)
|
||||||
|
|
||||||
_ = g(jnp.arange(4.))
|
np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
|
||||||
|
|
||||||
def test_vmap_rank_polymorphic_callback_errors_if_returns_wrong_shape(self):
|
@jax.jit
|
||||||
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
||||||
|
def h(x, y):
|
||||||
|
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
|
||||||
|
vectorized=True)
|
||||||
|
out = h(jnp.arange(4.), 4.)
|
||||||
|
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
|
||||||
|
|
||||||
|
def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self):
|
||||||
|
|
||||||
def cb(x):
|
def cb(x):
|
||||||
# Reduces over all dimension when it shouldn't
|
# Reduces over all dimension when it shouldn't
|
||||||
@ -598,7 +618,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
|||||||
@jax.jit
|
@jax.jit
|
||||||
@jax.vmap
|
@jax.vmap
|
||||||
def f(x):
|
def f(x):
|
||||||
return jax.pure_callback(cb, x, x, rank_polymorphic=True)
|
return jax.pure_callback(cb, x, x, vectorized=True)
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
f(jnp.arange(4.))
|
f(jnp.arange(4.))
|
||||||
@ -743,13 +763,13 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
|||||||
np.testing.assert_allclose(out, jnp.arange(1., 41.))
|
np.testing.assert_allclose(out, jnp.arange(1., 41.))
|
||||||
|
|
||||||
@jtu.skip_on_devices(*disabled_backends)
|
@jtu.skip_on_devices(*disabled_backends)
|
||||||
def test_rank_polymorphic_callback_inside_xmap(self):
|
def test_vectorized_callback_inside_xmap(self):
|
||||||
|
|
||||||
def _callback(x):
|
def _callback(x):
|
||||||
return (x + 1.).astype(x.dtype)
|
return (x + 1.).astype(x.dtype)
|
||||||
|
|
||||||
def f(x):
|
def f(x):
|
||||||
return jax.pure_callback(_callback, x, x, rank_polymorphic=True)
|
return jax.pure_callback(_callback, x, x, vectorized=True)
|
||||||
|
|
||||||
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
|
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
|
||||||
axis_resources={'a': 'dev'})
|
axis_resources={'a': 'dev'})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user