mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Add build time support for AOT compilation to TF graphs.
PiperOrigin-RevId: 417392920
This commit is contained in:
parent
52cf360e4a
commit
2a6147af1b
@ -17,9 +17,23 @@ licenses(["notice"])
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
py_library(
|
||||
name = "jax_to_hlo",
|
||||
srcs = ["jax_to_hlo.py"],
|
||||
name = "jax_to_ir",
|
||||
srcs = ["jax_to_ir.py"],
|
||||
tags = [
|
||||
"ignore_for_dep=third_party.py.jax.experimental.jax2tf",
|
||||
"ignore_for_dep=third_party.py.tensorflow",
|
||||
],
|
||||
deps = [
|
||||
"//third_party/py/jax",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "jax_to_ir_with_tensorflow",
|
||||
srcs = ["jax_to_ir.py"],
|
||||
deps = [
|
||||
"//third_party/py/jax",
|
||||
"//third_party/py/jax/experimental/jax2tf",
|
||||
"//third_party/py/tensorflow",
|
||||
],
|
||||
)
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""JAX tools."""
|
||||
|
||||
def _shell_quote(s):
|
||||
"""Copy of bazel-skylib's shell.quote.
|
||||
|
||||
@ -28,7 +30,13 @@ def _shell_quote(s):
|
||||
return "'" + s.replace("'", "'\\''") + "'"
|
||||
|
||||
def jax_to_hlo(name, deps, fn, input_shapes, constants = None):
|
||||
"""Creates a genrule that uses jax_to_hlo.py to make HLO from a JAX func.
|
||||
jax_to_ir(name, deps, fn, input_shapes, constants = constants, format = "HLO")
|
||||
|
||||
def jax_to_tf(name, deps, fn, input_shapes, constants = None):
|
||||
jax_to_ir(name, deps, fn, input_shapes, constants = constants, format = "TF")
|
||||
|
||||
def jax_to_ir(name, deps, fn, input_shapes, constants = None, format = "HLO"):
|
||||
"""Creates a genrule that uses jax_to_ir.py to make proto from a JAX func.
|
||||
|
||||
Suppose we have
|
||||
|
||||
@ -46,9 +54,9 @@ def jax_to_hlo(name, deps, fn, input_shapes, constants = None):
|
||||
Then we can invoke this macro as follows.
|
||||
|
||||
$ cat your/thing/BUILD
|
||||
load("//jax/tools:build_defs.bzl", "jax_to_hlo")
|
||||
load("//jax/tools:build_defs.bzl", "jax_to_ir")
|
||||
|
||||
jax_to_hlo(
|
||||
jax_to_ir(
|
||||
name = "prog_hlo",
|
||||
deps = ["//your/thing:prog"],
|
||||
fn = "your.thing.prog.fn", # Fully-qualified module name
|
||||
@ -59,6 +67,7 @@ def jax_to_hlo(name, deps, fn, input_shapes, constants = None):
|
||||
constants = {
|
||||
"z": "3.14159",
|
||||
},
|
||||
format = 'HLO',
|
||||
)
|
||||
|
||||
This generates two build rules, named
|
||||
@ -73,7 +82,7 @@ def jax_to_hlo(name, deps, fn, input_shapes, constants = None):
|
||||
That is, the above macro will create a program which accepts parameters y
|
||||
and x, in that order.
|
||||
|
||||
Skylark doesn't support floating-point numbers without a special option, so
|
||||
Starlark doesn't support floating-point numbers without a special option, so
|
||||
we need special rules to be able to pass fp values in `constants`. Each
|
||||
dict value `v` is transformed as follows.
|
||||
|
||||
@ -103,37 +112,43 @@ def jax_to_hlo(name, deps, fn, input_shapes, constants = None):
|
||||
will be in the order specified here.
|
||||
constants: Python dictionary mapping arg names to constant values they
|
||||
should take on, e.g. {"z": "float(3.14159")}.
|
||||
format: Either HLO or TF.
|
||||
"""
|
||||
if not constants:
|
||||
constants = {}
|
||||
|
||||
# Our goal here is to create a py_binary which depends on `deps` and
|
||||
# invokes the main function defined in jax_to_hlo.py.
|
||||
# invokes the main function defined in jax_to_ir.py.
|
||||
#
|
||||
# At first blush it seems that we can do this in a straightforward way:
|
||||
#
|
||||
# native.py_binary(main = "jax_to_hlo.py").
|
||||
# native.py_binary(main = "jax_to_ir.py").
|
||||
#
|
||||
# The problem with this is that this py_binary lives in the user's package,
|
||||
# whereas jax_to_hlo.py lives inside JAX's package. Bazel delivers a stern
|
||||
# whereas jax_to_ir.py lives inside JAX's package. Bazel delivers a stern
|
||||
# warning if you name a file in `main` that's outside of your package.
|
||||
#
|
||||
# To avoid the warning, we generate a simple "main file" in the user's
|
||||
# package. It only complicates the rules a bit.
|
||||
runner = name + "_jax_to_hlo_main"
|
||||
runner = name + "_jax_to_ir_main"
|
||||
native.genrule(
|
||||
name = runner + "_gen",
|
||||
outs = [runner + ".py"],
|
||||
cmd = """cat <<EOF > '$(location {runner}.py)'
|
||||
from absl import app
|
||||
import jax.tools.jax_to_hlo as jax_to_hlo
|
||||
import jax.tools.jax_to_ir as jax_to_ir
|
||||
|
||||
jax_to_hlo.set_up_flags()
|
||||
app.run(jax_to_hlo.main)
|
||||
jax_to_ir.set_up_flags()
|
||||
app.run(jax_to_ir.main)
|
||||
EOF
|
||||
""".format(runner = runner),
|
||||
)
|
||||
|
||||
if format == "TF":
|
||||
jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir_with_tensorflow"
|
||||
else:
|
||||
jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir"
|
||||
|
||||
native.py_binary(
|
||||
name = runner,
|
||||
srcs = [
|
||||
@ -142,7 +157,7 @@ EOF
|
||||
python_version = "PY3",
|
||||
deps = deps + [
|
||||
"//third_party/py/jax/jaxlib",
|
||||
"//third_party/py/jax/tools:jax_to_hlo",
|
||||
jax_to_ir_rule,
|
||||
],
|
||||
)
|
||||
|
||||
@ -151,7 +166,7 @@ EOF
|
||||
# Set JAX_PLATFORM_NAME to "cpu" to silence the "no GPU/TPU backend found,
|
||||
# falling back to CPU" warning.
|
||||
native.genrule(
|
||||
name = name + "_jax_to_hlo_genrule",
|
||||
name = name + "_jax_to_ir_genrule",
|
||||
outs = [name + ".pb", name + ".txt"],
|
||||
exec_tools = [runner],
|
||||
cmd = """
|
||||
@ -160,13 +175,15 @@ EOF
|
||||
--fn {fn} \
|
||||
--input_shapes {input_shapes} \
|
||||
--evaled_constants {constants} \
|
||||
--hlo_proto_dest '$(location {name}.pb)' \
|
||||
--hlo_text_dest '$(location {name}.txt)' \
|
||||
--ir_format {format} \
|
||||
--ir_dest '$(location {name}.pb)' \
|
||||
--ir_human_dest '$(location {name}.txt)' \
|
||||
""".format(
|
||||
name = name,
|
||||
fn = fn,
|
||||
input_shapes = _shell_quote(str(input_shapes)),
|
||||
constants = _shell_quote(str(constants)),
|
||||
runner = runner,
|
||||
format = _shell_quote(format),
|
||||
),
|
||||
)
|
||||
|
@ -12,11 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""Tool to convert a JAX function to an HLO proto.
|
||||
r"""Tool to convert a JAX function to serialized representations.
|
||||
|
||||
This script is meant to be used as part of a genrule that converts a JAX program
|
||||
into an HLO proto. The HLO proto represents an XLA program, and can be run from
|
||||
e.g. a C++ program, without involving any Python.
|
||||
into an IR that can be consumed by another system (e.g. a compiler).
|
||||
|
||||
Convert to HLO
|
||||
==============
|
||||
|
||||
For example, you can generate an HLO proto for the XLA compiler. The HLO proto
|
||||
represents an XLA program, and can be run from e.g. a C++ program, without
|
||||
involving any Python.
|
||||
|
||||
This lets you use JAX as a convenient frontend for writing "XLA programs". From
|
||||
another perspective, this script lets you make JAX into an ahead-of-time JAX ->
|
||||
@ -33,12 +39,13 @@ Usage:
|
||||
def fn(x, y, z):
|
||||
return jnp.dot(x, y) / z
|
||||
|
||||
$ python jax_to_hlo.py \
|
||||
$ python jax_to_ir.py \
|
||||
--fn prog.fn \
|
||||
--input_shapes '[("y": "f32[128,32]"), ("x", "f32[8,128]")]' \
|
||||
--constants '{"z": 3.14159}' \
|
||||
--hlo_text_dest /tmp/fn_hlo.txt \
|
||||
--hlo_proto_dest /tmp/fn_hlo.pb
|
||||
--ir_format HLO \
|
||||
--ir_human_dest /tmp/fn_hlo.txt \
|
||||
--ir_dest /tmp/fn_hlo.pb
|
||||
|
||||
Alternatively, you can use this script via a genrule. This way bazel will
|
||||
generate the hlo text/proto as part of compilation, and then e.g. a C++ program
|
||||
@ -58,38 +65,45 @@ Note that XLA's backwards-compatibility guarantees for saved HLO are currently
|
||||
and the XLA team won't (and in fact will be unable to) help. One way to be sure
|
||||
it won't break is to use the same version of XLA to build the HLO as you use to
|
||||
run it. The genrule above makes this easy.
|
||||
|
||||
Implementation note: This script must be python2 compatible for now, because
|
||||
Google's genrules still run with python2, b/66712815.
|
||||
"""
|
||||
|
||||
|
||||
from ast import literal_eval
|
||||
import importlib
|
||||
import functools
|
||||
import re
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
try:
|
||||
from jax.experimental import jax2tf
|
||||
except ImportError:
|
||||
jax2tf = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
tf = None # type: ignore
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def jax_to_hlo(fn, input_shapes, constants=None):
|
||||
"""Converts a JAX function to an HLO module.
|
||||
def jax_to_ir(fn, input_shapes, *, constants=None, format):
|
||||
"""Converts a JAX function to a serialized ir and a debug txt dump.
|
||||
|
||||
Args:
|
||||
fn: Function to convert.
|
||||
input_shapes: List of tuples (arg name, xla_client.Shape),
|
||||
input_shapes: List of tuples (arg name, jax.ShapedArray),
|
||||
indicating the shapes of the arguments to fn. The order of parameters in
|
||||
the resulting XLA program will match the order in this list.
|
||||
constants: Dict mapping function argument name to a Python value. Specified
|
||||
arguments these values as compile-time constants.
|
||||
format: Which IR format to use. Supported values are 'HLO' and 'TF'.
|
||||
|
||||
Returns:
|
||||
A tuple (serialized_hlo_proto, hlo_text).
|
||||
A tuple of (compiler_suitable_ir, human_readable_ir).
|
||||
"""
|
||||
if not constants:
|
||||
constants = {}
|
||||
@ -101,25 +115,8 @@ def jax_to_hlo(fn, input_shapes, constants=None):
|
||||
'Arguments appear in both `input_shapes` and `constants`: %s' %
|
||||
', '.join(sorted(overlapping_args)))
|
||||
|
||||
args = []
|
||||
for arg_name, shape in input_shapes:
|
||||
if not shape.is_array():
|
||||
raise ValueError('Shape %s is not an array, but currently only arrays '
|
||||
'are supported (i.e., no tuples, nor tokens).' % str(shape))
|
||||
|
||||
# Check that `shape` either doesn't have a layout or has the default layout.
|
||||
#
|
||||
# TODO(jlebar): This could be simpler if the Shape class exposed its layout,
|
||||
# or if Shape exposed a function to unconditionally use the default layout.
|
||||
shape_with_default_layout = xla_client.Shape.array_shape(
|
||||
shape.xla_element_type(),
|
||||
shape.dimensions()).with_major_to_minor_layout_if_absent()
|
||||
if (shape.with_major_to_minor_layout_if_absent() !=
|
||||
shape_with_default_layout):
|
||||
raise ValueError('Shape %s has a non-default layout, but only '
|
||||
'the default layout is allowed.' % str(shape))
|
||||
|
||||
args.append(jnp.zeros(shape.dimensions(), dtype=shape.numpy_dtype()))
|
||||
# TODO(tomhennigan): Ideally we could avoid creating actual values here.
|
||||
args = [jnp.zeros(s.shape, s.dtype) for _, s in input_shapes]
|
||||
|
||||
# Curry `constants` into the function.
|
||||
fn_curried = functools.partial(fn, **constants)
|
||||
@ -130,23 +127,51 @@ def jax_to_hlo(fn, input_shapes, constants=None):
|
||||
arg_names = [arg_name for arg_name, _ in input_shapes]
|
||||
return fn_curried(**dict(zip(arg_names, args)))
|
||||
|
||||
comp = jax.xla_computation(ordered_wrapper)(*args)
|
||||
return (comp.as_serialized_hlo_module_proto(), comp.as_hlo_text())
|
||||
if format == 'HLO':
|
||||
comp = jax.xla_computation(ordered_wrapper)(*args)
|
||||
serialized_proto = comp.as_serialized_hlo_module_proto()
|
||||
debug_txt = comp.as_hlo_text()
|
||||
else:
|
||||
assert format == 'TF'
|
||||
if tf is None:
|
||||
raise ValueError(
|
||||
'Conversion to TF graph requires TensorFlow to be installed.')
|
||||
|
||||
f = jax2tf.convert(ordered_wrapper)
|
||||
f = tf_wrap_with_input_names(f, input_shapes)
|
||||
f = tf.function(f, autograph=False)
|
||||
g = f.get_concrete_function(*args).graph.as_graph_def()
|
||||
serialized_proto = g.SerializeToString()
|
||||
debug_txt = str(g)
|
||||
|
||||
return serialized_proto, debug_txt
|
||||
|
||||
|
||||
def tf_wrap_with_input_names(f, input_shapes):
|
||||
def wrapper(*args):
|
||||
args = tuple(
|
||||
tf.identity(a, name=name) for a, (name, _) in zip(args, input_shapes))
|
||||
# NOTE: Output names already set via `jax2tf.convert(..)`.
|
||||
return f(*args)
|
||||
return wrapper
|
||||
|
||||
jax_to_hlo = functools.partial(jax_to_ir, format='HLO')
|
||||
jax_to_tf = functools.partial(jax_to_ir, format='TF')
|
||||
|
||||
|
||||
def main(argv):
|
||||
if len(argv) != 1:
|
||||
raise app.UsageError('No positional arguments are accepted.')
|
||||
|
||||
if not FLAGS.hlo_proto_dest and not FLAGS.hlo_text_dest:
|
||||
raise app.Error('At least one of --hlo_proto_dest and '
|
||||
'--hlo_text_dest is required.')
|
||||
if not FLAGS.ir_dest and not FLAGS.ir_human_dest:
|
||||
raise app.Error('At least one of --ir_dest and '
|
||||
'--ir_human_dest is required.')
|
||||
|
||||
module_name, fn_name = FLAGS.fn.rsplit('.', 1)
|
||||
module = importlib.import_module(module_name)
|
||||
fn = getattr(module, fn_name)
|
||||
|
||||
input_shapes = [(name, xla_client.Shape(shape_str))
|
||||
input_shapes = [(name, parse_shape_str(shape_str))
|
||||
for name, shape_str in literal_eval(FLAGS.input_shapes)]
|
||||
|
||||
# Parse --constants and --evaled_constants.
|
||||
@ -166,15 +191,38 @@ def main(argv):
|
||||
'Argument appears in both --constants and --evaled_constants: %s' % k)
|
||||
constants[k] = v
|
||||
|
||||
hlo_proto, hlo_text = jax_to_hlo(fn, input_shapes, constants)
|
||||
ir, debug_ir = jax_to_ir(fn, input_shapes, constants=constants,
|
||||
format=FLAGS.ir_format)
|
||||
|
||||
if FLAGS.hlo_proto_dest:
|
||||
with open(FLAGS.hlo_proto_dest, 'wb') as f:
|
||||
f.write(hlo_proto)
|
||||
if FLAGS.ir_dest:
|
||||
with open(FLAGS.ir_dest, 'wb') as f:
|
||||
f.write(ir)
|
||||
|
||||
if FLAGS.ir_human_dest:
|
||||
with open(FLAGS.ir_human_dest, 'w') as f:
|
||||
f.write(debug_ir)
|
||||
|
||||
|
||||
def parse_shape_str(s):
|
||||
match = _SHAPE_RE.match(s)
|
||||
if not match:
|
||||
raise ValueError(f'Invalid shape {s}. Valid example: "f32[1,2,3]".'
|
||||
f'Note that dtype must be one of {list(_DT)}')
|
||||
dtype = _DT[match.group(1)]
|
||||
if match.group(2):
|
||||
shape = tuple(int(d.strip()) for d in match.group(2).split(","))
|
||||
else:
|
||||
shape = ()
|
||||
return jax.ShapedArray(shape, dtype)
|
||||
|
||||
_DT = {'pred': jnp.bool_,
|
||||
'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64,
|
||||
's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64,
|
||||
'bf16': jnp.bfloat16,
|
||||
'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64,
|
||||
'c64': jnp.complex64, 'c128': jnp.complex128}
|
||||
_SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$")
|
||||
|
||||
if FLAGS.hlo_text_dest:
|
||||
with open(FLAGS.hlo_text_dest, 'w') as f:
|
||||
f.write(hlo_text)
|
||||
|
||||
def set_up_flags():
|
||||
flags.DEFINE_string(
|
||||
@ -188,8 +236,10 @@ def set_up_flags():
|
||||
'Python dict giving constant values for some params. '
|
||||
'Values in this dict that are of type str are evaluated '
|
||||
'using ast.literal_eval.')
|
||||
flags.DEFINE_string('hlo_proto_dest', None, 'File to write HLO proto')
|
||||
flags.DEFINE_string('hlo_text_dest', None, 'File to write HLO text')
|
||||
flags.DEFINE_enum('ir_format', 'HLO', ('HLO', 'TF'), 'Output format.')
|
||||
flags.DEFINE_string('ir_dest', None, 'File to write IR to')
|
||||
flags.DEFINE_string('ir_human_dest', None,
|
||||
'File to write human readable debug output')
|
||||
flags.mark_flag_as_required('fn')
|
||||
flags.mark_flag_as_required('input_shapes')
|
||||
|
@ -1,76 +0,0 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax._src.lib import xla_client
|
||||
import jax.numpy as jnp
|
||||
from jax.tools.jax_to_hlo import jax_to_hlo
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
|
||||
class JaxToHloTest(absltest.TestCase):
|
||||
|
||||
def test_convert_axpy(self):
|
||||
|
||||
def axpy(a, x, y):
|
||||
return a * x + y[:, jnp.newaxis]
|
||||
|
||||
hlo_proto, hlo_text = jax_to_hlo(
|
||||
axpy, [
|
||||
('y', xla_client.Shape('f32[128]')),
|
||||
('a', xla_client.Shape('f32[]')),
|
||||
('x', xla_client.Shape('f32[128,2]')),
|
||||
])
|
||||
|
||||
# Check that hlo_text contains a broadcast, add, and multiply.
|
||||
self.assertIn('broadcast', hlo_text)
|
||||
self.assertIn('add', hlo_text)
|
||||
self.assertIn('multiply', hlo_text)
|
||||
|
||||
# Check that the HLO parameters are in the order we specified in the
|
||||
# jax_to_hlo call.
|
||||
self.assertIn('f32[128]{0} parameter(0)', hlo_text)
|
||||
self.assertIn('f32[] parameter(1)', hlo_text)
|
||||
self.assertIn('f32[128,2]{1,0} parameter(2)', hlo_text)
|
||||
|
||||
# Check that the parameters are in the expected order.
|
||||
|
||||
# TODO(jlebar): Ideally we'd check that hlo_proto can be deserialized to a
|
||||
# valid HLO proto, but we don't seem to have access to hlo_pb2 at the
|
||||
# moment, so the best we seem to be able to do is check that it's nonempty.
|
||||
assert hlo_proto
|
||||
|
||||
def test_convert_with_constants(self):
|
||||
|
||||
def fn(a, b, x, y):
|
||||
return a / b * x + y
|
||||
|
||||
_, hlo_text = jax_to_hlo(
|
||||
fn,
|
||||
input_shapes=[
|
||||
('x', xla_client.Shape('f32[128]')),
|
||||
('y', xla_client.Shape('f32[128]')),
|
||||
],
|
||||
constants={
|
||||
'a': 123456,
|
||||
'b': 4,
|
||||
})
|
||||
# Because we passed `a` and `b` as constants, they get constant-folded away
|
||||
# by Python/JAX to a/b = 30864.
|
||||
self.assertIn('constant(30864)', hlo_text)
|
||||
self.assertNotIn('123456', hlo_text)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
144
tests/jax_to_ir_test.py
Normal file
144
tests/jax_to_ir_test.py
Normal file
@ -0,0 +1,144 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax.numpy as jnp
|
||||
from jax.tools import jax_to_ir
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
tf = None # type: ignore
|
||||
|
||||
|
||||
def axpy(a, x, y):
|
||||
return a * x + y[:, jnp.newaxis]
|
||||
|
||||
|
||||
class JaxToIRTest(absltest.TestCase):
|
||||
|
||||
def test_jax_to_hlo_axpy(self):
|
||||
hlo_proto, hlo_text = jax_to_ir.jax_to_hlo(axpy, [
|
||||
('y', jax_to_ir.parse_shape_str('f32[128]')),
|
||||
('a', jax_to_ir.parse_shape_str('f32[]')),
|
||||
('x', jax_to_ir.parse_shape_str('f32[128,2]')),
|
||||
])
|
||||
|
||||
# Check that hlo_text contains a broadcast, add, and multiply.
|
||||
self.assertIn('broadcast', hlo_text)
|
||||
self.assertIn('add', hlo_text)
|
||||
self.assertIn('multiply', hlo_text)
|
||||
|
||||
# Check that the HLO parameters are in the order we specified in the
|
||||
# jax_to_hlo call.
|
||||
self.assertIn('f32[128]{0} parameter(0)', hlo_text)
|
||||
self.assertIn('f32[] parameter(1)', hlo_text)
|
||||
self.assertIn('f32[128,2]{1,0} parameter(2)', hlo_text)
|
||||
|
||||
# Check that the parameters are in the expected order.
|
||||
|
||||
# TODO(jlebar): Ideally we'd check that hlo_proto can be deserialized to a
|
||||
# valid HLO proto, but we don't seem to have access to hlo_pb2 at the
|
||||
# moment, so the best we seem to be able to do is check that it's nonempty.
|
||||
assert hlo_proto
|
||||
|
||||
def test_jax_to_hlo_with_constants(self):
|
||||
|
||||
def fn(a, b, x, y):
|
||||
return a / b * x + y
|
||||
|
||||
_, hlo_text = jax_to_ir.jax_to_hlo(
|
||||
fn,
|
||||
input_shapes=[
|
||||
('x', jax_to_ir.parse_shape_str('f32[128]')),
|
||||
('y', jax_to_ir.parse_shape_str('f32[128]')),
|
||||
],
|
||||
constants={
|
||||
'a': 123456,
|
||||
'b': 4,
|
||||
})
|
||||
# Because we passed `a` and `b` as constants, they get constant-folded away
|
||||
# by Python/JAX to a/b = 30864.
|
||||
self.assertIn('constant(30864)', hlo_text)
|
||||
self.assertNotIn('123456', hlo_text)
|
||||
|
||||
def test_parse_shape_str_invalid(self):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid shape.*foo'):
|
||||
jax_to_ir.parse_shape_str('foo[]')
|
||||
|
||||
@unittest.skipIf(tf is None, 'TensorFlow not installed.')
|
||||
def test_jax_to_tf_axpy(self):
|
||||
tf_proto, tf_text = jax_to_ir.jax_to_tf(axpy, [
|
||||
('y', jax_to_ir.parse_shape_str('f32[128]')),
|
||||
('a', jax_to_ir.parse_shape_str('f32[]')),
|
||||
('x', jax_to_ir.parse_shape_str('f32[128,2]')),
|
||||
])
|
||||
|
||||
# Check that tf debug txt contains a broadcast, add, and multiply.
|
||||
self.assertIn('name: "BroadcastTo"', tf_text)
|
||||
self.assertIn('name: "AddV2"', tf_text)
|
||||
self.assertIn('name: "Mul"', tf_text)
|
||||
|
||||
# Check that we can re-import our graphdef.
|
||||
gdef = tf.compat.v1.GraphDef()
|
||||
gdef.ParseFromString(tf_proto)
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.import_graph_def(gdef, name='')
|
||||
|
||||
# Check that the HLO parameters are named as we specified.
|
||||
ops = {o.name: o for o in g.get_operations()
|
||||
if o.name in ('y', 'a', 'x', 'jax2tf_out')}
|
||||
self.assertLen(ops, 4)
|
||||
self.assertIdentityOp(ops['y'], [128], jnp.float32)
|
||||
self.assertIdentityOp(ops['a'], [], jnp.float32)
|
||||
self.assertIdentityOp(ops['x'], [128, 2], jnp.float32)
|
||||
self.assertIdentityOp(ops['jax2tf_out'], [128, 2], jnp.float32)
|
||||
|
||||
def assertIdentityOp(self, op, expected_shape, expected_dtype):
|
||||
self.assertEqual(op.type, 'Identity')
|
||||
output, = op.outputs
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
self.assertEqual(output.dtype, expected_dtype)
|
||||
|
||||
def test_parse_shape_str(self):
|
||||
self.assertParsedShape('f32[]', [], jnp.float32)
|
||||
self.assertParsedShape('f32[1,2,3]', [1, 2, 3], jnp.float32)
|
||||
self.assertParsedShape('pred[1]', [1], jnp.bool_)
|
||||
self.assertParsedShape('s8[1]', [1], jnp.int8)
|
||||
self.assertParsedShape('s16[1]', [1], jnp.int16)
|
||||
self.assertParsedShape('s32[1]', [1], jnp.int32)
|
||||
self.assertParsedShape('s64[1]', [1], jnp.int64)
|
||||
self.assertParsedShape('u8[1]', [1], jnp.uint8)
|
||||
self.assertParsedShape('u16[1]', [1], jnp.uint16)
|
||||
self.assertParsedShape('u32[1]', [1], jnp.uint32)
|
||||
self.assertParsedShape('u64[1]', [1], jnp.uint64)
|
||||
self.assertParsedShape('f16[1]', [1], jnp.float16)
|
||||
self.assertParsedShape('f32[1]', [1], jnp.float32)
|
||||
self.assertParsedShape('f64[1]', [1], jnp.float64)
|
||||
self.assertParsedShape('bf16[1]', [1], jnp.bfloat16)
|
||||
self.assertParsedShape('c64[1]', [1], jnp.complex64)
|
||||
self.assertParsedShape('c128[1]', [1], jnp.complex128)
|
||||
|
||||
def assertParsedShape(self, s: str, expected_shape, expected_dtype):
|
||||
p = jax_to_ir.parse_shape_str(s)
|
||||
self.assertEqual(p.shape, tuple(expected_shape))
|
||||
self.assertEqual(p.dtype, expected_dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user