Add build time support for AOT compilation to TF graphs.

PiperOrigin-RevId: 417392920
This commit is contained in:
Tom Hennigan 2021-12-20 06:18:06 -08:00 committed by jax authors
parent 52cf360e4a
commit 2a6147af1b
5 changed files with 291 additions and 142 deletions

View File

@ -17,9 +17,23 @@ licenses(["notice"])
package(default_visibility = ["//visibility:public"])
py_library(
name = "jax_to_hlo",
srcs = ["jax_to_hlo.py"],
name = "jax_to_ir",
srcs = ["jax_to_ir.py"],
tags = [
"ignore_for_dep=third_party.py.jax.experimental.jax2tf",
"ignore_for_dep=third_party.py.tensorflow",
],
deps = [
"//third_party/py/jax",
],
)
py_library(
name = "jax_to_ir_with_tensorflow",
srcs = ["jax_to_ir.py"],
deps = [
"//third_party/py/jax",
"//third_party/py/jax/experimental/jax2tf",
"//third_party/py/tensorflow",
],
)

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX tools."""
def _shell_quote(s):
"""Copy of bazel-skylib's shell.quote.
@ -28,7 +30,13 @@ def _shell_quote(s):
return "'" + s.replace("'", "'\\''") + "'"
def jax_to_hlo(name, deps, fn, input_shapes, constants = None):
"""Creates a genrule that uses jax_to_hlo.py to make HLO from a JAX func.
jax_to_ir(name, deps, fn, input_shapes, constants = constants, format = "HLO")
def jax_to_tf(name, deps, fn, input_shapes, constants = None):
jax_to_ir(name, deps, fn, input_shapes, constants = constants, format = "TF")
def jax_to_ir(name, deps, fn, input_shapes, constants = None, format = "HLO"):
"""Creates a genrule that uses jax_to_ir.py to make proto from a JAX func.
Suppose we have
@ -46,9 +54,9 @@ def jax_to_hlo(name, deps, fn, input_shapes, constants = None):
Then we can invoke this macro as follows.
$ cat your/thing/BUILD
load("//jax/tools:build_defs.bzl", "jax_to_hlo")
load("//jax/tools:build_defs.bzl", "jax_to_ir")
jax_to_hlo(
jax_to_ir(
name = "prog_hlo",
deps = ["//your/thing:prog"],
fn = "your.thing.prog.fn", # Fully-qualified module name
@ -59,6 +67,7 @@ def jax_to_hlo(name, deps, fn, input_shapes, constants = None):
constants = {
"z": "3.14159",
},
format = 'HLO',
)
This generates two build rules, named
@ -73,7 +82,7 @@ def jax_to_hlo(name, deps, fn, input_shapes, constants = None):
That is, the above macro will create a program which accepts parameters y
and x, in that order.
Skylark doesn't support floating-point numbers without a special option, so
Starlark doesn't support floating-point numbers without a special option, so
we need special rules to be able to pass fp values in `constants`. Each
dict value `v` is transformed as follows.
@ -103,37 +112,43 @@ def jax_to_hlo(name, deps, fn, input_shapes, constants = None):
will be in the order specified here.
constants: Python dictionary mapping arg names to constant values they
should take on, e.g. {"z": "float(3.14159")}.
format: Either HLO or TF.
"""
if not constants:
constants = {}
# Our goal here is to create a py_binary which depends on `deps` and
# invokes the main function defined in jax_to_hlo.py.
# invokes the main function defined in jax_to_ir.py.
#
# At first blush it seems that we can do this in a straightforward way:
#
# native.py_binary(main = "jax_to_hlo.py").
# native.py_binary(main = "jax_to_ir.py").
#
# The problem with this is that this py_binary lives in the user's package,
# whereas jax_to_hlo.py lives inside JAX's package. Bazel delivers a stern
# whereas jax_to_ir.py lives inside JAX's package. Bazel delivers a stern
# warning if you name a file in `main` that's outside of your package.
#
# To avoid the warning, we generate a simple "main file" in the user's
# package. It only complicates the rules a bit.
runner = name + "_jax_to_hlo_main"
runner = name + "_jax_to_ir_main"
native.genrule(
name = runner + "_gen",
outs = [runner + ".py"],
cmd = """cat <<EOF > '$(location {runner}.py)'
from absl import app
import jax.tools.jax_to_hlo as jax_to_hlo
import jax.tools.jax_to_ir as jax_to_ir
jax_to_hlo.set_up_flags()
app.run(jax_to_hlo.main)
jax_to_ir.set_up_flags()
app.run(jax_to_ir.main)
EOF
""".format(runner = runner),
)
if format == "TF":
jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir_with_tensorflow"
else:
jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir"
native.py_binary(
name = runner,
srcs = [
@ -142,7 +157,7 @@ EOF
python_version = "PY3",
deps = deps + [
"//third_party/py/jax/jaxlib",
"//third_party/py/jax/tools:jax_to_hlo",
jax_to_ir_rule,
],
)
@ -151,7 +166,7 @@ EOF
# Set JAX_PLATFORM_NAME to "cpu" to silence the "no GPU/TPU backend found,
# falling back to CPU" warning.
native.genrule(
name = name + "_jax_to_hlo_genrule",
name = name + "_jax_to_ir_genrule",
outs = [name + ".pb", name + ".txt"],
exec_tools = [runner],
cmd = """
@ -160,13 +175,15 @@ EOF
--fn {fn} \
--input_shapes {input_shapes} \
--evaled_constants {constants} \
--hlo_proto_dest '$(location {name}.pb)' \
--hlo_text_dest '$(location {name}.txt)' \
--ir_format {format} \
--ir_dest '$(location {name}.pb)' \
--ir_human_dest '$(location {name}.txt)' \
""".format(
name = name,
fn = fn,
input_shapes = _shell_quote(str(input_shapes)),
constants = _shell_quote(str(constants)),
runner = runner,
format = _shell_quote(format),
),
)

View File

@ -12,11 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Tool to convert a JAX function to an HLO proto.
r"""Tool to convert a JAX function to serialized representations.
This script is meant to be used as part of a genrule that converts a JAX program
into an HLO proto. The HLO proto represents an XLA program, and can be run from
e.g. a C++ program, without involving any Python.
into an IR that can be consumed by another system (e.g. a compiler).
Convert to HLO
==============
For example, you can generate an HLO proto for the XLA compiler. The HLO proto
represents an XLA program, and can be run from e.g. a C++ program, without
involving any Python.
This lets you use JAX as a convenient frontend for writing "XLA programs". From
another perspective, this script lets you make JAX into an ahead-of-time JAX ->
@ -33,12 +39,13 @@ Usage:
def fn(x, y, z):
return jnp.dot(x, y) / z
$ python jax_to_hlo.py \
$ python jax_to_ir.py \
--fn prog.fn \
--input_shapes '[("y": "f32[128,32]"), ("x", "f32[8,128]")]' \
--constants '{"z": 3.14159}' \
--hlo_text_dest /tmp/fn_hlo.txt \
--hlo_proto_dest /tmp/fn_hlo.pb
--ir_format HLO \
--ir_human_dest /tmp/fn_hlo.txt \
--ir_dest /tmp/fn_hlo.pb
Alternatively, you can use this script via a genrule. This way bazel will
generate the hlo text/proto as part of compilation, and then e.g. a C++ program
@ -58,38 +65,45 @@ Note that XLA's backwards-compatibility guarantees for saved HLO are currently
and the XLA team won't (and in fact will be unable to) help. One way to be sure
it won't break is to use the same version of XLA to build the HLO as you use to
run it. The genrule above makes this easy.
Implementation note: This script must be python2 compatible for now, because
Google's genrules still run with python2, b/66712815.
"""
from ast import literal_eval
import importlib
import functools
import re
from absl import app
from absl import flags
import jax
import jax.numpy as jnp
from jax._src.lib import xla_client
try:
from jax.experimental import jax2tf
except ImportError:
jax2tf = None # type: ignore[assignment]
try:
import tensorflow as tf
except ImportError:
tf = None # type: ignore
FLAGS = flags.FLAGS
def jax_to_hlo(fn, input_shapes, constants=None):
"""Converts a JAX function to an HLO module.
def jax_to_ir(fn, input_shapes, *, constants=None, format):
"""Converts a JAX function to a serialized ir and a debug txt dump.
Args:
fn: Function to convert.
input_shapes: List of tuples (arg name, xla_client.Shape),
input_shapes: List of tuples (arg name, jax.ShapedArray),
indicating the shapes of the arguments to fn. The order of parameters in
the resulting XLA program will match the order in this list.
constants: Dict mapping function argument name to a Python value. Specified
arguments these values as compile-time constants.
format: Which IR format to use. Supported values are 'HLO' and 'TF'.
Returns:
A tuple (serialized_hlo_proto, hlo_text).
A tuple of (compiler_suitable_ir, human_readable_ir).
"""
if not constants:
constants = {}
@ -101,25 +115,8 @@ def jax_to_hlo(fn, input_shapes, constants=None):
'Arguments appear in both `input_shapes` and `constants`: %s' %
', '.join(sorted(overlapping_args)))
args = []
for arg_name, shape in input_shapes:
if not shape.is_array():
raise ValueError('Shape %s is not an array, but currently only arrays '
'are supported (i.e., no tuples, nor tokens).' % str(shape))
# Check that `shape` either doesn't have a layout or has the default layout.
#
# TODO(jlebar): This could be simpler if the Shape class exposed its layout,
# or if Shape exposed a function to unconditionally use the default layout.
shape_with_default_layout = xla_client.Shape.array_shape(
shape.xla_element_type(),
shape.dimensions()).with_major_to_minor_layout_if_absent()
if (shape.with_major_to_minor_layout_if_absent() !=
shape_with_default_layout):
raise ValueError('Shape %s has a non-default layout, but only '
'the default layout is allowed.' % str(shape))
args.append(jnp.zeros(shape.dimensions(), dtype=shape.numpy_dtype()))
# TODO(tomhennigan): Ideally we could avoid creating actual values here.
args = [jnp.zeros(s.shape, s.dtype) for _, s in input_shapes]
# Curry `constants` into the function.
fn_curried = functools.partial(fn, **constants)
@ -130,23 +127,51 @@ def jax_to_hlo(fn, input_shapes, constants=None):
arg_names = [arg_name for arg_name, _ in input_shapes]
return fn_curried(**dict(zip(arg_names, args)))
comp = jax.xla_computation(ordered_wrapper)(*args)
return (comp.as_serialized_hlo_module_proto(), comp.as_hlo_text())
if format == 'HLO':
comp = jax.xla_computation(ordered_wrapper)(*args)
serialized_proto = comp.as_serialized_hlo_module_proto()
debug_txt = comp.as_hlo_text()
else:
assert format == 'TF'
if tf is None:
raise ValueError(
'Conversion to TF graph requires TensorFlow to be installed.')
f = jax2tf.convert(ordered_wrapper)
f = tf_wrap_with_input_names(f, input_shapes)
f = tf.function(f, autograph=False)
g = f.get_concrete_function(*args).graph.as_graph_def()
serialized_proto = g.SerializeToString()
debug_txt = str(g)
return serialized_proto, debug_txt
def tf_wrap_with_input_names(f, input_shapes):
def wrapper(*args):
args = tuple(
tf.identity(a, name=name) for a, (name, _) in zip(args, input_shapes))
# NOTE: Output names already set via `jax2tf.convert(..)`.
return f(*args)
return wrapper
jax_to_hlo = functools.partial(jax_to_ir, format='HLO')
jax_to_tf = functools.partial(jax_to_ir, format='TF')
def main(argv):
if len(argv) != 1:
raise app.UsageError('No positional arguments are accepted.')
if not FLAGS.hlo_proto_dest and not FLAGS.hlo_text_dest:
raise app.Error('At least one of --hlo_proto_dest and '
'--hlo_text_dest is required.')
if not FLAGS.ir_dest and not FLAGS.ir_human_dest:
raise app.Error('At least one of --ir_dest and '
'--ir_human_dest is required.')
module_name, fn_name = FLAGS.fn.rsplit('.', 1)
module = importlib.import_module(module_name)
fn = getattr(module, fn_name)
input_shapes = [(name, xla_client.Shape(shape_str))
input_shapes = [(name, parse_shape_str(shape_str))
for name, shape_str in literal_eval(FLAGS.input_shapes)]
# Parse --constants and --evaled_constants.
@ -166,15 +191,38 @@ def main(argv):
'Argument appears in both --constants and --evaled_constants: %s' % k)
constants[k] = v
hlo_proto, hlo_text = jax_to_hlo(fn, input_shapes, constants)
ir, debug_ir = jax_to_ir(fn, input_shapes, constants=constants,
format=FLAGS.ir_format)
if FLAGS.hlo_proto_dest:
with open(FLAGS.hlo_proto_dest, 'wb') as f:
f.write(hlo_proto)
if FLAGS.ir_dest:
with open(FLAGS.ir_dest, 'wb') as f:
f.write(ir)
if FLAGS.ir_human_dest:
with open(FLAGS.ir_human_dest, 'w') as f:
f.write(debug_ir)
def parse_shape_str(s):
match = _SHAPE_RE.match(s)
if not match:
raise ValueError(f'Invalid shape {s}. Valid example: "f32[1,2,3]".'
f'Note that dtype must be one of {list(_DT)}')
dtype = _DT[match.group(1)]
if match.group(2):
shape = tuple(int(d.strip()) for d in match.group(2).split(","))
else:
shape = ()
return jax.ShapedArray(shape, dtype)
_DT = {'pred': jnp.bool_,
'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64,
's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64,
'bf16': jnp.bfloat16,
'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64,
'c64': jnp.complex64, 'c128': jnp.complex128}
_SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$")
if FLAGS.hlo_text_dest:
with open(FLAGS.hlo_text_dest, 'w') as f:
f.write(hlo_text)
def set_up_flags():
flags.DEFINE_string(
@ -188,8 +236,10 @@ def set_up_flags():
'Python dict giving constant values for some params. '
'Values in this dict that are of type str are evaluated '
'using ast.literal_eval.')
flags.DEFINE_string('hlo_proto_dest', None, 'File to write HLO proto')
flags.DEFINE_string('hlo_text_dest', None, 'File to write HLO text')
flags.DEFINE_enum('ir_format', 'HLO', ('HLO', 'TF'), 'Output format.')
flags.DEFINE_string('ir_dest', None, 'File to write IR to')
flags.DEFINE_string('ir_human_dest', None,
'File to write human readable debug output')
flags.mark_flag_as_required('fn')
flags.mark_flag_as_required('input_shapes')

