2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as onp
|
|
|
|
|
|
|
|
from . import core
|
|
|
|
from . import ad_util
|
2019-11-15 10:02:51 -05:00
|
|
|
from . import dtypes
|
2019-11-22 10:53:11 -08:00
|
|
|
from . util import prod, partialmethod
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-01-27 15:44:33 -08:00
|
|
|
def concretization_err_msg(fun, context=None):
|
2018-11-17 18:03:33 -08:00
|
|
|
fname = getattr(fun, "__name__", fun)
|
2020-01-27 15:44:33 -08:00
|
|
|
if context is None:
|
|
|
|
context = ("The function to be transformed can't be traced at the required level "
|
|
|
|
"of abstraction. If using `jit`, try using `static_argnums` or "
|
|
|
|
"applying `jit` to smaller subfunctions instead.")
|
|
|
|
msg = "Abstract value passed to `{}`, which requires a concrete value. {}"
|
|
|
|
return msg.format(fname, context)
|
|
|
|
|
|
|
|
def concretization_function_error(fun, context=None):
|
2018-11-17 18:03:33 -08:00
|
|
|
def error(self, *args):
|
2020-01-27 15:44:33 -08:00
|
|
|
raise TypeError(concretization_err_msg(fun, context))
|
2018-11-17 18:03:33 -08:00
|
|
|
return error
|
|
|
|
|
|
|
|
|
|
|
|
class UnshapedArray(core.AbstractValue):
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
__slots__ = ['dtype', 'weak_type']
|
|
|
|
array_abstraction_level = 2
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def __init__(self, dtype, weak_type=False):
|
2019-11-15 10:02:51 -05:00
|
|
|
self.dtype = onp.dtype(dtypes.canonicalize_dtype(dtype))
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
self.weak_type = weak_type
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __eq__(self, other):
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return (type(self) is type(other) and self.dtype == other.dtype and
|
|
|
|
self.weak_type == other.weak_type)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-04-23 09:15:16 -07:00
|
|
|
def __ne__(self, other):
|
|
|
|
return not self == other
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def __hash__(self):
|
2019-04-15 07:45:10 -07:00
|
|
|
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
|
|
|
# objects, e.g. `onp.zeros(3).dtype is onp.zeros(4).dtype`, or we can use
|
|
|
|
# the unique character code via hash(self.dtype.char)
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return hash((self.dtype, self.weak_type))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __repr__(self):
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return '{}({}{})'.format(self.__class__.__name__, self.str_short(),
|
|
|
|
", weak_type=True" if self.weak_type else "")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
_bool = _nonzero = concretization_function_error(bool)
|
2020-01-27 15:44:33 -08:00
|
|
|
_float = concretization_function_error(
|
|
|
|
float, "Try using `value.astype(float)` instead.")
|
|
|
|
_int = concretization_function_error(
|
|
|
|
int, "Try using `value.astype(int)` instead.")
|
|
|
|
_complex = concretization_function_error(
|
|
|
|
complex, "Try using `value.astype(complex)` instead.")
|
2018-11-17 18:03:33 -08:00
|
|
|
_hex = concretization_function_error(hex)
|
|
|
|
_oct = concretization_function_error(oct)
|
|
|
|
|
|
|
|
def at_least_vspace(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def join(self, other):
|
2019-03-02 18:08:34 -08:00
|
|
|
if self.dtype == other.dtype:
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
if self.weak_type == other.weak_type:
|
|
|
|
return self
|
|
|
|
else:
|
|
|
|
return UnshapedArray(self.dtype, weak_type=False)
|
2019-03-02 18:08:34 -08:00
|
|
|
else:
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
raise TypeError(self, other)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def str_short(self):
|
2019-04-15 07:45:10 -07:00
|
|
|
return self.dtype.name
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def strip_weak_type(self):
|
|
|
|
"""Returns a copy of the aval with weak_type=False."""
|
|
|
|
return UnshapedArray(self.dtype) if self.weak_type else self
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
class ShapedArray(UnshapedArray):
|
|
|
|
__slots__ = ['shape']
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
array_abstraction_level = 1
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def __init__(self, shape, dtype, weak_type=False):
|
|
|
|
super(ShapedArray, self).__init__(dtype, weak_type=weak_type)
|
2018-11-17 18:03:33 -08:00
|
|
|
self.shape = shape
|
|
|
|
|
|
|
|
ndim = property(lambda self: len(self.shape))
|
2018-11-26 18:50:27 -08:00
|
|
|
size = property(lambda self: prod(self.shape))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(self) is type(other)
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
and self.dtype == other.dtype and self.shape == other.shape
|
|
|
|
and self.weak_type == other.weak_type)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __hash__(self):
|
2019-04-15 07:45:10 -07:00
|
|
|
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
|
|
|
# objects, e.g. `onp.zeros(3).dtype is onp.zeros(4).dtype`, or we can use
|
|
|
|
# the unique character code via hash(self.dtype.char)
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return hash((self.shape, self.dtype, self.weak_type))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def at_least_vspace(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def join(self, other):
|
|
|
|
if self.shape == other.shape and self.dtype == other.dtype:
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
if self.weak_type == other.weak_type:
|
|
|
|
return self
|
|
|
|
else:
|
|
|
|
return ShapedArray(self.shape, self.dtype, weak_type=False)
|
2018-11-17 18:03:33 -08:00
|
|
|
elif self.dtype == other.dtype:
|
|
|
|
return UnshapedArray(self.dtype)
|
|
|
|
else:
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
raise TypeError(self, other)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def str_short(self):
|
|
|
|
shapestr = ','.join(map(str, self.shape))
|
2019-04-15 07:45:10 -07:00
|
|
|
return '{}[{}]'.format(self.dtype.name, shapestr)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
try:
|
|
|
|
return self.shape[0]
|
|
|
|
except IndexError:
|
|
|
|
raise TypeError("len() of unsized object") # same as numpy error
|
|
|
|
|
2018-12-15 20:00:10 -08:00
|
|
|
def _len(self, ignored_tracer):
|
|
|
|
return len(self)
|
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def strip_weak_type(self):
|
|
|
|
return ShapedArray(self.shape, self.dtype) if self.weak_type else self
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
def _forward_to_value(self, fun, ignored_tracer, *args):
|
|
|
|
return fun(self.val, *args)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class ConcreteArray(ShapedArray):
|
|
|
|
__slots__ = ['val']
|
|
|
|
array_abstraction_level = 0
|
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def __init__(self, val, weak_type=False):
|
|
|
|
super(ConcreteArray, self).__init__(onp.shape(val), onp.result_type(val),
|
|
|
|
weak_type=weak_type)
|
|
|
|
# Note: canonicalized self.dtype doesn't necessarily match self.val
|
2018-11-17 18:03:33 -08:00
|
|
|
self.val = val
|
|
|
|
assert self.dtype != onp.dtype('O')
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(self) is type(other) and self.dtype == other.dtype
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
and self.shape == other.shape and self.weak_type == other.weak_type
|
|
|
|
and onp.all(self.val == other.val))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return id(self.val)
|
|
|
|
|
|
|
|
def at_least_vspace(self):
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return ShapedArray(self.shape, self.dtype, weak_type=self.weak_type)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def join(self, other):
|
|
|
|
if self == other:
|
|
|
|
return self
|
|
|
|
elif self.shape == other.shape and self.dtype == other.dtype:
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return ShapedArray(self.shape, self.dtype,
|
|
|
|
weak_type=self.weak_type and other.weak_type)
|
2018-11-17 18:03:33 -08:00
|
|
|
elif self.dtype == other.dtype:
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return UnshapedArray(self.dtype,
|
|
|
|
weak_type=self.weak_type and other.weak_type)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
raise TypeError(self, other)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def str_short(self):
|
|
|
|
return str(self.val)
|
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def strip_weak_type(self):
|
|
|
|
return ConcreteArray(self.val) if self.weak_type else self
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
_bool = _nonzero = partialmethod(_forward_to_value, bool)
|
|
|
|
_float = partialmethod(_forward_to_value, float)
|
|
|
|
_int = partialmethod(_forward_to_value, int)
|
|
|
|
_complex = partialmethod(_forward_to_value, complex)
|
|
|
|
_hex = partialmethod(_forward_to_value, hex)
|
|
|
|
_oct = partialmethod(_forward_to_value, oct)
|
|
|
|
|
2019-10-09 15:05:54 -04:00
|
|
|
class AbstractToken(core.AbstractValue): pass
|
|
|
|
|
|
|
|
abstract_token = AbstractToken()
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def make_shaped_array(x):
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
|
2018-11-17 18:03:33 -08:00
|
|
|
return ShapedArray(onp.shape(x), dtype)
|
|
|
|
|
2019-01-06 11:59:33 -08:00
|
|
|
def zeros_like_array(x):
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
|
2019-01-06 11:59:33 -08:00
|
|
|
return onp.broadcast_to(onp.array(0, dtype), onp.shape(x))
|
|
|
|
|
2019-11-20 22:43:46 -05:00
|
|
|
array_types = {onp.ndarray, onp.bool_,
|
|
|
|
onp.int8, onp.int16, onp.int32, onp.int64,
|
|
|
|
onp.uint8, onp.uint16, onp.uint32, onp.uint64,
|
|
|
|
dtypes.bfloat16, onp.float16, onp.float32, onp.float64,
|
2019-02-17 17:18:20 -05:00
|
|
|
onp.complex64, onp.complex128,
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
onp.longlong}
|
2019-05-08 20:32:24 -04:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
for t in array_types:
|
|
|
|
core.pytype_aval_mappings[t] = ConcreteArray
|
2019-01-06 11:59:33 -08:00
|
|
|
ad_util.jaxval_zeros_likers[t] = zeros_like_array
|
2019-04-23 17:47:28 -07:00
|
|
|
|
|
|
|
|
2019-04-25 10:43:50 -07:00
|
|
|
def zeros_like_shaped_array(aval):
|
|
|
|
assert isinstance(aval, ShapedArray)
|
|
|
|
return onp.zeros(aval.shape, dtype=aval.dtype)
|
|
|
|
|
|
|
|
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
|
2019-05-08 10:04:32 -07:00
|
|
|
|
2019-12-11 20:46:05 -08:00
|
|
|
def raise_to_shaped(aval, weak_type=False):
|
2019-07-27 15:46:14 -07:00
|
|
|
if isinstance(aval, ShapedArray):
|
2019-12-11 20:46:05 -08:00
|
|
|
return ShapedArray(aval.shape, aval.dtype, weak_type=weak_type)
|
2019-07-27 15:46:14 -07:00
|
|
|
elif aval is core.abstract_unit:
|
|
|
|
return core.abstract_unit
|
2019-10-09 15:05:54 -04:00
|
|
|
elif aval is abstract_token:
|
|
|
|
return abstract_token
|
2019-04-23 17:47:28 -07:00
|
|
|
else:
|
|
|
|
raise TypeError(type(aval))
|
2019-05-28 22:38:06 -07:00
|
|
|
|
2019-06-19 10:32:55 -07:00
|
|
|
core.literalable_types.update(array_types)
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
|
|
|
|
def _zeros_like_python_scalar(x):
|
|
|
|
return onp.array(0, dtypes.python_scalar_dtypes[type(x)])
|
|
|
|
|
|
|
|
def _make_concrete_python_scalar(x):
|
|
|
|
return ConcreteArray(
|
|
|
|
onp.array(x, dtype=dtypes.python_scalar_dtypes[type(x)]),
|
|
|
|
weak_type=True)
|
|
|
|
|
|
|
|
for t in dtypes.python_scalar_dtypes.keys():
|
|
|
|
core.pytype_aval_mappings[t] = _make_concrete_python_scalar
|
|
|
|
ad_util.jaxval_zeros_likers[t] = _zeros_like_python_scalar
|
|
|
|
|
|
|
|
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
|