Fix stale notebook.

The notebook fails with type errors, due to changes in XLA.
Use numpy.dtype for the xla_client.Shape.array_shape calls.
This commit is contained in:
George Necula 2019-07-09 12:14:01 +02:00
parent 498d501a35
commit 63fe05883d

View File

@ -57,7 +57,7 @@
"source": [
"## References \n",
"\n",
"__xla__ the doc that defines what's in HLO - but note that the doc is incomplete and omits some ops.\n",
"__xla__: the doc that defines what's in HLO - but note that the doc is incomplete and omits some ops.\n",
"\n",
"https://www.tensorflow.org/xla/operation_semantics\n",
"\n",
@ -65,13 +65,13 @@
"\n",
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/xla_builder.h\n",
"\n",
"__python xla client__ this is the XLA python client for JAX, and what we're using here.\n",
"__python xla client__: this is the XLA python client for JAX, and what we're using here.\n",
"\n",
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py\n",
"\n",
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client_test.py\n",
"\n",
"__jax__ you can see how jax interacts with the XLA compute layer for execution and JITing in these files.\n",
"__jax__: you can see how jax interacts with the XLA compute layer for execution and JITing in these files.\n",
"\n",
"https://github.com/google/jax/blob/master/jax/lax.py\n",
"\n",
@ -171,7 +171,7 @@
" str(onp.dtype('complex128')): onp.dtype('complex64'),\n",
" }\n",
" dtype = onp.dtype(dtype)\n",
" return str(_dtype_to_32bit_dtype.get(str(dtype), dtype))\n",
" return _dtype_to_32bit_dtype.get(str(dtype), dtype)\n",
"\n",
"def shape_of(value):\n",
" \"\"\"Given a Python or XLA value, return its canonicalized XLA Shape.\"\"\"\n",
@ -246,7 +246,7 @@
"c = xla_client.ComputationBuilder(\"simple_scalar\")\n",
"\n",
"# define a parameter shape and parameter\n",
"param_shape = xla_client.Shape.array_shape(onp.float32, ())\n",
"param_shape = xla_client.Shape.array_shape(onp.dtype(onp.float32), ())\n",
"x = c.ParameterWithShape(param_shape)\n",
"\n",
"# define computation graph\n",
@ -309,7 +309,7 @@
"# same as above with vector type:\n",
"\n",
"c = xla_client.ComputationBuilder(\"simple_vector\")\n",
"param_shape = xla_client.Shape.array_shape(onp.float32, (3,))\n",
"param_shape = xla_client.Shape.array_shape(onp.dtype(onp.float32), (3,))\n",
"x = c.ParameterWithShape(param_shape)\n",
"\n",
"# can also use this function to define a shape from an example:\n",
@ -380,20 +380,24 @@
],
"source": [
"# trivial while loop, decrement until 0\n",
"in_shape = shape_of(1)\n",
"# x = 5\n",
"# while x > 0:\n",
"# x = x - 1\n",
"#\n",
"in_shape = shape_of(5)\n",
"\n",
"# body computation:\n",
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
"x = bcb.ParameterWithShape(in_shape)\n",
"const = bcb.Constant(onp.int32(1))\n",
"y = bcb.Sub(x, const)\n",
"const1 = bcb.Constant(onp.int32(1))\n",
"y = bcb.Sub(x, const1)\n",
"body_computation = bcb.Build()\n",
"\n",
"# test computation:\n",
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
"x = tcb.ParameterWithShape(in_shape)\n",
"const = tcb.Constant(onp.int32(0))\n",
"y = tcb.Gt(x, const)\n",
"const0 = tcb.Constant(onp.int32(0))\n",
"y = tcb.Gt(x, const0)\n",
"test_computation = tcb.Build()\n",
"\n",
"# while computation:\n",
@ -588,8 +592,8 @@
")\n",
"# NB: in_shape is the same as the manually constructed:\n",
"# xla_client.Shape.tuple_shape(\n",
"# (xla_client.Shape.array_shape(onp.float32, matrix_shape), \n",
"# xla_client.Shape.array_shape(onp.int32, ()))\n",
"# (xla_client.Shape.array_shape(onp.dtype(onp.float32), matrix_shape), \n",
"# xla_client.Shape.array_shape(onp.dtype(onp.int32), ()))\n",
"# )\n",
"\n",
"# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",