Add preliminary support for np.complex128.

Only lightly tested.
This commit is contained in:
Peter Hawkins 2019-01-11 18:22:43 -05:00
parent 385ab96206
commit d43c65dcd8
6 changed files with 19 additions and 20 deletions

View File

@ -162,8 +162,8 @@ def zeros_like_array(x):
return onp.broadcast_to(onp.array(0, dtype), onp.shape(x))
array_types = [onp.ndarray, onp.float64, onp.float32, onp.complex64,
onp.int64, onp.int32, onp.bool_, onp.uint64, onp.uint32,
complex, float, int, bool]
onp.complex128, onp.int64, onp.int32, onp.bool_, onp.uint64,
onp.uint32, complex, float, int, bool]
for t in array_types:
core.pytype_aval_mappings[t] = ConcreteArray

View File

@ -706,7 +706,6 @@ _input_dtype = lambda *args, **_: xla_bridge.canonicalize_dtype(args[0].dtype)
_fixed_dtype = lambda dtype: lambda *args, **kwargs: xla_bridge.canonicalize_dtype(dtype)
_complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
prim = Primitive(name)
prim.def_impl(partial(xla.apply_primitive, prim))
@ -819,7 +818,8 @@ def _brcast_to(x, shape):
_f32 = {onp.float32}
_float = {onp.floating}
_complex = {onp.complex64}
_complex = {onp.complex}
_complex_elem_types = {onp.float32, onp.float64}
_int = {onp.integer}
_bool = {onp.bool_}
@ -885,16 +885,18 @@ erf_inv_p = standard_unop(_float, 'erf_inv')
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, onp.sqrt(onp.pi) / 2.),
mul(g, exp(square(ans)))))
real_p = unop(_fixed_dtype(onp.float32), _complex, 'real')
ad.deflinear(real_p, lambda t: [complex(t, onp.zeros((), onp.float32))])
real_p = unop(_complex_basetype, _complex, 'real')
ad.deflinear(real_p, lambda t: [complex(t, onp.zeros((), _dtype(t)))])
imag_p = unop(_fixed_dtype(onp.float32), _complex, 'imag')
ad.deflinear(imag_p, lambda t: [complex(onp.zeros((), onp.float32), t)])
imag_p = unop(_complex_basetype, _complex, 'imag')
ad.deflinear(imag_p, lambda t: [complex(onp.zeros((), _dtype(t)), t)])
complex_p = binop(_fixed_dtype(onp.complex64), [_f32, _f32], 'complex')
_complex_dtype = lambda dtype, *args: (onp.zeros((), dtype) + onp.zeros((), onp.complex64)).dtype
complex_p = binop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
'complex')
ad.deflinear(complex_p, lambda t: [real(t), imag(t)])
conj_p = unop(_fixed_dtype(onp.complex64), _float | _complex, 'conj')
conj_p = unop(_complex_dtype, _float | _complex, 'conj')
def conj_transpose_rule(t, x, input_dtype):
assert x is None

View File

@ -193,6 +193,7 @@ _etype_to_dtype = {
xla_data_pb2.F32: onp.dtype('float32'),
xla_data_pb2.F64: onp.dtype('float64'),
xla_data_pb2.C64: onp.dtype('complex64'),
xla_data_pb2.C128: onp.dtype('complex128'),
}
# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
@ -220,10 +221,6 @@ def canonicalize_dtype(dtype):
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
dtype = onp.dtype(dtype)
# special rule for complex128, which XLA doesn't support
if dtype == onp.complex128:
dtype = onp.dtype('complex64')
if FLAGS.jax_enable_x64:
return str(dtype)
else:

View File

@ -99,10 +99,10 @@ int16 = onp.int16
int32 = onp.int32
int64 = onp.int64
float16 = onp.float16
float32 = onp.float32
float64 = onp.float64
complex64 = onp.complex64
complex128 = onp.complex128
float32 = single = onp.float32
float64 = double = onp.float64
complex64 = csingle = onp.complex64
complex128 = cdouble = onp.complex128
flexible = onp.flexible
character = onp.character

View File

@ -45,7 +45,7 @@ nonempty_shapes = scalar_shapes + nonempty_array_shapes
all_shapes = scalar_shapes + array_shapes
float_dtypes = [onp.float32, onp.float64]
complex_dtypes = [onp.complex64]
complex_dtypes = [onp.complex64, onp.complex128]
int_dtypes = [onp.int32, onp.int64]
unsigned_dtypes = [onp.uint32, onp.uint64]
bool_dtypes = [onp.bool_]

View File

@ -52,7 +52,7 @@ def num_float_bits(dtype):
# arguments of appropriate shapes and dtypes using the following table.
float_dtypes = [onp.float32, onp.float64]
complex_dtypes = [onp.complex64]
complex_dtypes = [onp.complex64, onp.complex128]
int_dtypes = [onp.int32, onp.int64]
bool_dtypes = [onp.bool_]
default_dtypes = float_dtypes + int_dtypes