mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
An experiment for id_print implemented with outfeed
* Added print descriptors, support multiple types * Added a state-passing mechanism to XLA interpreter
This commit is contained in:
parent
970e475e0a
commit
de685c9d5a
@ -100,3 +100,10 @@ pytype_library(
|
||||
srcs_version = "PY3",
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "experimental_host_callback",
|
||||
srcs = ["experimental/host_callback.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
308
jax/experimental/host_callback.py
Normal file
308
jax/experimental/host_callback.py
Normal file
@ -0,0 +1,308 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implementation of an experimental primitive for printing, including
|
||||
from transformed and compiled code.
|
||||
|
||||
See documentation for `id_print` below.
|
||||
For usage example, see tests/host_callback_test.py.
|
||||
|
||||
Implementation plan:
|
||||
* Write the API for the `id_print` primitive, using data-dependence as
|
||||
explained in `id_print` documentation (DONE).
|
||||
* Implement the transformations. DONE (except pmap)
|
||||
* Implement the JIT for CPU using CustomCall in C++. DONE (except unit tests
|
||||
do not run in OSS; also missing float16 and bfloat16).
|
||||
* Implement the JIT for GPU using also CustomCall in C++. DONE.
|
||||
* Explore how to pipe the printed data back to the Colab cell,
|
||||
when running in Colab. ?
|
||||
* Explore implementation using outfeed, hoping that it works for all
|
||||
platforms, and can pipe data more easily. STARTED.
|
||||
* Explore feeding the data back to the Python program (the `id_tap`
|
||||
primitive). ?
|
||||
* Explore a simpler API that uses Python program-order, instead of
|
||||
data dependency-order. Need to add support to JAX for stateful primitives.
|
||||
"""
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
import io
|
||||
import itertools
|
||||
|
||||
from jax import abstract_arrays
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax.lib import pytree, xla_bridge
|
||||
from jax.interpreters import ad, xla, batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax import util
|
||||
from jaxlib import xla_client
|
||||
from jaxlib import xla_extension
|
||||
|
||||
import logging
|
||||
import numpy as onp
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
# TODO(necula): fix mypy errors if I define the type aliases below
|
||||
XlaOp = Any # xla_extension.XlaOp
|
||||
XlaShape = Any # xla_client.Shape
|
||||
XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder
|
||||
|
||||
id_print_p = core.Primitive("id_print")
|
||||
id_print_p.multiple_results = True
|
||||
|
||||
xops = xla_client._xla.ops
|
||||
|
||||
def id_print(*args, result=None, **kwargs):
|
||||
"""Behaves like the identify function for positional arguments, but prints all
|
||||
arguments on the host, even from transformed or compiled code.
|
||||
|
||||
The return value is a tuple with the value of `args` or the value of the
|
||||
keyword parameter `result` if present. If there is a single positional
|
||||
argument, it returns just that argument without packing it in a tuple.
|
||||
|
||||
The positional arguments must be JAX values. The keyword arguments are
|
||||
serialized to a string and printed along with the positional arguments.
|
||||
There are a few special keywork arguments that are not printed:
|
||||
|
||||
* `result`: is the result of `id_print`, must be a JAX value or a
|
||||
pytree of values.
|
||||
* `output_stream`: is the output stream where the values should be
|
||||
printed. (Note: does not yet work from under JIT).
|
||||
|
||||
Usage:
|
||||
>>> y = id_print(x * 2) # prints and returns 2x
|
||||
>>> y, z = id_print(x * 2, x * 3) # prints and returns 2x and 3x
|
||||
>>> y = id_print(x * 2, result=y) # prints 2x and returns y
|
||||
>>> y = id_print(x * 2, what='x') # prints what=x followed by 2x
|
||||
|
||||
The order of execution is by data dependency: after all the arguments are
|
||||
computed and before the result is used. At least one of the returned values
|
||||
must be used in the rest of the computation, or else this operation has
|
||||
no effect.
|
||||
|
||||
Upon JAX transformations, the transformed values are wrapped with
|
||||
`id_print`, and a special `transforms` tuple keyword argument is added with
|
||||
the sequence of transformations applied:
|
||||
|
||||
- For `vmap` the arguments are batched, and transforms=('vmap')
|
||||
- For `jvp` there will be an id_print for the primal values, and a
|
||||
separate `id_print` for the tangents with `transforms=('jvp')`.
|
||||
- For `grad` there will be an `id_print` for the primal values (if
|
||||
needed in the computation of `grad` and an `id_print` with the
|
||||
adjoints of the results, with transforms=('vjp').
|
||||
"""
|
||||
flat_args, args_treedef = pytree.flatten(args)
|
||||
|
||||
params = dict(kwargs) # copy
|
||||
if result is not None:
|
||||
flat_results, results_treedef = pytree.flatten(result)
|
||||
params["nr_results"] = len(flat_results)
|
||||
all_args = flat_args + flat_results
|
||||
else:
|
||||
all_args = flat_args
|
||||
flat_outs = id_print_p.bind(*all_args, **params) # Always returns a tuple of all args
|
||||
if result is not None:
|
||||
return results_treedef.unflatten(flat_outs[-params["nr_results"]:])
|
||||
else:
|
||||
res = args_treedef.unflatten(flat_outs)
|
||||
return res if len(args) > 1 else res[0]
|
||||
|
||||
|
||||
def _expand_params_transform(params: Dict, transform: str) -> Dict:
|
||||
"""Adds the `transform` to the params["transforms"]."""
|
||||
return dict(params, transforms=params.get("transforms", ()) + (transform,))
|
||||
|
||||
|
||||
def _id_print_impl(*args, **params):
|
||||
output_stream = params.get("output_stream")
|
||||
if output_stream is not None:
|
||||
print_params = dict(params)
|
||||
del print_params["output_stream"]
|
||||
else:
|
||||
import sys
|
||||
output_stream = sys.stdout
|
||||
print_params = params
|
||||
|
||||
# TODO: use the JITed version to do the actual printing.
|
||||
to_print = f"{args} {print_params}"
|
||||
output_stream.write(to_print)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
id_print_p.def_impl(_id_print_impl)
|
||||
|
||||
def _id_print_abstract_eval(*args_a: pe.AbstractValue, **params) \
|
||||
-> Sequence[pe.AbstractValue]:
|
||||
return args_a
|
||||
|
||||
|
||||
id_print_p.def_abstract_eval(_id_print_abstract_eval)
|
||||
|
||||
|
||||
# Each array sent to the outfeed is preceeded by a descriptor, which
|
||||
# is an array s32[32], with the format:
|
||||
# [0]: special header value
|
||||
# [1]: an encoding of the element_type()
|
||||
# [2]: the number of dimensions
|
||||
# [3:...]: the size of the dimensions
|
||||
# padded with 0s
|
||||
#
|
||||
_OUTFEED_DESCRIPTOR_PRINT_HEADER = 13579
|
||||
_OUTFEED_DESCRIPTOR_LENGTH = 32
|
||||
_CODE_TO_DTYPE = {
|
||||
0: onp.dtype(onp.int8),
|
||||
1: onp.dtype(onp.int16),
|
||||
2: onp.dtype(onp.int32),
|
||||
3: onp.dtype(onp.int64),
|
||||
4: onp.dtype(onp.uint8),
|
||||
5: onp.dtype(onp.uint16),
|
||||
6: onp.dtype(onp.uint32),
|
||||
7: onp.dtype(onp.uint64),
|
||||
8: onp.dtype(onp.float16),
|
||||
9: onp.dtype(onp.float32),
|
||||
10: onp.dtype(onp.float64),
|
||||
11: onp.dtype(dtypes.bfloat16),
|
||||
}
|
||||
_DTYPE_STR_TO_CODE = dict([(str(d), c) for c, d in _CODE_TO_DTYPE.items()])
|
||||
|
||||
|
||||
|
||||
def _id_print_translation_rule_outfeed(
|
||||
comp: XlaComputationBuilder,
|
||||
*args_op: XlaOp, **params):
|
||||
|
||||
# TODO: outfeed a whole tuple at once?
|
||||
def _outfeed_one_array(a: XlaOp, token: XlaOp) -> XlaOp:
|
||||
a_shape = comp.GetShape(a)
|
||||
dimensions = a_shape.dimensions()
|
||||
descriptor = [_OUTFEED_DESCRIPTOR_PRINT_HEADER,
|
||||
_DTYPE_STR_TO_CODE[str(onp.dtype(a_shape.element_type()))],
|
||||
len(dimensions)] + list(dimensions)
|
||||
if len(descriptor) > _OUTFEED_DESCRIPTOR_LENGTH:
|
||||
raise ValueError(f"Too many dimensions in array to print: {a_shape}")
|
||||
descriptor += [0] * (_OUTFEED_DESCRIPTOR_LENGTH - len(descriptor))
|
||||
|
||||
data = xops.ConstantLiteral(comp, onp.array(descriptor, dtype=onp.int32))
|
||||
token = xops.OutfeedWithToken(data, token, comp.GetShape(data))
|
||||
token = xops.OutfeedWithToken(a, token, a_shape)
|
||||
return token
|
||||
|
||||
prev_token = xla.computation_state_carry.current_token(comp)
|
||||
nr_args_to_emit = len(args_op) - params.get("nr_results", 0)
|
||||
for i, a_op in enumerate(args_op):
|
||||
if i < nr_args_to_emit:
|
||||
prev_token = _outfeed_one_array(a_op, prev_token)
|
||||
xla.computation_state_carry.set_current_token(comp, prev_token)
|
||||
return xops.Tuple(comp, args_op)
|
||||
|
||||
xla.translations[id_print_p] = _id_print_translation_rule_outfeed
|
||||
|
||||
|
||||
# TODO: find a better way to signal the end of printing
|
||||
_END_PRINTING = onp.int32(12345678)
|
||||
def end_printing(res):
|
||||
return id_print(_END_PRINTING, result=res)
|
||||
|
||||
@contextmanager
|
||||
def print_receiver(output_stream=None,
|
||||
receiver_name="",
|
||||
timeout_sec=10):
|
||||
# TODO: better timeout management
|
||||
"""Starts a receiver for the id_print outfeed.
|
||||
|
||||
Usage:
|
||||
with print_receiver():
|
||||
jax.jit(func)(args)
|
||||
|
||||
"""
|
||||
# TODO: start receivers for each device
|
||||
platform = xla_client.get_local_backend(None).platform
|
||||
def _get_data(data_shape: XlaShape) -> XlaShape:
|
||||
if platform == "gpu":
|
||||
return xla_client.transfer_from_outfeed(data_shape)
|
||||
else:
|
||||
return xla_client.transfer_from_outfeed(
|
||||
xla_client.Shape.tuple_shape((data_shape,)))[0]
|
||||
|
||||
def _consume_one_array_from_outfeed():
|
||||
descriptor_shape = xla_client.Shape.array_shape(onp.dtype(onp.int32),
|
||||
(_OUTFEED_DESCRIPTOR_LENGTH,))
|
||||
descriptor = _get_data(descriptor_shape)
|
||||
|
||||
logging.info(f"[{receiver_name}] Read descriptor: {descriptor}")
|
||||
if descriptor[0] != _OUTFEED_DESCRIPTOR_PRINT_HEADER:
|
||||
raise ValueError(f"Read unexpected print descriptor {descriptor} [{receiver_name}]")
|
||||
data_dimensions = tuple(descriptor[3:3+descriptor[2]])
|
||||
data_shape = xla_client.Shape.array_shape(_CODE_TO_DTYPE[descriptor[1]],
|
||||
data_dimensions)
|
||||
data = _get_data(data_shape)
|
||||
logging.info(f"[{receiver_name}] Read data of shape {data.dtype}{data.shape}")
|
||||
return data
|
||||
|
||||
def receiver_loop():
|
||||
i = 0
|
||||
while (True):
|
||||
got = _consume_one_array_from_outfeed()
|
||||
if not got.shape and got == _END_PRINTING:
|
||||
logging.info(f"[{receiver_name}] Received END_PRINTING")
|
||||
return
|
||||
got_str = onp.array2string(got, threshold=1024)
|
||||
logging.info(f"[{receiver_name}] Received {i} ({got.dtype}{got.shape}): {got_str}")
|
||||
if output_stream is not None:
|
||||
output_stream.write(got_str)
|
||||
i += 1
|
||||
|
||||
receiver = threading.Thread(target=receiver_loop)
|
||||
receiver.start()
|
||||
try:
|
||||
yield receiver
|
||||
finally:
|
||||
# TODO: proper termination
|
||||
receiver.join(timeout=timeout_sec)
|
||||
if receiver.is_alive():
|
||||
logging.error(f"[{receiver_name}] Receiver still alive")
|
||||
else:
|
||||
logging.info(f"[{receiver_name}] Receiver finished")
|
||||
|
||||
|
||||
def _id_print_jvp_rule(primals, tangents, **params):
|
||||
primals_out = id_print(primals, **params)
|
||||
tangents_out = id_print(tangents, **_expand_params_transform(params, "jvp"))
|
||||
return primals_out, tangents_out
|
||||
|
||||
|
||||
ad.primitive_jvps[id_print_p] = _id_print_jvp_rule
|
||||
|
||||
|
||||
def _id_print_transpose_rule(cts, *args, **params):
|
||||
assert all([ad.is_undefined_primal(x) for x in args])
|
||||
assert len(cts) == len(args)
|
||||
ct_args = id_print_p.bind(*cts,
|
||||
**_expand_params_transform(params, "transpose"))
|
||||
return ct_args
|
||||
|
||||
|
||||
ad.primitive_transposes[id_print_p] = _id_print_transpose_rule
|
||||
|
||||
|
||||
def _id_print_batching_rule(batched_args, batch_dims, **params):
|
||||
res = id_print_p.bind(*batched_args,
|
||||
**_expand_params_transform(params, "batch"))
|
||||
return res, batch_dims
|
||||
|
||||
|
||||
batching.primitive_batchers[id_print_p] = _id_print_batching_rule
|
@ -14,9 +14,11 @@
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
import itertools as it
|
||||
import operator as op
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Tuple
|
||||
|
||||
from absl import logging
|
||||
import numpy as onp
|
||||
@ -33,13 +35,14 @@ from ..abstract_arrays import (ConcreteArray, ShapedArray, AbstractToken,
|
||||
from ..core import Literal, pp_eqn_compact
|
||||
from ..pprint_util import pp
|
||||
from ..util import (partial, partialmethod, cache, prod, unzip2, memoize,
|
||||
extend_name_stack, wrap_name)
|
||||
extend_name_stack, wrap_name, split_list)
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from . import partial_eval as pe
|
||||
from . import ad
|
||||
from . import masking
|
||||
|
||||
|
||||
xe = xc._xla
|
||||
xops = xc._xla.ops
|
||||
|
||||
@ -48,6 +51,11 @@ Backend = Any # xc.LocalBackend (why does mypy not like this?)
|
||||
Device = Any # xc.Device
|
||||
PyLocalBuffer = Any
|
||||
|
||||
XlaOp = Any # xla_extension.XlaOp
|
||||
XlaShape = Any # xla_client.Shape
|
||||
XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_bool('jax_debug_nans',
|
||||
bool_env('JAX_DEBUG_NANS', False),
|
||||
@ -162,6 +170,63 @@ pytype_aval_mappings.update((t, make_shaped_array) for t in array_types)
|
||||
pytype_aval_mappings.update(
|
||||
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
|
||||
|
||||
class _ComputationStateCarry(threading.local):
|
||||
"""Carries some state in global state.
|
||||
|
||||
For now the state is only a token, obtained from the last OutFeed.
|
||||
The translation rules for primitives can read-write from this class.
|
||||
|
||||
This assumes that the primitives are processed in order!
|
||||
"""
|
||||
_current_computation: Optional[XlaComputationBuilder]
|
||||
_current_token: Optional[XlaOp]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._current_computation = None
|
||||
self._current_token = None
|
||||
|
||||
def state_len(self):
|
||||
return 1
|
||||
|
||||
def current_state(self, comp: XlaComputationBuilder) -> List[XlaOp]:
|
||||
if comp is not self._current_computation:
|
||||
if self._current_computation is not None:
|
||||
# TODO: add more error checking
|
||||
logging.warning("Overwriting previous computation")
|
||||
self._current_computation = comp
|
||||
self._current_token = xops.CreateToken(comp)
|
||||
return [self._current_token]
|
||||
|
||||
def current_token(self, comp: XlaComputationBuilder) -> XlaOp:
|
||||
state = self.current_state(comp)
|
||||
return state[0]
|
||||
|
||||
def set_current_token(self, comp: XlaComputationBuilder, token: XlaOp):
|
||||
assert comp == self._current_computation
|
||||
self._current_token = token
|
||||
|
||||
def set_comp_and_current_state(self, comp: XlaComputationBuilder,
|
||||
state: Sequence[XlaOp]):
|
||||
self._current_computation = comp
|
||||
self._current_token = state[0]
|
||||
|
||||
def extract_state_from_tuple_op(self, comp: XlaComputationBuilder,
|
||||
tuple_op: XlaOp, nr_regular: int):
|
||||
"""Given a tuple extract the state from its tail.
|
||||
|
||||
We assume that the `tuple_op` represents a tuple with `nr_outs` regular
|
||||
elements, followed by some elements encoding the state. Stores the new state,
|
||||
and returns a Tuple with the regular elemennts.
|
||||
"""
|
||||
regular_ops = [xops.GetTupleElement(tuple_op, i) for i in range(nr_regular)]
|
||||
self._current_computation = comp
|
||||
current_state = [xops.GetTupleElement(tuple_op, i)
|
||||
for i in range(nr_regular, nr_regular + self.state_len())]
|
||||
self._current_token = current_state[0]
|
||||
return xops.Tuple(comp, regular_ops)
|
||||
|
||||
computation_state_carry = _ComputationStateCarry()
|
||||
|
||||
### op-by-op execution
|
||||
|
||||
def arg_spec(x):
|
||||
@ -622,12 +687,23 @@ def _xla_call_translation_rule(c, axis_env,
|
||||
in_nodes, name_stack, backend, name,
|
||||
call_jaxpr, device=None):
|
||||
del device # Ignored.
|
||||
prev_state = computation_state_carry.current_state(c)
|
||||
all_in_nodes = list(in_nodes) + list(prev_state)
|
||||
|
||||
subc = xb.make_computation_builder(f"jit_{name}")
|
||||
args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
|
||||
|
||||
all_args = [xb.parameter(subc, i, c.GetShape(n))
|
||||
for i, n in enumerate(all_in_nodes)]
|
||||
args, input_state = split_list(all_args, [len(in_nodes)])
|
||||
computation_state_carry.set_comp_and_current_state(subc, input_state)
|
||||
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
|
||||
extend_name_stack(name_stack, wrap_name(name, 'jit')), *args)
|
||||
subc = subc.Build(xops.Tuple(subc, out_nodes))
|
||||
return xops.Call(c, subc, list(in_nodes))
|
||||
result_state = computation_state_carry.current_state(subc)
|
||||
subc = subc.Build(xops.Tuple(subc, list(out_nodes) + result_state))
|
||||
call_op = xops.Call(c, subc, all_in_nodes)
|
||||
nr_outs = len(out_nodes)
|
||||
return computation_state_carry.extract_state_from_tuple_op(c, call_op, nr_outs)
|
||||
|
||||
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
|
||||
|
||||
|
||||
|
@ -254,13 +254,15 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
||||
cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])
|
||||
batched = bool(cond_jaxpr.out_avals[0].shape)
|
||||
|
||||
prev_state = xla.computation_state_carry.current_state(c)
|
||||
|
||||
# Since jaxprs don't have tuples and have multiple return values, but we need
|
||||
# the HLO While loop to take a single tuple input and output a single boolean
|
||||
# (for the cond computation) or a single tuple output (for the body
|
||||
# computation), we build XLA computations that handle the tuple munging before
|
||||
# generating a Call into the computations formed from the jaxprs.
|
||||
|
||||
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
|
||||
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals + prev_state)
|
||||
|
||||
cond_c = xb.make_computation_builder("cond_computation")
|
||||
cond_carry = xb.parameter(cond_c, 0, c.GetShape(init_carry))
|
||||
@ -279,6 +281,8 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
||||
body_c = xb.make_computation_builder("body_computation")
|
||||
body_carry = xb.parameter(body_c, 0, c.GetShape(init_carry))
|
||||
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
|
||||
xla.computation_state_carry.extract_state_from_tuple_op(
|
||||
body_c, body_carry, len(args))
|
||||
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
|
||||
new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
|
||||
_map(partial(xb.constant, body_c), body_jaxpr.literals),
|
||||
@ -289,10 +293,13 @@ def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
||||
extend_name_stack(name_stack, 'body_pred'), *(x + z))
|
||||
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
|
||||
assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape, z) # no broadcast
|
||||
new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z)))
|
||||
result_state = xla.computation_state_carry.current_state(body_c)
|
||||
new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z, result_state)))
|
||||
|
||||
ans = xops.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry)
|
||||
ans_elts = [xops.GetTupleElement(ans, i) for i in range(len(args))]
|
||||
xla.computation_state_carry.extract_state_from_tuple_op(
|
||||
c, ans, len(args))
|
||||
_, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
|
||||
return xops.Tuple(c, z)
|
||||
|
||||
@ -556,23 +563,30 @@ def _cond_translation_rule(c, axis_env, name_stack, avals, backend,
|
||||
pred, *args, true_jaxpr, false_jaxpr, linear):
|
||||
del linear # Unused.
|
||||
true_ops, false_ops = split_list(args, [len(true_jaxpr.in_avals)])
|
||||
|
||||
current_state = xla.computation_state_carry.current_state(c)
|
||||
def make_computation(name, jaxpr, op_shape):
|
||||
c = xb.make_computation_builder(name + '_comp')
|
||||
|
||||
op = xb.parameter(c, 0, op_shape)
|
||||
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
|
||||
xla.computation_state_carry.extract_state_from_tuple_op(
|
||||
c, op, len(jaxpr.in_avals))
|
||||
outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
|
||||
_map(partial(xb.constant, c), jaxpr.literals),
|
||||
extend_name_stack(name_stack, name + '_fun'), *ops)
|
||||
return c.Build(xops.Tuple(c, outs))
|
||||
result_state = xla.computation_state_carry.current_state(c)
|
||||
return c.Build(xops.Tuple(c, list(outs) + result_state))
|
||||
|
||||
true_op = xops.Tuple(c, true_ops)
|
||||
true_op = xops.Tuple(c, true_ops + current_state)
|
||||
true_c = make_computation('true', true_jaxpr, c.GetShape(true_op))
|
||||
|
||||
false_op = xops.Tuple(c, false_ops)
|
||||
false_op = xops.Tuple(c, false_ops + current_state)
|
||||
false_c = make_computation('false', false_jaxpr, c.GetShape(false_op))
|
||||
cond_op = xops.Conditional(pred, true_op, true_c, false_op, false_c)
|
||||
nr_outs = len(true_jaxpr.out_avals)
|
||||
|
||||
return xops.Conditional(pred, true_op, true_c, false_op, false_c)
|
||||
return xla.computation_state_carry.extract_state_from_tuple_op(
|
||||
c, cond_op, nr_outs)
|
||||
|
||||
def _cond_pred_bcast_select(pred, x, y):
|
||||
if core.get_aval(x) is core.get_aval(y) is core.abstract_unit:
|
||||
|
485
tests/host_callback_test.py
Normal file
485
tests/host_callback_test.py
Normal file
@ -0,0 +1,485 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import numpy as onp
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from unittest import SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import api
|
||||
from jax import lax
|
||||
from jax import numpy as np
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.lib import xla_bridge
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
def skip_if_jit_not_enabled():
|
||||
if os.getenv("JAX_ENABLE_JIT_PRINT", "false") == "false":
|
||||
raise SkipTest("print jit not enabled yet; use JAX_ENABLE_JIT_PRINT env.")
|
||||
|
||||
class _TestingOutputStream(object):
|
||||
"""Use as `output_stream` for tests."""
|
||||
|
||||
def __init__(self):
|
||||
self._output = []
|
||||
|
||||
def write(self, what: str) -> None:
|
||||
# Sometimes we get floating points in the output; we round them
|
||||
def repl(match_group):
|
||||
# TODO: why can't we use here np.around?
|
||||
matched = match_group.group(0)
|
||||
if matched == ".": return matched
|
||||
x = onp.around(float(matched), decimals=2)
|
||||
return f"{x:.2f}"
|
||||
|
||||
what = re.sub(r"\-?\d*\.[\-\def]*", repl, what)
|
||||
print(f"output_stream: {what}")
|
||||
self._output.append(what)
|
||||
|
||||
@property
|
||||
def output(self):
|
||||
return "\n".join(self._output)
|
||||
|
||||
def __str__(self):
|
||||
return "TestingOutputStream"
|
||||
|
||||
def reset(self):
|
||||
self._output = []
|
||||
|
||||
|
||||
testing_stream = _TestingOutputStream()
|
||||
|
||||
|
||||
def fun1(a):
|
||||
y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream)
|
||||
y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y)
|
||||
return y**4 # Some computation to make the gradient interesting
|
||||
|
||||
|
||||
def fun1_equiv(a): # Numerical equivalent of fun`
|
||||
return (a * 2.)**4
|
||||
|
||||
|
||||
class HostCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
testing_stream.reset()
|
||||
|
||||
def helper_set_devices(self, nr_devices):
|
||||
flags_str = os.getenv("XLA_FLAGS", "")
|
||||
os.environ["XLA_FLAGS"] = (
|
||||
flags_str +
|
||||
" --xla_force_host_platform_device_count={}".format(nr_devices))
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
return api.devices()
|
||||
|
||||
def helper_set_hlo_dump(self):
|
||||
flags_str = os.getenv("XLA_FLAGS", "")
|
||||
os.environ["XLA_FLAGS"] = f"{flags_str} --xla_dump_to=/tmp/xla_dump"
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
def helper_print_serialization(self, description, np_vals, params):
|
||||
"""Encode a print_descriptor and print it.
|
||||
|
||||
Args:
|
||||
np_vals: a list of np.ndarray from which to extract the shapes.
|
||||
"""
|
||||
encoded = hcb._make_id_print_metadata(
|
||||
[xla.aval_to_xla_shape(pe.get_aval(np_val)) for np_val in np_vals],
|
||||
params)
|
||||
print(f"test_serialization: {description}")
|
||||
print(", ".join([f"{b}" for b in encoded]))
|
||||
|
||||
def test_serialization(self):
|
||||
"""Prints encodings used in host_callback_test::TestParseDescriptor."""
|
||||
raise SkipTest("Not implemented")
|
||||
self.helper_print_serialization("no args, separator=sep, param=0", [],
|
||||
dict(param=0, separator="sep"))
|
||||
self.helper_print_serialization("1 scalar int, separator= param=1",
|
||||
[np.int32(0)], dict(param=1, separator=""))
|
||||
self.helper_print_serialization("1 array f32[2, 3], separator= param=2",
|
||||
[np.ones((2, 3), dtype=np.float32)],
|
||||
dict(param=2, separator=""))
|
||||
self.helper_print_serialization(
|
||||
"1 array f32[2, 3] and 1 f64, separator= param=3",
|
||||
[np.ones((2, 3), dtype=np.float32),
|
||||
np.float64(0)], dict(param=3, separator=""))
|
||||
|
||||
def test_with_tuple_result(self):
|
||||
|
||||
def func2(x):
|
||||
x1, y1 = hcb.id_print(x * 2., x * 3., output_stream=testing_stream)
|
||||
return x1 + y1
|
||||
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
{ lambda ; a.
|
||||
let b = mul a 2.0
|
||||
c = mul a 3.0
|
||||
d e = id_print[ output_stream=TestingOutputStream ] b c
|
||||
f = add d e
|
||||
in (f,) }""", str(api.make_jaxpr(func2)(3.)))
|
||||
self.assertEqual(3. * (2. + 3.), func2(3.))
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
(6.00, 9.00) {}""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_eval(self):
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
{ lambda ; a.
|
||||
let b = mul a 2.0
|
||||
c = id_print[ output_stream=TestingOutputStream
|
||||
what=a * 2 ] b
|
||||
d = mul c 3.0
|
||||
e f = id_print[ nr_results=1
|
||||
output_stream=TestingOutputStream
|
||||
what=y * 3 ] d c
|
||||
g = pow f 4.0
|
||||
in (g,) }""", str(api.make_jaxpr(fun1)(5.)))
|
||||
self.assertEqual("", testing_stream.output)
|
||||
|
||||
self.assertEqual((5. * 2.)**4, fun1(5.))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
(10.00,) {'what': 'a * 2'}
|
||||
(30.00, 10.00) {'what': 'y * 3', 'nr_results': 1}""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_jit_simple(self):
|
||||
jit_fun1 = api.jit(lambda x: hcb.end_printing(3. * hcb.id_print(
|
||||
2. * x, what="here")))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
{ lambda ; a.
|
||||
let b = xla_call[ backend=None
|
||||
call_jaxpr={ lambda ; a.
|
||||
let b = mul a 2.0
|
||||
c = id_print[ what=here ] b
|
||||
d = mul c 3.0
|
||||
e f = id_print[ nr_results=1 ] 12345678 d
|
||||
in (f,) }
|
||||
device=None
|
||||
name=<lambda> ] a
|
||||
in (b,) }""", str(api.make_jaxpr(jit_fun1)(5.)))
|
||||
logging.info("%s: %s",
|
||||
self._testMethodName, api.xla_computation(jit_fun1)(5.).GetHloText())
|
||||
with hcb.print_receiver(output_stream=testing_stream,
|
||||
receiver_name=self._testMethodName):
|
||||
res = jit_fun1(5.)
|
||||
self.assertAllClose(6. * 5., res, check_dtypes=True)
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
10.00""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_simple_jit_sequencing(self):
|
||||
def func(x):
|
||||
x1 = hcb.id_print(x, where="1")
|
||||
x2 = hcb.id_print(x1 + 1, where="2")
|
||||
return hcb.end_printing(x2)
|
||||
|
||||
logging.info("%s: %s", self._testMethodName,
|
||||
api.make_jaxpr(func)(1))
|
||||
logging.info("%s: %s", self._testMethodName,
|
||||
api.xla_computation(func)(1).GetHloText())
|
||||
|
||||
with hcb.print_receiver(output_stream=testing_stream,
|
||||
receiver_name=self._testMethodName):
|
||||
self.assertEqual(2, api.jit(func)(1))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
1
|
||||
2""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_jit2(self):
|
||||
"""A sequence of JIT."""
|
||||
def func(x):
|
||||
x1 = hcb.id_print(x, where="1")
|
||||
x2 = hcb.id_print(x1 + 1, where="2")
|
||||
return x2
|
||||
|
||||
with hcb.print_receiver(output_stream=testing_stream,
|
||||
receiver_name=self._testMethodName):
|
||||
self.assertEqual(2, api.jit(func)(1))
|
||||
self.assertEqual(11, api.jit(func)(10))
|
||||
# Now send the end of printing
|
||||
api.jit(lambda x: hcb.end_printing(x))(0)
|
||||
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
1
|
||||
2
|
||||
10
|
||||
11""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_jit_nested(self):
|
||||
def func(x):
|
||||
x1 = hcb.id_print(x, where="1")
|
||||
def func_nested(x):
|
||||
x2 = hcb.id_print(x, where="nested")
|
||||
return x2
|
||||
x3 = api.jit(func_nested)(x1)
|
||||
x2 = hcb.id_print(x3 + 1, where="2")
|
||||
return hcb.end_printing(x2)
|
||||
|
||||
logging.warning("%s: %s", self._testMethodName,
|
||||
api.make_jaxpr(func)(1))
|
||||
logging.warning("%s: %s", self._testMethodName,
|
||||
api.xla_computation(func)(1).GetHloText())
|
||||
return
|
||||
with hcb.print_receiver(output_stream=testing_stream,
|
||||
receiver_name=self._testMethodName):
|
||||
pass # self.assertEqual(2, api.jit(func)(1))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
1
|
||||
2""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_jit_cond1(self):
|
||||
"""A conditional"""
|
||||
def func(x):
|
||||
x1 = hcb.id_print(x, where="1")
|
||||
x2 = hcb.id_print(x1 + 1, where="2")
|
||||
|
||||
x4 = lax.cond(x % 2 == 0,
|
||||
x2 + 1, lambda x: hcb.id_print(x, where="cond_t"),
|
||||
x2 + 1, lambda x: hcb.id_print(-1, where="cond_f", result=x))
|
||||
x5 = hcb.id_print(x4 + 1, where="w.2")
|
||||
return hcb.end_printing(x5)
|
||||
|
||||
logging.warning("%s: %s", self._testMethodName, api.make_jaxpr(func)(1))
|
||||
logging.warning("%s: %s", self._testMethodName,
|
||||
api.xla_computation(func)(1).GetHloText())
|
||||
|
||||
with hcb.print_receiver(output_stream=testing_stream,
|
||||
receiver_name=self._testMethodName):
|
||||
self.assertEqual(4, api.jit(func)(1))
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
1
|
||||
2
|
||||
-1
|
||||
4""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
|
||||
def test_jit_while_cond(self):
|
||||
def func(x):
|
||||
x1 = hcb.id_print(x, where="1")
|
||||
x2 = hcb.id_print(x1 + 1, where="2")
|
||||
def body(x):
|
||||
x3 = hcb.id_print(x, where="w.1")
|
||||
x4 = lax.cond(x % 2 == 0,
|
||||
x3 + 1, lambda x: hcb.id_print(x, where="w.t"),
|
||||
x3 + 1, lambda x: hcb.id_print(-1, where="w.f", result=x))
|
||||
return hcb.id_print(x4 + 1, where="w.2")
|
||||
x10 = lax.while_loop(lambda x: x < 10, body, x2)
|
||||
res = hcb.id_print(x10, where="10")
|
||||
return hcb.end_printing(res)
|
||||
logging.warning("%s: %s", self._testMethodName, api.make_jaxpr(func)(1))
|
||||
logging.warning("%s: %s", self._testMethodName,
|
||||
api.xla_computation(func)(1).GetHloText())
|
||||
|
||||
with hcb.print_receiver(output_stream=testing_stream,
|
||||
receiver_name=self._testMethodName):
|
||||
self.assertEqual(10, api.jit(func)(1))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
1
|
||||
2
|
||||
2
|
||||
3
|
||||
4
|
||||
4
|
||||
5
|
||||
6
|
||||
6
|
||||
7
|
||||
8
|
||||
8
|
||||
9
|
||||
10
|
||||
10""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(
|
||||
testcase_name=f"_shape_{shape}_dtype_{dtype}_nr_args={nr_args}",
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
nr_args=nr_args) for nr_args in [1, 2]
|
||||
for shape in [(), (2,), (2, 3), (2, 3, 4)]
|
||||
for dtype in jtu.supported_dtypes()))
|
||||
def test_jit_types(self, nr_args=2, dtype=np.int16, shape=(2,)):
|
||||
if dtype in (np.complex64, np.complex128, np.bool_):
|
||||
raise SkipTest(f"id_print jit not implemented for {dtype}.")
|
||||
if jtu.device_under_test() == "tpu":
|
||||
if dtype in (np.int16,):
|
||||
raise SkipTest(f"transfering {dtype} not supported on TPU")
|
||||
self.helper_set_hlo_dump()
|
||||
args = [np.arange(np.prod(shape), dtype=dtype).reshape(shape)]
|
||||
if nr_args > 1:
|
||||
args = args * nr_args
|
||||
jit_fun1 = api.jit(lambda xs: hcb.end_printing(
|
||||
hcb.id_print(
|
||||
*xs,
|
||||
a_new_test="************",
|
||||
testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}")))
|
||||
with hcb.print_receiver(receiver_name=self._testMethodName):
|
||||
res = jit_fun1(args)
|
||||
# self.assertAllClose(args, res, check_dtypes=True)
|
||||
|
||||
def test_jit_large(self):
|
||||
arg = np.arange(10000, dtype=np.int32).reshape((10, 10, 5, -1))
|
||||
with hcb.print_receiver(output_stream=testing_stream,
|
||||
receiver_name=self._testMethodName):
|
||||
api.jit(lambda x: hcb.end_printing(hcb.id_print(x)))(arg)
|
||||
|
||||
def test_jvp(self):
|
||||
jvp_fun1 = lambda x, xt: api.jvp(fun1, (x,), (xt,))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
{ lambda ; a b.
|
||||
let c = mul a 2.0
|
||||
d = id_print[ output_stream=TestingOutputStream
|
||||
what=a * 2 ] c
|
||||
e = mul d 3.0
|
||||
f g = id_print[ nr_results=1
|
||||
output_stream=TestingOutputStream
|
||||
what=y * 3 ] e d
|
||||
h = pow g 4.0
|
||||
i = mul b 2.0
|
||||
j = id_print[ output_stream=TestingOutputStream
|
||||
transforms=('jvp',)
|
||||
what=a * 2 ] i
|
||||
k = mul j 3.0
|
||||
l m = id_print[ nr_results=1
|
||||
output_stream=TestingOutputStream
|
||||
transforms=('jvp',)
|
||||
what=y * 3 ] k j
|
||||
n = pow g 3.0
|
||||
o = mul 4.0 n
|
||||
p = mul m o
|
||||
in (h, p) }""",
|
||||
str(api.make_jaxpr(jvp_fun1)(np.float32(5.), np.float32(0.1))))
|
||||
|
||||
res_primals, res_tangents = jvp_fun1(np.float32(5.), np.float32(0.1))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
(DeviceArray(10.00, dtype=float32),) {'what': 'a * 2'}
|
||||
(DeviceArray(0.20, dtype=float32),) {'what': 'a * 2', 'transforms': ('jvp',)}
|
||||
(DeviceArray(30.00, dtype=float32), DeviceArray(10.00, dtype=float32)) {'what': 'y * 3', 'nr_results': 1}
|
||||
(DeviceArray(0.60, dtype=float32), DeviceArray(0.20, dtype=float32)) {'what': 'y * 3', 'nr_results': 1, 'transforms': ('jvp',)}
|
||||
""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_grad(self):
|
||||
raise SkipTest("failing with new implementation")
|
||||
grad_fun1 = api.grad(fun1)
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
{ lambda ; a.
|
||||
let b = mul a 2.0
|
||||
c = id_print[ output_stream=TestingOutputStream
|
||||
what=a * 2 ] b
|
||||
d = mul c 3.0
|
||||
e = id_print[ output_stream=TestingOutputStream
|
||||
what=y * 3 ] d
|
||||
f = tie_in e c
|
||||
g = pow f 3.0
|
||||
h = mul 4.0 g
|
||||
i = mul 1.0 h
|
||||
j = id_print[ output_stream=TestingOutputStream
|
||||
transforms=('jvp', 'transpose')
|
||||
what=a * 2 ] i
|
||||
k = mul j 2.0
|
||||
in (k,) }""", str(api.make_jaxpr(grad_fun1)(5.)))
|
||||
|
||||
# This comes from the actual partial evaluation
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
(Zero,) {'what': 'y * 3', 'transforms': ('jvp', 'transpose')}
|
||||
""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
res_grad = grad_fun1(np.float32(5.))
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
(DeviceArray(10.00, dtype=float32),) {'what': 'a * 2'}
|
||||
(DeviceArray(30.00, dtype=float32),) {'what': 'y * 3'}
|
||||
(Zero,) {'what': 'y * 3', 'transforms': ('jvp', 'transpose')}
|
||||
(DeviceArray(4000.00, dtype=float32),) {'what': 'a * 2', 'transforms': ('jvp', 'transpose')}
|
||||
""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_vmap(self):
|
||||
vmap_fun1 = api.vmap(fun1)
|
||||
vargs = np.array([np.float32(4.), np.float32(5.)])
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
{ lambda ; a.
|
||||
let b = mul a 2.0
|
||||
c = id_print[ output_stream=TestingOutputStream
|
||||
transforms=('batch',)
|
||||
what=a * 2 ] b
|
||||
d = mul c 3.0
|
||||
e f = id_print[ nr_results=1
|
||||
output_stream=TestingOutputStream
|
||||
transforms=('batch',)
|
||||
what=y * 3 ] d c
|
||||
g = pow f 4.0
|
||||
in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs)))
|
||||
|
||||
res_vmap = vmap_fun1(vargs)
|
||||
self.assertMultiLineStrippedEqual(
|
||||
"""
|
||||
(DeviceArray([ 8.00, 10.00], dtype=float32),) {'what': 'a * 2', 'transforms': ('batch',)}
|
||||
(DeviceArray([24.00, 30.00], dtype=float32), DeviceArray([ 8.00, 10.00], dtype=float32)) {'what': 'y * 3', 'nr_results': 1, 'transforms': ('batch',)}
|
||||
""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_pmap(self):
|
||||
skip_if_jit_not_enabled()
|
||||
self.helper_set_devices(4)
|
||||
vargs = np.arange(api.local_device_count(), dtype=np.float32)
|
||||
|
||||
pmap_fun1 = api.pmap(fun1, axis_name="i")
|
||||
res = pmap_fun1(vargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user