View File

@ -1,76 +0,0 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from jax._src.lib import xla_client
import jax.numpy as jnp
from jax.tools.jax_to_hlo import jax_to_hlo
from jax._src import test_util as jtu
class JaxToHloTest(absltest.TestCase):
def test_convert_axpy(self):
def axpy(a, x, y):
return a * x + y[:, jnp.newaxis]
hlo_proto, hlo_text = jax_to_hlo(
axpy, [
('y', xla_client.Shape('f32[128]')),
('a', xla_client.Shape('f32[]')),
('x', xla_client.Shape('f32[128,2]')),
])
# Check that hlo_text contains a broadcast, add, and multiply.
self.assertIn('broadcast', hlo_text)
self.assertIn('add', hlo_text)
self.assertIn('multiply', hlo_text)
# Check that the HLO parameters are in the order we specified in the
# jax_to_hlo call.
self.assertIn('f32[128]{0} parameter(0)', hlo_text)
self.assertIn('f32[] parameter(1)', hlo_text)
self.assertIn('f32[128,2]{1,0} parameter(2)', hlo_text)
# Check that the parameters are in the expected order.
# TODO(jlebar): Ideally we'd check that hlo_proto can be deserialized to a
# valid HLO proto, but we don't seem to have access to hlo_pb2 at the
# moment, so the best we seem to be able to do is check that it's nonempty.
assert hlo_proto
def test_convert_with_constants(self):
def fn(a, b, x, y):
return a / b * x + y
_, hlo_text = jax_to_hlo(
fn,
input_shapes=[
('x', xla_client.Shape('f32[128]')),
('y', xla_client.Shape('f32[128]')),
],
constants={
'a': 123456,
'b': 4,
})
# Because we passed `a` and `b` as constants, they get constant-folded away
# by Python/JAX to a/b = 30864.
self.assertIn('constant(30864)', hlo_text)
self.assertNotIn('123456', hlo_text)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

