2019-10-17 08:58:25 +02:00
{
2021-02-10 06:39:35 -08:00
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "6umP1IKf4Dg6"
},
"source": [
2023-01-25 15:12:04 -08:00
"# Autobatching for Bayesian Inference\n",
2021-02-10 06:39:35 -08:00
"\n",
2024-06-21 14:50:02 -07:00
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
2023-03-01 13:15:42 -05:00
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n",
2021-02-11 11:56:24 -08:00
"\n",
2021-02-10 06:39:35 -08:00
"This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n",
"\n",
"Inspired by a notebook by @davmre."
]
2019-10-17 08:58:25 +02:00
},
2021-02-10 06:39:35 -08:00
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "8RZDkfbV3zdR"
},
"outputs": [],
"source": [
"import functools\n",
"import itertools\n",
"import re\n",
"import sys\n",
"import time\n",
"\n",
"from matplotlib.pyplot import *\n",
"\n",
"import jax\n",
"\n",
"from jax import lax\n",
"import jax.numpy as jnp\n",
"import jax.scipy as jsp\n",
"from jax import random\n",
"\n",
"import numpy as np\n",
"import scipy as sp"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p2VcZS1d34C6"
},
"source": [
"## Generate a fake binary classification dataset"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "pq41hMvn4c_i"
},
"outputs": [],
"source": [
"np.random.seed(10009)\n",
"\n",
"num_features = 10\n",
"num_points = 100\n",
"\n",
"true_beta = np.random.randn(num_features).astype(jnp.float32)\n",
"all_x = np.random.randn(num_points, num_features).astype(jnp.float32)\n",
"y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "O0nVumAw7IlT",
"outputId": "751a3290-a81b-4538-9183-16cd685fbaf9"
},
"outputs": [
2019-10-17 08:58:25 +02:00
{
2021-02-10 06:39:35 -08:00
"data": {
"text/plain": [
"array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,\n",
" 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,\n",
" 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,\n",
" 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,\n",
" 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)"
2019-10-17 08:58:25 +02:00
]
2021-02-10 06:39:35 -08:00
},
"execution_count": 11,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"y"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DZRVvhpn5aB1"
},
"source": [
"## Write the log-joint function for the model\n",
"\n",
"We'll write a non-batched version, a manually batched version, and an autobatched version."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C_mDXInL7nsP"
},
"source": [
"### Non-batched"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "ZHyL2sJh5ajG"
},
"outputs": [],
"source": [
"def log_joint(beta):\n",
" result = 0.\n",
" # Note that no `axis` parameter is provided to `jnp.sum`.\n",
" result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))\n",
" result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "e51qW0ro6J7C",
"outputId": "2ec6bbbd-12ee-45bc-af76-5111c53e4d5a"
},
"outputs": [
2019-10-17 08:58:25 +02:00
{
2021-02-10 06:39:35 -08:00
"data": {
"text/plain": [
2023-08-18 16:50:36 -04:00
"Array(-213.23558, dtype=float32)"
2019-10-17 08:58:25 +02:00
]
2021-02-10 06:39:35 -08:00
},
"execution_count": 13,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"log_joint(np.random.randn(num_features))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "fglQXK1Y6wnm",
"outputId": "2b934336-08ad-4776-9a58-aa575bf601eb"
},
"outputs": [
2019-10-17 08:58:25 +02:00
{
2021-02-10 06:39:35 -08:00
"name": "stdout",
"output_type": "stream",
"text": [
"Caught expected exception Incompatible shapes for broadcasting: ((100, 10), (1, 100))\n"
]
}
],
"source": [
"# This doesn't work, because we didn't write `log_prob()` to handle batching.\n",
"try:\n",
" batch_size = 10\n",
" batched_test_beta = np.random.randn(batch_size, num_features)\n",
"\n",
" log_joint(np.random.randn(batch_size, num_features))\n",
"except ValueError as e:\n",
" print(\"Caught expected exception \" + str(e))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_lQ8MnKq7sLU"
},
"source": [
"### Manually batched"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "2g5-4bQE7gRA"
},
"outputs": [],
"source": [
"def batched_log_joint(beta):\n",
" result = 0.\n",
" # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis\n",
" # or setting it incorrectly yields an error; at worst, it silently changes the\n",
" # semantics of the model.\n",
" result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),\n",
" axis=-1)\n",
" # Note the multiple transposes. Getting this right is not rocket science,\n",
" # but it's also not totally mindless. (I didn't get it right on the first\n",
" # try.)\n",
" result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)),\n",
" axis=-1)\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "KdDMr-Gy85CO",
"outputId": "db746654-68e9-43b8-ce3b-6e5682e22eb5"
},
"outputs": [
2019-10-17 08:58:25 +02:00
{
2021-02-10 06:39:35 -08:00
"data": {
"text/plain": [
2023-08-18 16:50:36 -04:00
"Array([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n",
" -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n",
" -126.60544586, -190.81988525], dtype=float32)"
2019-10-17 08:58:25 +02:00
]
2021-02-10 06:39:35 -08:00
},
"execution_count": 16,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"batch_size = 10\n",
"batched_test_beta = np.random.randn(batch_size, num_features)\n",
"\n",
"batched_log_joint(batched_test_beta)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-uuGlHQ_85kd"
},
"source": [
"### Autobatched with vmap\n",
"\n",
"It just works."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "SU20bouH8-Za",
"outputId": "ee450298-982f-4b9a-bed9-a6f9b8f63d92"
},
"outputs": [
2019-10-17 08:58:25 +02:00
{
2021-02-10 06:39:35 -08:00
"data": {
"text/plain": [
2023-08-18 16:50:36 -04:00
"Array([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n",
" -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n",
" -126.60544586, -190.81988525], dtype=float32)"
2019-10-17 08:58:25 +02:00
]
2021-02-10 06:39:35 -08:00
},
"execution_count": 17,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"vmap_batched_log_joint = jax.vmap(log_joint)\n",
"vmap_batched_log_joint(batched_test_beta)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L1KNBo9y_yZJ"
},
"source": [
"## Self-contained variational inference example\n",
"\n",
"A little code is copied from above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lQTPaaQMJh8Y"
},
"source": [
"### Set up the (batched) log-joint function"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "AITXbaofA3Pm"
},
"outputs": [],
"source": [
"@jax.jit\n",
"def log_joint(beta):\n",
" result = 0.\n",
" # Note that no `axis` parameter is provided to `jnp.sum`.\n",
" result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))\n",
" result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))\n",
" return result\n",
"\n",
"batched_log_joint = jax.jit(jax.vmap(log_joint))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UmmFMQ8LJk6a"
},
"source": [
"### Define the ELBO and its gradient"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "MJtnskL6BKwV"
},
"outputs": [],
"source": [
"def elbo(beta_loc, beta_log_scale, epsilon):\n",
" beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon\n",
" return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))\n",
" \n",
"elbo = jax.jit(elbo)\n",
"elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oQC7xKYnJrp5"
},
"source": [
"### Optimize the ELBO using SGD"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "9JrD5nNgH715",
"outputId": "80bf62d8-821a-45c4-885c-528b2e449e97"
},
"outputs": [
2019-10-17 08:58:25 +02:00
{
2021-02-10 06:39:35 -08:00
"name": "stdout",
"output_type": "stream",
"text": [
"0\t-180.85391235351562\n",
"10\t-113.06047058105469\n",
"20\t-102.73725891113281\n",
"30\t-99.78732299804688\n",
"40\t-98.90898895263672\n",
"50\t-98.29743957519531\n",
"60\t-98.18630981445312\n",
"70\t-97.5797348022461\n",
"80\t-97.28599548339844\n",
"90\t-97.46998596191406\n",
"100\t-97.47715759277344\n",
"110\t-97.5806884765625\n",
"120\t-97.49433898925781\n",
"130\t-97.50270080566406\n",
"140\t-96.86398315429688\n",
"150\t-97.44197082519531\n",
"160\t-97.06938934326172\n",
"170\t-96.84031677246094\n",
"180\t-97.21339416503906\n",
"190\t-97.56500244140625\n",
"200\t-97.26395416259766\n",
"210\t-97.11984252929688\n",
"220\t-97.39595794677734\n",
"230\t-97.16830444335938\n",
"240\t-97.11840057373047\n",
"250\t-97.24346160888672\n",
"260\t-97.29786682128906\n",
"270\t-96.69286346435547\n",
"280\t-96.96443176269531\n",
"290\t-97.3005599975586\n",
"300\t-96.63589477539062\n",
"310\t-97.0351791381836\n",
"320\t-97.52906799316406\n",
"330\t-97.2880630493164\n",
"340\t-97.07324981689453\n",
"350\t-97.15620422363281\n",
"360\t-97.25880432128906\n",
"370\t-97.19515228271484\n",
"380\t-97.13092803955078\n",
"390\t-97.11730194091797\n",
"400\t-96.93872833251953\n",
"410\t-97.26676940917969\n",
"420\t-97.35321044921875\n",
"430\t-97.2100830078125\n",
"440\t-97.28434753417969\n",
"450\t-97.16310119628906\n",
"460\t-97.26123809814453\n",
"470\t-97.21342468261719\n",
"480\t-97.23995971679688\n",
"490\t-97.1491470336914\n",
"500\t-97.23527526855469\n",
"510\t-96.93415832519531\n",
"520\t-97.21208190917969\n",
"530\t-96.82577514648438\n",
"540\t-97.01283264160156\n",
"550\t-96.9417724609375\n",
"560\t-97.16526794433594\n",
"570\t-97.29165649414062\n",
"580\t-97.42940521240234\n",
"590\t-97.24371337890625\n",
"600\t-97.15219116210938\n",
"610\t-97.4984359741211\n",
"620\t-96.99072265625\n",
"630\t-96.88955688476562\n",
"640\t-96.89968872070312\n",
"650\t-97.13794708251953\n",
"660\t-97.43705749511719\n",
"670\t-96.99232482910156\n",
"680\t-97.15624237060547\n",
"690\t-97.1869125366211\n",
"700\t-97.1115951538086\n",
"710\t-97.78104400634766\n",
"720\t-97.23224639892578\n",
"730\t-97.16204071044922\n",
"740\t-96.99580383300781\n",
"750\t-96.66720581054688\n",
"760\t-97.16795349121094\n",
"770\t-97.51432037353516\n",
"780\t-97.28899383544922\n",
"790\t-96.91226959228516\n",
"800\t-97.17100524902344\n",
"810\t-97.29046630859375\n",
"820\t-97.16242980957031\n",
"830\t-97.19109344482422\n",
"840\t-97.5638427734375\n",
"850\t-97.00192260742188\n",
"860\t-96.86555480957031\n",
"870\t-96.76338195800781\n",
"880\t-96.83660125732422\n",
"890\t-97.121826171875\n",
"900\t-97.09553527832031\n",
"910\t-97.06825256347656\n",
"920\t-97.1194839477539\n",
"930\t-96.87931823730469\n",
"940\t-97.45622253417969\n",
"950\t-96.69277954101562\n",
"960\t-97.29376983642578\n",
"970\t-97.33528137207031\n",
"980\t-97.349609375\n",
"990\t-97.09675598144531\n"
]
}
],
"source": [
"def normal_sample(key, shape):\n",
" \"\"\"Convenience function for quasi-stateful RNG.\"\"\"\n",
" new_key, sub_key = random.split(key)\n",
" return new_key, random.normal(sub_key, shape)\n",
"\n",
"normal_sample = jax.jit(normal_sample, static_argnums=(1,))\n",
"\n",
2023-08-17 17:33:42 -07:00
"key = random.key(10003)\n",
2021-02-10 06:39:35 -08:00
"\n",
"beta_loc = jnp.zeros(num_features, jnp.float32)\n",
"beta_log_scale = jnp.zeros(num_features, jnp.float32)\n",
"\n",
"step_size = 0.01\n",
"batch_size = 128\n",
"epsilon_shape = (batch_size, num_features)\n",
"for i in range(1000):\n",
" key, epsilon = normal_sample(key, epsilon_shape)\n",
" elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(\n",
" beta_loc, beta_log_scale, epsilon)\n",
" beta_loc += step_size * beta_loc_grad\n",
" beta_log_scale += step_size * beta_log_scale_grad\n",
" if i % 10 == 0:\n",
" print('{}\\t{}'.format(i, elbo_val))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b3ZAe5fJJ2KM"
},
"source": [
"### Display the results\n",
"\n",
"Coverage isn't quite as good as we might like, but it's not bad, and nobody said variational inference was exact."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "zt1NBLoVHtOG",
"outputId": "fb159795-e6e7-497c-e501-9933ec761af4"
},
"outputs": [
2019-10-17 08:58:25 +02:00
{
2021-02-10 06:39:35 -08:00
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f90aed84860>"
2019-10-17 08:58:25 +02:00
]
2021-02-10 06:39:35 -08:00
},
"execution_count": 24,
"metadata": {
"tags": []
},
"output_type": "execute_result"
2019-10-17 08:58:25 +02:00
},
{
2021-02-10 06:39:35 -08:00
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbwAAAGtCAYAAABtOsHhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzs3Xd8jWfjBvDrzkCtVlP6Umq9NRNS\nUk0QYuQkETKEGKF2RbWqSGsVpa2iqNqrVgkVSmNlSWQ0VoxSvCgxqyU0DULGuX9/HMnPSOJEzjnP\nGdf388nnJGc8z5UTcuV+1i2klCAiIjJ3VkoHICIiMgQWHhERWQQWHhERWQQWHhERWQQWHhERWQQW\nHhERWQQWHhERWQQWHhERWQQWHhERWQQbpQMUx2uvvSZr1aqldAwiIjIiKSkpt6SUlZ/3PJMqvFq1\nauHw4cNKxyAiIiMihLikzfO4SZOIiCwCC4+IiCwCC4+IiCyCSe3DK0h2djauXr2KBw8eKB2FyOKU\nKVMG1atXh62trdJRiJ7L5Avv6tWrqFChAmrVqgUhhNJxiCyGlBJpaWm4evUqateurXQcoucy+U2a\nDx48gJ2dHcuOyMCEELCzs+PWFTIZJl94AFh2RArh/z0yJWZReERERM/DwtORbdu2QQiBM2fOKJpj\n0qRJiI6OLvFy/vnnHyxatKjYr5syZQq+/fbbAu9/44034OjoCHt7e/zyyy/FXvaxY8ewa9euYr/u\n+vXr6NatW7Ff9zghBPr06ZP/dU5ODipXrozOnTuXaLlEZDgsPB0JDQ1F69atERoaqrNl5uTkFPs1\nU6dORceOHUu87hctvKJ88sknOHbsGDZv3oyBAwdCrVYX6/UvUng5OTmoVq0awsLCivWap5UrVw4n\nT55EZmYmACAqKgpvvPFGsbIQkbIssvBSLt3BwtjzSLl0RyfLu3v3LhITE7Fy5Ups3Lgx//64uDi0\nadMG3t7eqF+/PoKDg/N/yZcvXx6ffPIJGjdujA4dOuDmzZsAADc3N4wcORJOTk6YN28eUlNT0b59\nezRp0gQdOnTA5cuXAQC+vr5Yu3YtAGDp0qUICgoCAPTv3z//l3utWrUwbtw4ODo6wsnJCUeOHIGH\nhwfq1q2LJUuW5Gfv0KEDmjVrBgcHB2zfvh0AMHbsWPzxxx9wdHRESEgIAGDWrFl455130KRJE0ye\nPDn/+/zqq69Qr149tG7dGv/73/+e+341bNgQNjY2uHXrVqHf3+bNm2Fvb4+mTZuiTZs2yMrKwqRJ\nk7Bp0yY4Ojpi06ZNuHfvHgYOHIgWLVrg7bffzs++evVq+Pj4oH379ujQoQNSU1Nhb28PQHOQ04AB\nA+Dg4IC3334bsbGxBb6mIJ06dcLOnTsBaP7A6dWrV/5jhWVJTU2Fq6srmjVrhmbNmuHXX3/N/7fh\n5uaGbt26oUGDBggKCoKUMv+9b9SoEZo0aYIxY8Y89/0kIi1JKU3mo3nz5vJpp06deua+ohxOvS3r\nT9wla4/dIetP3CUPp94u1usL8uOPP8qBAwdKKaV0cXGRhw8fllJKGRsbK0uXLi3/+OMPmZOTIzt2\n7Cg3b94spZQSgPzxxx+llFJ+8cUXcvjw4VJKKdu2bSuHDRuWv+zOnTvL1atXSymlXLlypfT19ZVS\nSnnjxg1Zt25dGR8fL9966y2ZlpYmpZSyX79++euoWbOmXLRokZRSypEjR0oHBwf577//yr///ltW\nqVJFSilldna2TE9Pl1JKefPmTVm3bl2pVqvlxYsXZePGjfNzREREyCFDhki1Wi1zc3Olt7e33Ldv\nnzx8+LC0t7eX9+7dk+np6bJu3bpy1qxZz7xHkydPzr9///79smrVqlKtVhf6/dnb28urV69KKaW8\nc+eOlFLKVatW5b9PUko5btw4uW7duvznvPXWW/Lu3bty1apV8o033sh/Tx7/Xr799ls5YMAAKaWU\np0+fljVq1JCZmZnPvOZp5cqVk8ePH5cBAQEyMzNTNm3aVMbGxkpvb+8is9y7d09mZmZKKaU8e/as\nzPs3HBsbKytWrCivXLkic3NzpbOzs0xISJC3bt2S9erVk2q1+onv3ZgV9/8gka4BOCy16BCLG+Ht\nv5CGrBw11BLIzlFj/4W0Ei8zNDQUPXv2BAD07Nnzic2aLVq0QJ06dWBtbY1evXohMTERAGBlZYUe\nPXoAAPr06ZN/P4D8+wEgOTkZvXv3BgD07ds3/3mvv/46pk6dinbt2mH27Nl49dVXC8zm4+MDAHBw\ncMC7776LChUqoHLlyihdujT++ecfSCkxfvx4NGnSBB07dsS1a9fw119/PbOcyMhIREZG4u2330az\nZs1w5swZnDt3DgkJCfD390fZsmVRsWLF/PUVZO7cuXB0dMSYMWOwadMmCCEK/f5atWqF/v37Y/ny\n5cjNzS1weZGRkfjmm2/g6OgINzc3PHjwIH+E6O7uXuB7kpiYmL8vrkGDBqhZsybOnj1b5GvyNGnS\nBKmpqQgNDUWnTp20ypKdnY0hQ4bAwcEB3bt3x6lTp/Jf06JFC1SvXh1WVlZwdHREamoqXn75ZZQp\nUwaDBg3C1q1bUbZs2ULzEFHxmPyJ58XlXMcOpWyskJ2jhq2NFZzr2JVoebdv38bevXtx4sQJCCGQ\nm5sLIQRmzZoF4NnDtgs7jPvx+8uVK6fVuk+cOAE7Oztcv3690OeULl0agKZg8z7P+zonJwfr16/H\nzZs3kZKSAltbW9SqVavA86qklBg3bhyGDh36xP3fffedVlkBzT48bTfRLVmyBAcOHMDOnTvRvHlz\npKSkFJhpy5YtqF+//hP3HzhwQOv38HHavMbHxwdjxoxBXFwc0tL+/4+lwrJMmTIFr7/+Oo4fPw61\nWo0yZcrkP/b4z8Pa2ho5OTmwsbHBwYMHERMTg7CwMCxYsAB79+4t9vdCRM+yuBFe85qVsH6wM0ap\n6mP9YGc0r1mpRMsLCwtD3759cenSJaSmpuLKlSuoXbs2EhISAAAHDx7ExYsXoVarsWnTJrRu3RoA\noFar8/e1bdiwIf/+p7Vs2TJ/v+D69evh6uqav9zdu3fj6NGj+Pbbb3Hx4sUXyp+eno4qVarA1tYW\nsbGxuHRJM8tGhQoVkJGRkf88Dw8P/PDDD7h79y4A4Nq1a/j777/Rpk0bbNu2DZmZmcjIyEB4eHix\n1l/Y9/fHH3/g3XffxdSpU1G5cmVcuXKlwEzz58/P3/d19OjR567P1dUV69evBwCcPXsWly9ffqak\nijJw4EBMnjwZDg4OT9xfWJb09HRUrVoVVlZWWLduXaGj1Tx3795Feno6OnXqhLlz5+L48eNaZyMF\nJScD06drbsloWVzhAZrSG97uvyUuO0CzOdPf3/+J+wICAvI3a77zzjv48MMP0bBhQ9SuXTv/ueXK\nlcPBgwdhb2+PvXv3YtKkSQUuf/78+Vi1ahWaNGmCdevWYd68eXj48CGGDBmCH374AdWqVcPs2bMx\ncODA/F+2xREUFITDhw/DwcEBa9euRYMGDQAAdnZ2aNWqFezt7RESEgKVSoXevXvDxcUFDg4O6Nat\nGzIyMtCsWTP06NEDTZs2hZeXF955551irb+g7w8AQkJC4ODgAHt7e7Rs2RJNmzZFu3btcOrUqfyD\nVj7//HNkZ2ejSZMmaNy4MT7//PPnru+DDz6AWq2Gg4MDevTogdWrVz8x0nqe6tWrY8SIEc/cX1iW\nDz74AGvWrEHTpk1x5syZ544iMzIy0LlzZzRp0gStW7fGnDlztM5GCklOBjp0AD7/XHPL0jNa4kV+\nSSrFyclJPj0B7OnTp9GwYUOFEhUtLi4O3377LXbs2PHMY+XLl88fLRGZMmP+P2gQ06dryi43F7C2\nBqZNA8aNUzqVRRFCpEgpnZ73PIsc4RER6YybG1CqlKbsSpXSfE1GyeIOWjEkNzc3uBXyj5+jOyIz\n4eICxMQAcXGasnNxUTo
"text/plain": [
"<Figure size 504x504 with 1 Axes>"
2019-10-17 08:58:25 +02:00
]
2021-02-10 06:39:35 -08:00
},
"metadata": {
"tags": []
},
"output_type": "display_data"
2019-10-17 08:58:25 +02:00
}
2021-02-10 06:39:35 -08:00
],
"source": [
"figure(figsize=(7, 7))\n",
"plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n",
"plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n",
"plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n",
"plot_scale = 3\n",
"plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n",
"xlabel('True beta')\n",
"ylabel('Estimated beta')\n",
"legend(loc='best')"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "_bXdOlvUEJl0"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "vmapped_log_probs.ipynb",
"provenance": []
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 0
2020-10-16 13:11:56 -07:00
}