mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details. PiperOrigin-RevId: 476167538
97 lines
2.5 KiB
Python
97 lines
2.5 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
from functools import partial
|
|
|
|
import numpy.random as npr
|
|
|
|
import jax.numpy as jnp
|
|
from jax.example_libraries import optimizers
|
|
from jax import grad, jit, make_jaxpr, vmap, lax
|
|
|
|
|
|
def gram(kernel, xs):
|
|
'''Compute a Gram matrix from a kernel and an array of data points.
|
|
|
|
Args:
|
|
kernel: callable, maps pairs of data points to scalars.
|
|
xs: array of data points, stacked along the leading dimension.
|
|
|
|
Returns:
|
|
A 2d array `a` such that `a[i, j] = kernel(xs[i], xs[j])`.
|
|
'''
|
|
return vmap(lambda x: vmap(lambda y: kernel(x, y))(xs))(xs)
|
|
|
|
|
|
def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9):
|
|
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass)
|
|
|
|
@jit
|
|
def update(i, opt_state):
|
|
x = get_params(opt_state)
|
|
return opt_update(i, grad(f)(x), opt_state)
|
|
|
|
opt_state = opt_init(x)
|
|
for i in range(num_steps):
|
|
opt_state = update(i, opt_state)
|
|
return get_params(opt_state)
|
|
|
|
|
|
def train(kernel, xs, ys, regularization=0.01):
|
|
gram_ = jit(partial(gram, kernel))
|
|
gram_mat = gram_(xs)
|
|
n = xs.shape[0]
|
|
|
|
def objective(v):
|
|
risk = .5 * jnp.sum((jnp.dot(gram_mat, v) - ys) ** 2.0)
|
|
reg = regularization * jnp.sum(v ** 2.0)
|
|
return risk + reg
|
|
|
|
v = minimize(objective, jnp.zeros(n))
|
|
|
|
def predict(x):
|
|
prods = vmap(lambda x_: kernel(x, x_))(xs)
|
|
return jnp.sum(v * prods)
|
|
|
|
return jit(vmap(predict))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
n = 100
|
|
d = 20
|
|
|
|
# linear kernel
|
|
|
|
linear_kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH)
|
|
truth = npr.randn(d)
|
|
xs = npr.randn(n, d)
|
|
ys = jnp.dot(xs, truth)
|
|
|
|
predict = train(linear_kernel, xs, ys)
|
|
|
|
print('MSE:', jnp.sum((predict(xs) - ys) ** 2.))
|
|
|
|
def gram_jaxpr(kernel):
|
|
return make_jaxpr(partial(gram, kernel))(xs)
|
|
|
|
rbf_kernel = lambda x, y: jnp.exp(-jnp.sum((x - y) ** 2))
|
|
|
|
print()
|
|
print('jaxpr of gram(linear_kernel):')
|
|
print(gram_jaxpr(linear_kernel))
|
|
print()
|
|
print('jaxpr of gram(rbf_kernel):')
|
|
print(gram_jaxpr(rbf_kernel))
|