144
tests/jax_to_ir_test.py Normal file
View File

@ -0,0 +1,144 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from absl.testing import absltest
import jax.numpy as jnp
from jax.tools import jax_to_ir
from jax._src import test_util as jtu
try:
import tensorflow as tf
except ImportError:
tf = None # type: ignore
def axpy(a, x, y):
return a * x + y[:, jnp.newaxis]
class JaxToIRTest(absltest.TestCase):
def test_jax_to_hlo_axpy(self):
hlo_proto, hlo_text = jax_to_ir.jax_to_hlo(axpy, [
('y', jax_to_ir.parse_shape_str('f32[128]')),
('a', jax_to_ir.parse_shape_str('f32[]')),
('x', jax_to_ir.parse_shape_str('f32[128,2]')),
])
# Check that hlo_text contains a broadcast, add, and multiply.
self.assertIn('broadcast', hlo_text)
self.assertIn('add', hlo_text)
self.assertIn('multiply', hlo_text)
# Check that the HLO parameters are in the order we specified in the
# jax_to_hlo call.
self.assertIn('f32[128]{0} parameter(0)', hlo_text)
self.assertIn('f32[] parameter(1)', hlo_text)
self.assertIn('f32[128,2]{1,0} parameter(2)', hlo_text)
# Check that the parameters are in the expected order.
# TODO(jlebar): Ideally we'd check that hlo_proto can be deserialized to a
# valid HLO proto, but we don't seem to have access to hlo_pb2 at the
# moment, so the best we seem to be able to do is check that it's nonempty.
assert hlo_proto
def test_jax_to_hlo_with_constants(self):
def fn(a, b, x, y):
return a / b * x + y
_, hlo_text = jax_to_ir.jax_to_hlo(
fn,
input_shapes=[
('x', jax_to_ir.parse_shape_str('f32[128]')),
('y', jax_to_ir.parse_shape_str('f32[128]')),
],
constants={
'a': 123456,
'b': 4,
})
# Because we passed `a` and `b` as constants, they get constant-folded away
# by Python/JAX to a/b = 30864.
self.assertIn('constant(30864)', hlo_text)
self.assertNotIn('123456', hlo_text)
def test_parse_shape_str_invalid(self):
with self.assertRaisesRegex(ValueError, 'Invalid shape.*foo'):
jax_to_ir.parse_shape_str('foo[]')
@unittest.skipIf(tf is None, 'TensorFlow not installed.')
def test_jax_to_tf_axpy(self):
tf_proto, tf_text = jax_to_ir.jax_to_tf(axpy, [
('y', jax_to_ir.parse_shape_str('f32[128]')),
('a', jax_to_ir.parse_shape_str('f32[]')),
('x', jax_to_ir.parse_shape_str('f32[128,2]')),
])
# Check that tf debug txt contains a broadcast, add, and multiply.
self.assertIn('name: "BroadcastTo"', tf_text)
self.assertIn('name: "AddV2"', tf_text)
self.assertIn('name: "Mul"', tf_text)
# Check that we can re-import our graphdef.
gdef = tf.compat.v1.GraphDef()
gdef.ParseFromString(tf_proto)
g = tf.Graph()
with g.as_default():
tf.import_graph_def(gdef, name='')
# Check that the HLO parameters are named as we specified.
ops = {o.name: o for o in g.get_operations()
if o.name in ('y', 'a', 'x', 'jax2tf_out')}
self.assertLen(ops, 4)
self.assertIdentityOp(ops['y'], [128], jnp.float32)
self.assertIdentityOp(ops['a'], [], jnp.float32)
self.assertIdentityOp(ops['x'], [128, 2], jnp.float32)
self.assertIdentityOp(ops['jax2tf_out'], [128, 2], jnp.float32)
def assertIdentityOp(self, op, expected_shape, expected_dtype):
self.assertEqual(op.type, 'Identity')
output, = op.outputs
self.assertEqual(output.shape, expected_shape)
self.assertEqual(output.dtype, expected_dtype)
def test_parse_shape_str(self):
self.assertParsedShape('f32[]', [], jnp.float32)
self.assertParsedShape('f32[1,2,3]', [1, 2, 3], jnp.float32)
self.assertParsedShape('pred[1]', [1], jnp.bool_)
self.assertParsedShape('s8[1]', [1], jnp.int8)
self.assertParsedShape('s16[1]', [1], jnp.int16)
self.assertParsedShape('s32[1]', [1], jnp.int32)
self.assertParsedShape('s64[1]', [1], jnp.int64)
self.assertParsedShape('u8[1]', [1], jnp.uint8)
self.assertParsedShape('u16[1]', [1], jnp.uint16)
self.assertParsedShape('u32[1]', [1], jnp.uint32)
self.assertParsedShape('u64[1]', [1], jnp.uint64)
self.assertParsedShape('f16[1]', [1], jnp.float16)
self.assertParsedShape('f32[1]', [1], jnp.float32)
self.assertParsedShape('f64[1]', [1], jnp.float64)
self.assertParsedShape('bf16[1]', [1], jnp.bfloat16)
self.assertParsedShape('c64[1]', [1], jnp.complex64)
self.assertParsedShape('c128[1]', [1], jnp.complex128)
def assertParsedShape(self, s: str, expected_shape, expected_dtype):
p = jax_to_ir.parse_shape_str(s)
self.assertEqual(p.shape, tuple(expected_shape))
self.assertEqual(p.dtype, expected_dtype)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())