2020-10-20 13:06:37 -07:00
|
|
|
# Copyright 2020 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from jax import core
|
2021-02-05 12:04:25 +02:00
|
|
|
from jax import numpy as jnp
|
2021-11-22 08:22:10 -08:00
|
|
|
from jax._src import device_array
|
2021-09-23 06:33:25 -07:00
|
|
|
from jax._src.lib import xla_client
|
|
|
|
from jax._src.lib import xla_bridge
|
2020-10-20 13:06:37 -07:00
|
|
|
|
2022-02-09 14:57:21 -08:00
|
|
|
SUPPORTED_DTYPES = frozenset({
|
|
|
|
jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16,
|
|
|
|
jnp.uint32, jnp.uint64, jnp.float16, jnp.bfloat16, jnp.float32,
|
2022-04-13 13:34:00 -04:00
|
|
|
jnp.float64, jnp.complex64, jnp.complex128})
|
2021-02-05 12:04:25 +02:00
|
|
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False):
|
2022-03-08 09:35:36 -05:00
|
|
|
"""Returns a DLPack tensor that encapsulates a ``DeviceArray`` `x`.
|
2020-10-20 13:06:37 -07:00
|
|
|
|
2022-03-08 09:35:36 -05:00
|
|
|
Takes ownership of the contents of ``x``; leaves `x` in an invalid/deleted
|
2020-10-20 13:06:37 -07:00
|
|
|
state.
|
|
|
|
|
|
|
|
Args:
|
2022-03-08 09:35:36 -05:00
|
|
|
x: a ``DeviceArray``, on either CPU or GPU.
|
2020-10-20 13:06:37 -07:00
|
|
|
take_ownership: If ``True``, JAX hands ownership of the buffer to DLPack,
|
|
|
|
and the consumer is free to mutate the buffer; the JAX buffer acts as if
|
|
|
|
it were deleted. If ``False``, JAX retains ownership of the buffer; it is
|
|
|
|
undefined behavior if the DLPack consumer writes to a buffer that JAX
|
|
|
|
owns.
|
|
|
|
"""
|
2021-11-22 08:22:10 -08:00
|
|
|
if not isinstance(x, device_array.DeviceArray):
|
2020-10-20 13:06:37 -07:00
|
|
|
raise TypeError("Argument to to_dlpack must be a DeviceArray, got {}"
|
|
|
|
.format(type(x)))
|
2021-02-04 15:13:39 -08:00
|
|
|
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
Remove the JAX lazy sublanguage.
Back in the mists of time, before omnistaging landed in JAX, we used lazy
expressions to avoid materializing large constants inside `jit` computations.
Omnistaging, which means that computations that are in the dynamic scope of a
`jit` are staged into the `jit` computation, has subsumed most of the reasons
for laziness to exist, and this PR removes the laziness support for simplicity.
At the time of this PR, laziness is used only for broadcasts and transposes in
eager mode (i.e., outside a `jit`). This allows us to:
a) fuse together multiple broadcasts and transposes, and
b) if a lazy expression is lexically captured by a `jit` computation, we can
avoid materializing it in its expanded form.
It is not clear that laziness has sufficient power to weight ratio to continue
to exist, and it is making other work on improving JAX dispatch times more
difficult. As a result, this PR removes laziness to unblock that work; if we
want laziness again we would want to reimplement it in C++ anyway.
2021-03-09 21:31:17 -05:00
|
|
|
x.device_buffer, take_ownership=take_ownership)
|
2020-10-20 13:06:37 -07:00
|
|
|
|
2021-07-15 16:39:18 -04:00
|
|
|
def from_dlpack(dlpack):
|
2022-03-08 09:35:36 -05:00
|
|
|
"""Returns a ``DeviceArray`` representation of a DLPack tensor.
|
2020-10-20 13:06:37 -07:00
|
|
|
|
2022-03-08 09:35:36 -05:00
|
|
|
The returned ``DeviceArray`` shares memory with ``dlpack``.
|
2020-10-20 13:06:37 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
dlpack: a DLPack tensor, on either CPU or GPU.
|
|
|
|
"""
|
2021-07-15 16:39:18 -04:00
|
|
|
cpu_backend = xla_bridge.get_backend("cpu")
|
|
|
|
try:
|
2022-05-05 09:32:26 -07:00
|
|
|
gpu_backend = xla_bridge.get_backend("cuda")
|
2021-07-15 16:39:18 -04:00
|
|
|
except RuntimeError:
|
|
|
|
gpu_backend = None
|
|
|
|
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
|
|
|
|
dlpack, cpu_backend, gpu_backend)
|
2021-06-22 06:38:22 -07:00
|
|
|
|
2021-02-04 15:13:39 -08:00
|
|
|
xla_shape = buf.xla_shape()
|
2020-10-20 13:06:37 -07:00
|
|
|
assert not xla_shape.is_tuple()
|
|
|
|
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
|
2021-11-22 08:22:10 -08:00
|
|
|
return device_array.make_device_array(aval, buf.device(), buf) # pytype: disable=attribute-error
|