mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 13:46:08 +00:00
75 lines
2.3 KiB
Python
75 lines
2.3 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
|
|
from . import ad_util
|
|
from . import core
|
|
from . import dtypes
|
|
|
|
_DIMENSION_TYPES = core._DIMENSION_TYPES
|
|
|
|
UnshapedArray = core.UnshapedArray
|
|
ShapedArray = core.ShapedArray
|
|
ConcreteArray = core.ConcreteArray
|
|
AbstractToken = core.AbstractToken
|
|
abstract_token = core.abstract_token
|
|
canonicalize_shape = core.canonicalize_shape
|
|
raise_to_shaped = core.raise_to_shaped
|
|
|
|
|
|
def make_shaped_array(x):
|
|
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
|
|
return ShapedArray(np.shape(x), dtype)
|
|
|
|
def zeros_like_array(x):
|
|
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
|
|
return np.broadcast_to(np.array(0, dtype), np.shape(x))
|
|
|
|
array_types = {np.ndarray, np.bool_,
|
|
np.int8, np.int16, np.int32, np.int64,
|
|
np.uint8, np.uint16, np.uint32, np.uint64,
|
|
dtypes.bfloat16, np.float16, np.float32, np.float64,
|
|
np.complex64, np.complex128,
|
|
np.longlong}
|
|
|
|
for t in array_types:
|
|
core.pytype_aval_mappings[t] = ConcreteArray
|
|
ad_util.jaxval_zeros_likers[t] = zeros_like_array
|
|
|
|
|
|
def zeros_like_shaped_array(aval):
|
|
assert isinstance(aval, ShapedArray)
|
|
return np.zeros(aval.shape, dtype=aval.dtype)
|
|
|
|
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
|
|
|
|
core.literalable_types.update(array_types)
|
|
|
|
def _zeros_like_python_scalar(t, x):
|
|
return np.array(0, dtypes.python_scalar_dtypes[t])
|
|
|
|
def _make_concrete_python_scalar(t, x):
|
|
return ConcreteArray(
|
|
np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
|
|
weak_type=True)
|
|
|
|
for t in dtypes.python_scalar_dtypes.keys():
|
|
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
|
|
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)
|
|
|
|
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
|