2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""Interface and utility functions to XLA.
|
|
|
|
|
|
|
|
This module wraps the XLA client(s) and builders to standardize their interfaces
|
|
|
|
and provide some automatic type mapping logic for converting between Numpy and
|
|
|
|
XLA. There are also a handful of related casting utilities.
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
from functools import partial
|
2018-11-17 18:03:33 -08:00
|
|
|
import os
|
2020-06-04 15:27:48 -07:00
|
|
|
from typing import Callable, Dict, Tuple, Union
|
2018-11-17 18:03:33 -08:00
|
|
|
import warnings
|
|
|
|
|
2019-08-26 11:22:58 -07:00
|
|
|
from absl import logging
|
|
|
|
|
2018-11-29 12:30:34 -08:00
|
|
|
from ..config import flags
|
2019-07-22 17:24:10 -04:00
|
|
|
from .. import util
|
2019-11-15 10:02:51 -05:00
|
|
|
from .. import dtypes
|
2020-07-14 13:05:31 -07:00
|
|
|
import numpy as np
|
2019-08-09 13:38:08 -04:00
|
|
|
import threading
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-14 14:00:08 -08:00
|
|
|
try:
|
|
|
|
from . import tpu_client
|
|
|
|
except ImportError:
|
|
|
|
tpu_client = None
|
2019-07-30 21:01:48 -04:00
|
|
|
from . import xla_client
|
2019-03-31 16:19:59 -07:00
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
xops = xla_client.ops
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
FLAGS = flags.FLAGS
|
2019-11-15 10:02:51 -05:00
|
|
|
|
2019-02-07 17:48:22 -05:00
|
|
|
flags.DEFINE_string(
|
|
|
|
'jax_xla_backend', 'xla',
|
2019-11-14 14:00:08 -08:00
|
|
|
'Default is "xla" for the XLA service directly, '
|
|
|
|
'or "tpu_driver" for using high-performance access to Cloud TPU hardware.')
|
2018-11-25 18:53:48 -08:00
|
|
|
flags.DEFINE_string(
|
|
|
|
'jax_backend_target', 'local',
|
|
|
|
'Either "local" or "rpc:address" to connect to a remote service target.')
|
2018-11-17 18:03:33 -08:00
|
|
|
flags.DEFINE_string(
|
2019-05-10 12:27:15 -07:00
|
|
|
'jax_platform_name',
|
|
|
|
os.getenv('JAX_PLATFORM_NAME', ''),
|
|
|
|
'Platform name for XLA. The default is to attempt to use a GPU if '
|
|
|
|
'available, but fall back to CPU otherwise. To set the platform manually, '
|
|
|
|
'pass "cpu" for CPU or "gpu" for GPU.')
|
2020-05-26 20:21:22 +01:00
|
|
|
flags.DEFINE_bool(
|
|
|
|
'jax_disable_most_optimizations', False,
|
|
|
|
'Try not to do much optimization work. This can be useful if the cost of '
|
|
|
|
'optimization is greater than that of running a less-optimized program.')
|
2019-02-07 17:48:22 -05:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-01 13:04:27 -07:00
|
|
|
def get_compile_options(num_replicas, num_partitions, device_assignment=None,
|
|
|
|
use_spmd_partitioning=True):
|
2019-07-23 02:48:53 -07:00
|
|
|
"""Returns the compile options to use, as derived from flag values.
|
|
|
|
|
|
|
|
Args:
|
2020-01-29 19:35:48 +00:00
|
|
|
num_replicas: int indicating the number of replicas for which to compile.
|
|
|
|
num_partitions: int indicating the number of partitions for which to compile.
|
2019-07-23 02:48:53 -07:00
|
|
|
device_assignment: Optional tuple of integers indicating the assignment of
|
|
|
|
logical replicas to physical devices (default inherited from
|
2020-01-29 19:35:48 +00:00
|
|
|
xla_client.CompileOptions). Must be consistent with `num_replicas` and
|
|
|
|
`num_partitions`.
|
2020-07-01 13:04:27 -07:00
|
|
|
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
|
|
|
|
partitioning in XLA.
|
2019-07-23 02:48:53 -07:00
|
|
|
"""
|
2020-01-29 19:35:48 +00:00
|
|
|
compile_options = xla_client.CompileOptions()
|
|
|
|
compile_options.num_replicas = num_replicas
|
|
|
|
compile_options.num_partitions = num_partitions
|
2020-07-01 13:04:27 -07:00
|
|
|
build_options = compile_options.executable_build_options
|
|
|
|
build_options.use_spmd_partitioning = use_spmd_partitioning
|
2019-07-23 02:48:53 -07:00
|
|
|
if device_assignment is not None:
|
2020-01-29 19:35:48 +00:00
|
|
|
logging.vlog(
|
|
|
|
2,
|
|
|
|
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
|
|
|
|
num_replicas, num_partitions, device_assignment)
|
2020-07-14 13:05:31 -07:00
|
|
|
device_assignment = np.array(device_assignment)
|
2020-01-29 19:35:48 +00:00
|
|
|
|
|
|
|
# Allow 1D device assignment if num_partitions is 1.
|
|
|
|
if (device_assignment.ndim == 1) and (num_partitions == 1):
|
2019-12-02 16:07:23 -08:00
|
|
|
device_assignment = device_assignment[:, None]
|
2020-01-29 19:35:48 +00:00
|
|
|
|
|
|
|
if num_replicas != device_assignment.shape[0]:
|
|
|
|
msg = 'device_assignment does not match num_replicas: {} vs {}.'
|
|
|
|
raise ValueError(msg.format(device_assignment, num_replicas))
|
|
|
|
|
|
|
|
if num_partitions != device_assignment.shape[1]:
|
|
|
|
msg = 'device_assignment does not match num_partitions: {} vs {}.'
|
|
|
|
raise ValueError(msg.format(device_assignment, num_partitions))
|
|
|
|
|
2019-07-23 02:48:53 -07:00
|
|
|
device_assignment = xla_client.DeviceAssignment.create(device_assignment)
|
2020-01-29 19:35:48 +00:00
|
|
|
assert device_assignment.replica_count() == num_replicas
|
|
|
|
assert device_assignment.computation_count() == num_partitions
|
2019-07-23 02:48:53 -07:00
|
|
|
compile_options.device_assignment = device_assignment
|
2020-05-26 20:21:22 +01:00
|
|
|
|
|
|
|
if FLAGS.jax_disable_most_optimizations:
|
|
|
|
debug_options = compile_options.executable_build_options.debug_options
|
|
|
|
debug_options.xla_backend_optimization_level = 0
|
|
|
|
debug_options.xla_llvm_disable_expensive_passes = True
|
|
|
|
debug_options.xla_test_all_input_layouts = False
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
return compile_options
|
|
|
|
|
2019-03-29 11:09:56 -04:00
|
|
|
_backends = {}
|
2018-11-25 18:53:48 -08:00
|
|
|
|
2019-03-29 11:09:56 -04:00
|
|
|
def register_backend(name, factory):
|
|
|
|
_backends[name] = factory
|
2018-11-25 18:53:48 -08:00
|
|
|
|
2019-08-19 23:45:36 -07:00
|
|
|
def _get_local_backend(platform=None):
|
|
|
|
if not platform:
|
2020-04-24 13:11:53 -07:00
|
|
|
platform = FLAGS.jax_platform_name or None
|
2019-03-29 11:09:56 -04:00
|
|
|
|
2019-05-03 21:06:33 -04:00
|
|
|
backend = xla_client.get_local_backend(platform)
|
2019-03-29 11:09:56 -04:00
|
|
|
if backend is None:
|
|
|
|
raise RuntimeError("No local XLA backends found.")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-24 13:11:53 -07:00
|
|
|
if backend.platform == 'cpu' and platform != 'cpu':
|
2019-03-29 11:09:56 -04:00
|
|
|
warnings.warn('No GPU/TPU found, falling back to CPU.')
|
2019-02-07 17:48:22 -05:00
|
|
|
|
2019-03-29 11:09:56 -04:00
|
|
|
return backend
|
2019-02-07 17:48:22 -05:00
|
|
|
|
2018-11-25 18:53:48 -08:00
|
|
|
|
2019-12-02 15:02:27 -08:00
|
|
|
register_backend('xla', _get_local_backend)
|
|
|
|
|
|
|
|
# memoize the TPU driver to be consistent with xla_client behavior
|
|
|
|
_tpu_backend = None
|
|
|
|
|
2019-11-14 14:00:08 -08:00
|
|
|
def _get_tpu_driver_backend(platform):
|
|
|
|
del platform
|
2019-12-02 15:02:27 -08:00
|
|
|
global _tpu_backend
|
|
|
|
if _tpu_backend is None:
|
|
|
|
backend_target = FLAGS.jax_backend_target
|
|
|
|
if backend_target is None:
|
|
|
|
raise ValueError('When using TPU Driver as the backend, you must specify '
|
|
|
|
'--jax_backend_target=<hostname>:8470.')
|
|
|
|
_tpu_backend = tpu_client.TpuBackend.create(worker=backend_target)
|
|
|
|
return _tpu_backend
|
2019-11-14 14:00:08 -08:00
|
|
|
|
|
|
|
|
|
|
|
if tpu_client:
|
|
|
|
register_backend('tpu_driver', _get_tpu_driver_backend)
|
2019-03-31 10:31:47 -07:00
|
|
|
|
2019-12-02 15:02:27 -08:00
|
|
|
|
2019-08-09 13:38:08 -04:00
|
|
|
_backend_lock = threading.Lock()
|
2019-08-09 13:55:04 -04:00
|
|
|
|
2019-08-09 13:12:44 -04:00
|
|
|
@util.memoize
|
2019-08-19 23:45:36 -07:00
|
|
|
def get_backend(platform=None):
|
2019-12-18 11:18:33 -08:00
|
|
|
# TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
|
|
|
|
# 'backend' values are handled
|
2020-04-28 10:03:31 -04:00
|
|
|
if not isinstance(platform, (type(None), str)):
|
2019-12-18 11:18:33 -08:00
|
|
|
return platform
|
|
|
|
|
2019-08-09 13:55:04 -04:00
|
|
|
with _backend_lock:
|
2019-08-09 13:38:08 -04:00
|
|
|
backend = _backends.get(FLAGS.jax_xla_backend)
|
|
|
|
if backend is None:
|
|
|
|
msg = 'Unknown jax_xla_backend value "{}".'
|
|
|
|
raise ValueError(msg.format(FLAGS.jax_xla_backend))
|
2019-08-19 23:45:36 -07:00
|
|
|
return backend(platform)
|
2018-11-25 18:53:48 -08:00
|
|
|
|
|
|
|
|
2019-11-25 16:23:40 -08:00
|
|
|
def get_device_backend(device=None):
|
|
|
|
"""Returns the Backend associated with `device`, or the default Backend."""
|
|
|
|
platform = device.platform if device else None
|
|
|
|
return get_backend(platform)
|
|
|
|
|
|
|
|
|
2020-05-20 14:40:28 -07:00
|
|
|
def device_count(backend: str = None):
|
2019-08-26 11:22:58 -07:00
|
|
|
"""Returns the total number of devices.
|
|
|
|
|
2020-06-08 10:37:50 -04:00
|
|
|
On most platforms, this is the same as :py:func:`jax.local_device_count`.
|
|
|
|
However, on multi-host platforms, this will return the total number of devices
|
|
|
|
across all hosts.
|
2019-08-26 11:22:58 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
backend: This is an experimental feature and the API is likely to change.
|
2020-06-08 10:37:50 -04:00
|
|
|
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
|
|
|
``'tpu'``.
|
2019-08-26 11:22:58 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Number of devices.
|
|
|
|
"""
|
2019-08-19 23:45:36 -07:00
|
|
|
return int(get_backend(backend).device_count())
|
2019-03-06 17:28:28 -08:00
|
|
|
|
|
|
|
|
2020-05-20 14:40:28 -07:00
|
|
|
def local_device_count(backend: str =None):
|
2019-08-26 11:22:58 -07:00
|
|
|
"""Returns the number of devices on this host."""
|
|
|
|
return int(get_backend(backend).local_device_count())
|
|
|
|
|
|
|
|
|
2020-05-20 14:40:28 -07:00
|
|
|
def devices(backend: str = None):
|
|
|
|
"""Returns a list of all devices for a given backend.
|
2019-08-26 11:22:58 -07:00
|
|
|
|
2020-06-08 10:37:50 -04:00
|
|
|
Each device is represented by a subclass of :class:`Device` (e.g.
|
|
|
|
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
|
|
|
|
equal to ``device_count(backend)``. Local devices can be identified by comparing
|
2020-06-11 17:11:32 -04:00
|
|
|
:meth:`Device.host_id` to the value returned by :py:func:`jax.host_id`.
|
2019-08-26 11:22:58 -07:00
|
|
|
|
2020-05-20 14:40:28 -07:00
|
|
|
If ``backend`` is ``None``, returns all the devices from the default backend.
|
2020-06-08 10:37:50 -04:00
|
|
|
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
|
|
|
|
otherwise ``'cpu'``.
|
2020-05-20 14:40:28 -07:00
|
|
|
|
2019-08-26 11:22:58 -07:00
|
|
|
Args:
|
|
|
|
backend: This is an experimental feature and the API is likely to change.
|
2020-06-08 10:37:50 -04:00
|
|
|
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
|
|
|
``'tpu'``.
|
2019-08-26 11:22:58 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
List of Device subclasses.
|
|
|
|
"""
|
|
|
|
return get_backend(backend).devices()
|
|
|
|
|
|
|
|
|
2020-05-20 14:40:28 -07:00
|
|
|
def local_devices(host_id: int = None, backend: str = None):
|
2020-06-08 10:37:50 -04:00
|
|
|
"""Like :py:func:`jax.devices`, but only returns devices local to a given host.
|
2020-05-20 14:40:28 -07:00
|
|
|
|
|
|
|
If ``host_id`` is ``None``, returns devices local to this host.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
host_id: the integer ID of the host. Host IDs can be retrieved via
|
2020-06-08 10:37:50 -04:00
|
|
|
:py:func:`jax.host_ids`.
|
2020-05-20 14:40:28 -07:00
|
|
|
backend: This is an experimental feature and the API is likely to change.
|
2020-06-08 10:37:50 -04:00
|
|
|
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
|
|
|
``'tpu'``.
|
2020-05-20 14:40:28 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
List of Device subclasses.
|
|
|
|
"""
|
2019-09-27 14:38:16 -07:00
|
|
|
if host_id is None:
|
|
|
|
host_id = get_backend(backend).host_id()
|
2020-05-20 14:40:28 -07:00
|
|
|
if host_id not in host_ids():
|
|
|
|
raise ValueError(f"Unknown host_id {host_id}")
|
2019-09-27 14:38:16 -07:00
|
|
|
return [d for d in devices(backend) if d.host_id == host_id]
|
|
|
|
|
|
|
|
|
2020-05-20 14:40:28 -07:00
|
|
|
def host_id(backend: str = None):
|
2019-08-26 11:22:58 -07:00
|
|
|
"""Returns the integer host ID of this host.
|
|
|
|
|
|
|
|
On most platforms, this will always be 0. This will vary on multi-host
|
|
|
|
platforms though.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
backend: This is an experimental feature and the API is likely to change.
|
2020-06-08 10:37:50 -04:00
|
|
|
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
|
|
|
``'tpu'``.
|
2019-08-26 11:22:58 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Integer host ID.
|
|
|
|
"""
|
|
|
|
return get_backend(backend).host_id()
|
|
|
|
|
|
|
|
|
2020-05-20 14:40:28 -07:00
|
|
|
def host_ids(backend: str = None):
|
2020-03-17 21:02:41 +00:00
|
|
|
"""Returns a sorted list of all host IDs."""
|
2020-09-17 21:51:18 +05:30
|
|
|
return sorted({d.host_id for d in devices(backend)})
|
2019-09-27 14:38:16 -07:00
|
|
|
|
|
|
|
|
2020-05-20 14:40:28 -07:00
|
|
|
def host_count(backend: str = None):
|
2019-11-22 11:03:26 -08:00
|
|
|
"""Returns the number of hosts."""
|
2019-09-27 14:38:16 -07:00
|
|
|
return len(host_ids(backend))
|
Add support for multihost pmaps.
All participating hosts are assumed to be running the same pmap
code. Conceptually, this can be considered a single pmap over an array
sharded on its leading pmapped dimension across the hosts. Each host
passes its input shard to its pmapped function call, which returns the
corresponding output shard (i.e. an array of the same leading
dimension size). However, any collective operations will be run across
the entire "global" array.
If the `devices` argument to pmap is None, the pmap is assumed to be
running across all hosts visible to XLA (as returned by
jax.host_count()). Each host can pass in an input array of leading
dimension size equal to or less than the number of devices local to
that host. Note that this doesn't change the current behavior for
single-host platforms. If `devices` are specified, the participating
hosts are dictated by the devices' host_ids, and each host must pass
in an input array of leading dim size equal to the number of local
participating devices.
Implementation-wise, each host independently compiles the computation,
which we assume yields the same executable on all hosts (follow-up
work will add more error checking). The hosts must know the global
axis size of the sharded array, e.g. to provide the correct replica
count to XLA. This is equal to the length of `devices` if specified,
but if not, pmap is recursively called (with `devices` specified) to
use `psum` to compute the global axis size.
2019-09-19 11:02:34 -07:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
### utility functions
|
|
|
|
|
2019-08-09 13:12:44 -04:00
|
|
|
@util.memoize
|
2018-11-17 18:03:33 -08:00
|
|
|
def dtype_to_etype(dtype):
|
|
|
|
"""Convert from dtype to canonical etype (reading FLAGS.jax_enable_x64)."""
|
2019-11-15 10:02:51 -05:00
|
|
|
return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-08-09 13:12:44 -04:00
|
|
|
@util.memoize
|
2018-11-17 18:03:33 -08:00
|
|
|
def supported_numpy_dtypes():
|
2019-11-15 10:02:51 -05:00
|
|
|
return {dtypes.canonicalize_dtype(dtype)
|
2019-02-18 14:28:59 -05:00
|
|
|
for dtype in xla_client.XLA_ELEMENT_TYPE_TO_DTYPE.values()}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
# TODO(mattjj,frostig): try to remove this function
|
|
|
|
def normalize_to_xla_dtypes(val):
|
|
|
|
"""Normalize dtypes in a value."""
|
2020-07-14 13:05:31 -07:00
|
|
|
if hasattr(val, '__array__') or np.isscalar(val):
|
|
|
|
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(val, (tuple, list)):
|
|
|
|
return tuple(normalize_to_xla_dtypes(x) for x in val)
|
|
|
|
raise TypeError('Can\'t convert to XLA: {}'.format(val))
|
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
def _numpy_array_constant(builder, value, canonicalize_types=True):
|
|
|
|
if canonicalize_types:
|
|
|
|
value = normalize_to_xla_dtypes(value)
|
|
|
|
return xops.ConstantLiteral(builder, value)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-29 11:31:36 -07:00
|
|
|
def parameter(builder, num, shape, name=None, replicated=None):
|
2020-04-23 18:30:47 -04:00
|
|
|
if name is None:
|
|
|
|
name = ''
|
2020-04-29 11:31:36 -07:00
|
|
|
if replicated is None:
|
|
|
|
replicated = []
|
|
|
|
elif isinstance(replicated, bool):
|
2020-04-23 18:30:47 -04:00
|
|
|
replicated = [replicated] * shape.leaf_count()
|
|
|
|
|
|
|
|
return xops.Parameter(builder, num,
|
|
|
|
shape.with_major_to_minor_layout_if_absent(), name,
|
|
|
|
replicated)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
def constant(builder, py_val, canonicalize_types=True):
|
|
|
|
"""Translate constant `py_val` to a constant, canonicalizing its dtype.
|
2019-03-03 12:44:26 -08:00
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
Args:
|
|
|
|
py_val: a Python value to be translated to a constant.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A representation of the constant, either a ComputationDataHandle or None
|
|
|
|
"""
|
|
|
|
py_type = type(py_val)
|
|
|
|
if py_type in _constant_handlers:
|
|
|
|
return _constant_handlers[py_type](builder, py_val, canonicalize_types)
|
|
|
|
else:
|
|
|
|
raise TypeError("No constant handler for type: {}".format(py_type))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-05-14 11:38:08 -07:00
|
|
|
# HLO instructions optionally can be annotated to say how the output should be
|
|
|
|
# spatially partitioned (represented in XLA as OpSharding protos, see
|
|
|
|
# _sharding_to_proto). For array outputs, the annotation is either an int per
|
|
|
|
# dimension specifying the number of ways that dimension divided (i.e. the total
|
|
|
|
# number of shards is the product), or None to indicate the array should be
|
|
|
|
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
|
|
|
|
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
|
|
|
|
# checkers don't support recursive types), so we only represent one level of
|
|
|
|
# nesting in this type definition.
|
|
|
|
SpatialSharding = Union[Tuple[int, ...],
|
|
|
|
None,
|
|
|
|
Tuple[Union[Tuple[int, ...], None], ...]]
|
|
|
|
|
|
|
|
def _sharding_to_proto(sharding: SpatialSharding):
|
|
|
|
"""Converts a SpatialSharding to an OpSharding.
|
|
|
|
|
|
|
|
See
|
|
|
|
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
|
|
|
|
for details on the OpSharding proto.
|
|
|
|
"""
|
|
|
|
proto = xla_client.OpSharding()
|
|
|
|
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
|
|
|
|
assert all(s is None or isinstance(s, tuple) for s in sharding)
|
|
|
|
sub_protos = [_sharding_to_proto(s) for s in sharding] # type: ignore
|
|
|
|
proto.type = xla_client.OpSharding.Type.TUPLE
|
|
|
|
proto.tuple_shardings = sub_protos
|
|
|
|
return proto
|
|
|
|
|
|
|
|
if sharding is None:
|
|
|
|
proto.type = xla_client.OpSharding.Type.REPLICATED
|
|
|
|
else:
|
|
|
|
proto.type = xla_client.OpSharding.Type.OTHER
|
|
|
|
proto.tile_assignment_dimensions = list(sharding)
|
2020-07-14 13:05:31 -07:00
|
|
|
proto.tile_assignment_devices = list(range(np.product(sharding)))
|
2020-05-14 11:38:08 -07:00
|
|
|
return proto
|
|
|
|
|
|
|
|
def set_sharding(builder, op, sharding: SpatialSharding):
|
|
|
|
"""Uses CustomCall to annotate a value as sharded."""
|
|
|
|
# "Sharding" is a built-in custom call target that acts like an identity
|
|
|
|
# function, and is used to attach an OpSharding to.
|
|
|
|
return with_sharding(builder, sharding, xops.CustomCall,
|
2020-05-15 15:51:07 -04:00
|
|
|
builder, b"Sharding", [op], builder.get_shape(op))
|
2020-05-14 11:38:08 -07:00
|
|
|
|
|
|
|
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
|
|
|
|
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
2020-05-15 15:51:07 -04:00
|
|
|
builder.set_sharding(_sharding_to_proto(sharding))
|
2020-05-14 11:38:08 -07:00
|
|
|
try:
|
|
|
|
return op_fn(*args, **kwargs)
|
|
|
|
finally:
|
2020-05-15 15:51:07 -04:00
|
|
|
builder.clear_sharding()
|
2020-05-14 11:38:08 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def make_computation_builder(name):
|
2020-04-23 18:30:47 -04:00
|
|
|
return xla_client.XlaBuilder(name)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def register_constant_handler(type_, handler_fun):
|
|
|
|
_constant_handlers[type_] = handler_fun
|
2020-03-18 17:06:05 -04:00
|
|
|
_constant_handlers: Dict[type, Callable] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-01-28 18:41:27 -05:00
|
|
|
def _ndarray_constant_handler(c, val, canonicalize_types=True):
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Constant handler for ndarray literals, handling zero-size strides.
|
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
This function essentially calls _numpy_array_constant(val) except it has
|
2018-11-17 18:03:33 -08:00
|
|
|
special handling of arrays with any strides of size zero: for those, it
|
|
|
|
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
|
|
|
|
to avoid staging in large literals that might arise from np.zeros or np.ones
|
2020-07-14 13:05:31 -07:00
|
|
|
or the output of lax.broadcast (which uses np.broadcast_to which in turn
|
2018-11-17 18:03:33 -08:00
|
|
|
uses size-zero strides).
|
|
|
|
|
|
|
|
Args:
|
2020-04-23 18:30:47 -04:00
|
|
|
c: an XlaBuilder
|
2018-11-17 18:03:33 -08:00
|
|
|
val: an ndarray.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
|
|
|
|
staged into the XLA Computation.
|
|
|
|
"""
|
2020-04-23 18:30:47 -04:00
|
|
|
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
|
2020-07-14 13:05:31 -07:00
|
|
|
if np.any(np.equal(0, val.strides)) and val.size > 0:
|
|
|
|
zero_stride_axes, = np.where(np.equal(0, val.strides))
|
|
|
|
other_axes, = np.where(np.not_equal(0, val.strides))
|
2018-11-17 18:03:33 -08:00
|
|
|
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
|
|
|
|
for ax in range(val.ndim))]
|
2020-04-23 18:30:47 -04:00
|
|
|
xla_val = xops.Broadcast(
|
|
|
|
_numpy_array_constant(c, collapsed_val, canonicalize_types),
|
2020-07-14 13:05:31 -07:00
|
|
|
np.take(val.shape, zero_stride_axes))
|
|
|
|
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Transpose(xla_val, permutation)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2020-04-23 18:30:47 -04:00
|
|
|
return _numpy_array_constant(c, val, canonicalize_types)
|
2020-07-14 13:05:31 -07:00
|
|
|
register_constant_handler(np.ndarray, _ndarray_constant_handler)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-01-28 18:41:27 -05:00
|
|
|
def _scalar_constant_handler(c, val, canonicalize_types=True):
|
2020-04-23 18:30:47 -04:00
|
|
|
return _numpy_array_constant(c, val, canonicalize_types)
|
2019-01-28 18:41:27 -05:00
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
|
|
|
|
np.uint8, np.uint16, np.uint32, np.uint64,
|
|
|
|
np.float16, np.float32, np.float64, np.float128,
|
2020-07-30 12:59:36 -07:00
|
|
|
np.bool_, np.longlong,
|
|
|
|
xla_client.bfloat16]:
|
2019-01-28 18:41:27 -05:00
|
|
|
register_constant_handler(scalar_type, _scalar_constant_handler)
|
2019-05-13 16:48:19 -04:00
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
|
2020-04-23 18:30:47 -04:00
|
|
|
return _numpy_array_constant(c, dtype.type(val))
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
|
|
|
|
for ptype, dtype in dtypes.python_scalar_dtypes.items():
|
|
|
|
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
|