mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix stale notebook.
The notebook fails with type errors, due to changes in XLA. Use numpy.dtype for the xla_client.Shape.array_shape calls.
This commit is contained in:
parent
498d501a35
commit
63fe05883d
@ -57,7 +57,7 @@
|
||||
"source": [
|
||||
"## References \n",
|
||||
"\n",
|
||||
"__xla__ the doc that defines what's in HLO - but note that the doc is incomplete and omits some ops.\n",
|
||||
"__xla__: the doc that defines what's in HLO - but note that the doc is incomplete and omits some ops.\n",
|
||||
"\n",
|
||||
"https://www.tensorflow.org/xla/operation_semantics\n",
|
||||
"\n",
|
||||
@ -65,13 +65,13 @@
|
||||
"\n",
|
||||
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/xla_builder.h\n",
|
||||
"\n",
|
||||
"__python xla client__ this is the XLA python client for JAX, and what we're using here.\n",
|
||||
"__python xla client__: this is the XLA python client for JAX, and what we're using here.\n",
|
||||
"\n",
|
||||
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py\n",
|
||||
"\n",
|
||||
"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client_test.py\n",
|
||||
"\n",
|
||||
"__jax__ you can see how jax interacts with the XLA compute layer for execution and JITing in these files.\n",
|
||||
"__jax__: you can see how jax interacts with the XLA compute layer for execution and JITing in these files.\n",
|
||||
"\n",
|
||||
"https://github.com/google/jax/blob/master/jax/lax.py\n",
|
||||
"\n",
|
||||
@ -171,7 +171,7 @@
|
||||
" str(onp.dtype('complex128')): onp.dtype('complex64'),\n",
|
||||
" }\n",
|
||||
" dtype = onp.dtype(dtype)\n",
|
||||
" return str(_dtype_to_32bit_dtype.get(str(dtype), dtype))\n",
|
||||
" return _dtype_to_32bit_dtype.get(str(dtype), dtype)\n",
|
||||
"\n",
|
||||
"def shape_of(value):\n",
|
||||
" \"\"\"Given a Python or XLA value, return its canonicalized XLA Shape.\"\"\"\n",
|
||||
@ -246,7 +246,7 @@
|
||||
"c = xla_client.ComputationBuilder(\"simple_scalar\")\n",
|
||||
"\n",
|
||||
"# define a parameter shape and parameter\n",
|
||||
"param_shape = xla_client.Shape.array_shape(onp.float32, ())\n",
|
||||
"param_shape = xla_client.Shape.array_shape(onp.dtype(onp.float32), ())\n",
|
||||
"x = c.ParameterWithShape(param_shape)\n",
|
||||
"\n",
|
||||
"# define computation graph\n",
|
||||
@ -309,7 +309,7 @@
|
||||
"# same as above with vector type:\n",
|
||||
"\n",
|
||||
"c = xla_client.ComputationBuilder(\"simple_vector\")\n",
|
||||
"param_shape = xla_client.Shape.array_shape(onp.float32, (3,))\n",
|
||||
"param_shape = xla_client.Shape.array_shape(onp.dtype(onp.float32), (3,))\n",
|
||||
"x = c.ParameterWithShape(param_shape)\n",
|
||||
"\n",
|
||||
"# can also use this function to define a shape from an example:\n",
|
||||
@ -380,20 +380,24 @@
|
||||
],
|
||||
"source": [
|
||||
"# trivial while loop, decrement until 0\n",
|
||||
"in_shape = shape_of(1)\n",
|
||||
"# x = 5\n",
|
||||
"# while x > 0:\n",
|
||||
"# x = x - 1\n",
|
||||
"#\n",
|
||||
"in_shape = shape_of(5)\n",
|
||||
"\n",
|
||||
"# body computation:\n",
|
||||
"bcb = xla_client.ComputationBuilder(\"bodycomp\")\n",
|
||||
"x = bcb.ParameterWithShape(in_shape)\n",
|
||||
"const = bcb.Constant(onp.int32(1))\n",
|
||||
"y = bcb.Sub(x, const)\n",
|
||||
"const1 = bcb.Constant(onp.int32(1))\n",
|
||||
"y = bcb.Sub(x, const1)\n",
|
||||
"body_computation = bcb.Build()\n",
|
||||
"\n",
|
||||
"# test computation:\n",
|
||||
"tcb = xla_client.ComputationBuilder(\"testcomp\")\n",
|
||||
"x = tcb.ParameterWithShape(in_shape)\n",
|
||||
"const = tcb.Constant(onp.int32(0))\n",
|
||||
"y = tcb.Gt(x, const)\n",
|
||||
"const0 = tcb.Constant(onp.int32(0))\n",
|
||||
"y = tcb.Gt(x, const0)\n",
|
||||
"test_computation = tcb.Build()\n",
|
||||
"\n",
|
||||
"# while computation:\n",
|
||||
@ -588,8 +592,8 @@
|
||||
")\n",
|
||||
"# NB: in_shape is the same as the manually constructed:\n",
|
||||
"# xla_client.Shape.tuple_shape(\n",
|
||||
"# (xla_client.Shape.array_shape(onp.float32, matrix_shape), \n",
|
||||
"# xla_client.Shape.array_shape(onp.int32, ()))\n",
|
||||
"# (xla_client.Shape.array_shape(onp.dtype(onp.float32), matrix_shape), \n",
|
||||
"# xla_client.Shape.array_shape(onp.dtype(onp.int32), ()))\n",
|
||||
"# )\n",
|
||||
"\n",
|
||||
"# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n",
|
||||
|
Loading…
x
Reference in New Issue
Block a user