rocm_jax/examples/onnx2xla.py

135 lines
4.7 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An ONNX to XLA compiler by JAX-tracing a Numpy-backed ONNX interpreter."""
from io import BytesIO
import hashlib
import urllib.request
import sys
import numpy as np
import onnx
from onnx import numpy_helper
import jax.numpy as jnp
from jax import jit, grad
from jax import lax
def _asarray(proto):
return numpy_helper.to_array(proto).reshape(tuple(proto.dims))
attr_types = dict(onnx.AttributeProto.AttributeType.items())
attribute_handlers = {
attr_types['FLOAT']: lambda a: a.f,
attr_types['INT']: lambda a: a.i,
attr_types['STRING']: lambda a: a.s,
attr_types['TENSOR']: lambda a: _asarray(a.t),
attr_types['FLOATS']: lambda a: a.floats,
attr_types['INTS']: lambda a: a.ints,
attr_types['STRINGS']: lambda a: a.strings,
attr_types['TENSORS']: lambda a: [_asarray(x) for x in a.tensors],
}
def onnx_maxpool(x, kernel_shape, pads=None, strides=None):
"""Numpy-backed implementation of ONNX MaxPool op."""
prefix = (1,) * (x.ndim - len(kernel_shape))
dims = prefix + tuple(kernel_shape)
pads = tuple(pads) if pads else [0] * len(kernel_shape)
strides = (prefix + tuple(strides)) if strides else [1] * len(kernel_shape)
return [lax.reduce_window(x, -jnp.inf, lax.max, dims, strides, 'VALID')]
def onnx_conv(x, w, b=0, group=1, kernel_shape=None, pads=None, strides=None,
dilations=None, auto_pad=None):
"""Numpy-backed implementation of ONNX Conv op."""
assert group == 1
kernel_shape = kernel_shape or w.shape
strides = strides or [1] * (w.ndim - 2)
if auto_pad:
auto_pad = 'SAME' if auto_pad.startswith(b'SAME') else 'VALID'
pads = lax.padtype_to_pads(x.shape[2:], w.shape[2:], strides, auto_pad)
else:
pads = pads or [0] * (w.ndim - 2)
lhs_dilation = [1] * (w.ndim - 2)
rhs_dilation = dilations or [1] * (w.ndim - 2)
return [lax.conv_with_general_padding(x, w, strides, pads,
lhs_dilation, rhs_dilation) + b]
def onnx_add(a, b, axis=None, broadcast=True):
"""Numpy-backed implementation of ONNX Add op."""
if broadcast:
axis = (a.dim - b.ndim) if axis is None else axis % a.ndim
assert a.shape[axis:][:b.ndim] == b.shape
b_shape = np.ones(a.ndim, dtype='int64')
b_shape[axis:axis + b.ndim] = b.shape
b = jnp.reshape(b, b_shape)
return [a + b]
onnx_ops = {
'Add': onnx_add,
'Constant': lambda value: [value],
'Conv': onnx_conv,
'MatMul': lambda x, y: [jnp.matmul(x, y)],
'MaxPool': onnx_maxpool,
'Relu': lambda x: [jnp.maximum(x, 0)],
'Reshape': lambda x, shape: [jnp.reshape(x, shape)],
}
def interpret_onnx(graph, *args):
vals = dict({n.name: a for n, a in zip(graph.input, args)},
**{n.name: _asarray(n) for n in graph.initializer})
for node in graph.node:
args = (vals[name] for name in node.input)
attrs = {a.name: attribute_handlers[a.type](a) for a in node.attribute}
outputs = onnx_ops[node.op_type](*args, **attrs)
for name, output in zip(node.output, outputs):
vals[name] = output
return [vals[n.name] for n in graph.output]
if __name__ == "__main__":
# It seems that there are several ONNX proto versions (you had one job!) but
# this implementation works with at least this one mnist example file.
url = ('https://github.com/onnx/models/blob/'
'81c4779096d1205edd0b809e191a924c58c38fef/'
'mnist/model.onnx?raw=true')
download = urllib.request.urlopen(url).read()
if hashlib.md5(download).hexdigest() != 'bc8ad9bd19c5a058055dc18d0f089dad':
print("onnx file checksum mismatch")
sys.exit(1)
model = onnx.load(BytesIO(download))
predict = lambda inputs: interpret_onnx(model.graph, inputs)[0]
# Run inference in Numpy-backed interpreter
print("interpreted:")
print(predict(jnp.ones((1, 1, 28, 28))))
# JIT compile to XLA device, run inference on device
compiled_predict = jit(predict)
print("compiled:")
print(compiled_predict(jnp.ones((1, 1, 28, 28))))
# The interpreter is differentiable too! Even the compiled one:
fun = lambda inputs: jnp.sum(compiled_predict(inputs))
print("a derivative with respect to inputs:")
print(grad(fun)(jnp.ones((1, 1, 28, 28)))[..., :3, :3])