mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
document setting default device and array creation
This commit is contained in:
parent
5b3cbc5e18
commit
c4a08aff45
@ -67,7 +67,9 @@ and 2) whether it is **committed** to the device or not (the data is sometimes
|
||||
referred to as being *sticky* to the device).
|
||||
|
||||
By default, JAX arrays are placed uncommitted on the default device
|
||||
(``jax.devices()[0]``).
|
||||
(``jax.devices()[0]``), which is the first GPU by default. The default
|
||||
device can be set to "cpu" or "gpu" manually by setting the environment
|
||||
variable ``JAX_PLATFORM_NAME`` or the absl flag ``--jax_platform_name``.
|
||||
|
||||
>>> from jax import numpy as jnp
|
||||
>>> print(jnp.ones(3).device_buffer.device()) # doctest: +SKIP
|
||||
@ -97,6 +99,11 @@ device.
|
||||
Jitted functions behave like any other primitive operations—they will follow the
|
||||
data and will show errors if invoked on data committed on more than one device.
|
||||
|
||||
``jnp.device_put(jnp.zeros(...), jax.devices()[1])`` or similar will actually create the
|
||||
array of zeros on ``jax.devices()[1]``, instead of creating the array on the default
|
||||
device then moving it. This is thanks to some laziness in array creation, which holds
|
||||
for all the constant creation operations (``ones``, ``full``, ``eye``, etc).
|
||||
|
||||
(As of April 2020, :func:`jax.jit` has a `device` parameter that affects the device
|
||||
placement. That parameter is experimental, is likely to be removed or changed,
|
||||
and its use is not recommended.)
|
||||
|
Loading…
x
Reference in New Issue
Block a user