update docs to remove stale reference to laziness optimization

This commit is contained in:
Matthew Johnson 2023-03-13 10:15:55 -07:00
parent 233911c001
commit 3eb9c7a6e7

View File

@ -372,10 +372,12 @@ device.
Jitted functions behave like any other primitive operations—they will follow the
data and will show errors if invoked on data committed on more than one device.
``jax.device_put(jnp.zeros(...), jax.devices()[1])`` or similar will actually create the
array of zeros on ``jax.devices()[1]``, instead of creating the array on the default
device then moving it. This is thanks to some laziness in array creation, which holds
for all the constant creation operations (``ones``, ``full``, ``eye``, etc).
(Before `PR #6002 <https://github.com/google/jax/pull/6002>`_ in March 2021
there was some laziness in creation of array constants, so that
``jax.device_put(jnp.zeros(...), jax.devices()[1])`` or similar would actually
create the array of zeros on ``jax.devices()[1]``, instead of creating the
array on the default device then moving it. But this optimization was removed
so as to simplify the implementation.)
(As of April 2020, :func:`jax.jit` has a `device` parameter that affects the device
placement. That parameter is experimental, is likely to be removed or changed,