rocm_jax/jax/abstract_arrays.py
2020-05-12 20:37:05 -04:00

75 lines
2.3 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import numpy as np
from . import ad_util
from . import core
from . import dtypes
_DIMENSION_TYPES = core._DIMENSION_TYPES
UnshapedArray = core.UnshapedArray
ShapedArray = core.ShapedArray
ConcreteArray = core.ConcreteArray
AbstractToken = core.AbstractToken
abstract_token = core.abstract_token
canonicalize_shape = core.canonicalize_shape
raise_to_shaped = core.raise_to_shaped
def make_shaped_array(x):
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
return ShapedArray(np.shape(x), dtype)
def zeros_like_array(x):
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
return np.broadcast_to(np.array(0, dtype), np.shape(x))
array_types = {np.ndarray, np.bool_,
np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
dtypes.bfloat16, np.float16, np.float32, np.float64,
np.complex64, np.complex128,
np.longlong}
for t in array_types:
core.pytype_aval_mappings[t] = ConcreteArray
ad_util.jaxval_zeros_likers[t] = zeros_like_array
def zeros_like_shaped_array(aval):
assert isinstance(aval, ShapedArray)
return np.zeros(aval.shape, dtype=aval.dtype)
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
core.literalable_types.update(array_types)
def _zeros_like_python_scalar(t, x):
return np.array(0, dtypes.python_scalar_dtypes[t])
def _make_concrete_python_scalar(t, x):
return ConcreteArray(
np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
weak_type=True)
for t in dtypes.python_scalar_dtypes.keys():
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())