2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2020-05-07 01:46:13 -04:00
|
|
|
from functools import partial
|
|
|
|
|
2020-05-12 21:37:05 -03:00
|
|
|
import numpy as np
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from . import ad_util
|
2020-03-09 09:14:23 +00:00
|
|
|
from . import core
|
2019-11-15 10:02:51 -05:00
|
|
|
from . import dtypes
|
2020-01-29 14:24:11 -08:00
|
|
|
|
2020-11-04 09:01:18 -08:00
|
|
|
from ._src import traceback_util
|
2020-10-26 10:03:06 -07:00
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
_DIMENSION_TYPES = core._DIMENSION_TYPES
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
UnshapedArray = core.UnshapedArray
|
|
|
|
ShapedArray = core.ShapedArray
|
|
|
|
ConcreteArray = core.ConcreteArray
|
|
|
|
AbstractToken = core.AbstractToken
|
|
|
|
abstract_token = core.abstract_token
|
|
|
|
canonicalize_shape = core.canonicalize_shape
|
|
|
|
raise_to_shaped = core.raise_to_shaped
|
2019-10-09 15:05:54 -04:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def make_shaped_array(x):
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
|
2020-05-12 21:37:05 -03:00
|
|
|
return ShapedArray(np.shape(x), dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-06 11:59:33 -08:00
|
|
|
def zeros_like_array(x):
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
|
2020-05-29 09:43:08 +00:00
|
|
|
return zeros_like_shaped_array(ShapedArray(np.shape(x), dtype))
|
2019-01-06 11:59:33 -08:00
|
|
|
|
2020-05-12 21:37:05 -03:00
|
|
|
array_types = {np.ndarray, np.bool_,
|
|
|
|
np.int8, np.int16, np.int32, np.int64,
|
|
|
|
np.uint8, np.uint16, np.uint32, np.uint64,
|
|
|
|
dtypes.bfloat16, np.float16, np.float32, np.float64,
|
|
|
|
np.complex64, np.complex128,
|
2020-11-20 13:55:12 -05:00
|
|
|
np.longlong, np.intc}
|
2019-05-08 20:32:24 -04:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
for t in array_types:
|
|
|
|
core.pytype_aval_mappings[t] = ConcreteArray
|
2019-01-06 11:59:33 -08:00
|
|
|
ad_util.jaxval_zeros_likers[t] = zeros_like_array
|
2019-04-23 17:47:28 -07:00
|
|
|
|
|
|
|
|
2019-04-25 10:43:50 -07:00
|
|
|
def zeros_like_shaped_array(aval):
|
|
|
|
assert isinstance(aval, ShapedArray)
|
2020-09-24 16:29:57 +01:00
|
|
|
if aval.dtype == dtypes.float0:
|
|
|
|
return np.zeros(aval.shape, dtypes.float0)
|
2020-05-29 09:43:08 +00:00
|
|
|
return np.broadcast_to(np.array(0, aval.dtype), aval.shape)
|
2019-04-25 10:43:50 -07:00
|
|
|
|
|
|
|
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
|
2019-05-08 10:04:32 -07:00
|
|
|
|
2019-06-19 10:32:55 -07:00
|
|
|
core.literalable_types.update(array_types)
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
|
2020-05-07 01:46:13 -04:00
|
|
|
def _zeros_like_python_scalar(t, x):
|
2020-05-12 21:37:05 -03:00
|
|
|
return np.array(0, dtypes.python_scalar_dtypes[t])
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
|
2020-05-07 01:46:13 -04:00
|
|
|
def _make_concrete_python_scalar(t, x):
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return ConcreteArray(
|
2020-05-12 21:37:05 -03:00
|
|
|
np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
weak_type=True)
|
|
|
|
|
2020-09-17 21:51:18 +05:30
|
|
|
for t in dtypes.python_scalar_dtypes:
|
2020-05-07 01:46:13 -04:00
|
|
|
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
|
|
|
|
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
|
|
|
|
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
|