This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.
Bumps the minimum jaxlib version in order to include
https://github.com/tensorflow/tensorflow/pull/53656
* trivial jit computations were forcing commitment to the default device
* a device_put with a device specification would not set the commitment
if the data was already (uncommitted) on the specified device.
* added tests for the above
* once the above were fixed the LaztTest.test_zeros_ones_compilation
stated to fail because the `sticky` parameter to lazy_force_computation
was changing. Fixed this by removing stickyness from the compilation key.
* Expanded docstring for jax.device_put; expanded the
device placement FAQ entry.
* Fixed a few places where device sitckyness was lost. Added FAQ for device
placement.
I have also added a new test (multi_device_test.test_computation_follows_data),
written more as part of the documentation. It is shorted than the
old test_computation_follows_data (which is still there, renamed
as test_computation_follows_data_old). I believe there is no
extra coverage in test_computation_follows_data_old w.r.t. all the
other tests we have.
* Fix mypy annotations and updates based on comments
* Undid some changes, will make another PR
* Added clearer error message for tracers in numpy.split
Now we print:
ConcretizationTypeError: Abstract tracer value where concrete value is expected (in
jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid
tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray>
* Fixed tests, slight change to the error message
* Expanded the FAQ entry about abstract tracers for higher-order primitives
* Added clarification for tracers inside jit of grad
* Updated FAQ language in response to reviews