mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add preliminary support for np.complex128.
Only lightly tested.
This commit is contained in:
parent
385ab96206
commit
d43c65dcd8
@ -162,8 +162,8 @@ def zeros_like_array(x):
|
||||
return onp.broadcast_to(onp.array(0, dtype), onp.shape(x))
|
||||
|
||||
array_types = [onp.ndarray, onp.float64, onp.float32, onp.complex64,
|
||||
onp.int64, onp.int32, onp.bool_, onp.uint64, onp.uint32,
|
||||
complex, float, int, bool]
|
||||
onp.complex128, onp.int64, onp.int32, onp.bool_, onp.uint64,
|
||||
onp.uint32, complex, float, int, bool]
|
||||
|
||||
for t in array_types:
|
||||
core.pytype_aval_mappings[t] = ConcreteArray
|
||||
|
18
jax/lax.py
18
jax/lax.py
@ -706,7 +706,6 @@ _input_dtype = lambda *args, **_: xla_bridge.canonicalize_dtype(args[0].dtype)
|
||||
_fixed_dtype = lambda dtype: lambda *args, **kwargs: xla_bridge.canonicalize_dtype(dtype)
|
||||
_complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype
|
||||
|
||||
|
||||
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
||||
prim = Primitive(name)
|
||||
prim.def_impl(partial(xla.apply_primitive, prim))
|
||||
@ -819,7 +818,8 @@ def _brcast_to(x, shape):
|
||||
|
||||
_f32 = {onp.float32}
|
||||
_float = {onp.floating}
|
||||
_complex = {onp.complex64}
|
||||
_complex = {onp.complex}
|
||||
_complex_elem_types = {onp.float32, onp.float64}
|
||||
_int = {onp.integer}
|
||||
_bool = {onp.bool_}
|
||||
|
||||
@ -885,16 +885,18 @@ erf_inv_p = standard_unop(_float, 'erf_inv')
|
||||
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, onp.sqrt(onp.pi) / 2.),
|
||||
mul(g, exp(square(ans)))))
|
||||
|
||||
real_p = unop(_fixed_dtype(onp.float32), _complex, 'real')
|
||||
ad.deflinear(real_p, lambda t: [complex(t, onp.zeros((), onp.float32))])
|
||||
real_p = unop(_complex_basetype, _complex, 'real')
|
||||
ad.deflinear(real_p, lambda t: [complex(t, onp.zeros((), _dtype(t)))])
|
||||
|
||||
imag_p = unop(_fixed_dtype(onp.float32), _complex, 'imag')
|
||||
ad.deflinear(imag_p, lambda t: [complex(onp.zeros((), onp.float32), t)])
|
||||
imag_p = unop(_complex_basetype, _complex, 'imag')
|
||||
ad.deflinear(imag_p, lambda t: [complex(onp.zeros((), _dtype(t)), t)])
|
||||
|
||||
complex_p = binop(_fixed_dtype(onp.complex64), [_f32, _f32], 'complex')
|
||||
_complex_dtype = lambda dtype, *args: (onp.zeros((), dtype) + onp.zeros((), onp.complex64)).dtype
|
||||
complex_p = binop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
|
||||
'complex')
|
||||
ad.deflinear(complex_p, lambda t: [real(t), imag(t)])
|
||||
|
||||
conj_p = unop(_fixed_dtype(onp.complex64), _float | _complex, 'conj')
|
||||
conj_p = unop(_complex_dtype, _float | _complex, 'conj')
|
||||
|
||||
def conj_transpose_rule(t, x, input_dtype):
|
||||
assert x is None
|
||||
|
@ -193,6 +193,7 @@ _etype_to_dtype = {
|
||||
xla_data_pb2.F32: onp.dtype('float32'),
|
||||
xla_data_pb2.F64: onp.dtype('float64'),
|
||||
xla_data_pb2.C64: onp.dtype('complex64'),
|
||||
xla_data_pb2.C128: onp.dtype('complex128'),
|
||||
}
|
||||
|
||||
# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
|
||||
@ -220,10 +221,6 @@ def canonicalize_dtype(dtype):
|
||||
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
|
||||
dtype = onp.dtype(dtype)
|
||||
|
||||
# special rule for complex128, which XLA doesn't support
|
||||
if dtype == onp.complex128:
|
||||
dtype = onp.dtype('complex64')
|
||||
|
||||
if FLAGS.jax_enable_x64:
|
||||
return str(dtype)
|
||||
else:
|
||||
|
@ -99,10 +99,10 @@ int16 = onp.int16
|
||||
int32 = onp.int32
|
||||
int64 = onp.int64
|
||||
float16 = onp.float16
|
||||
float32 = onp.float32
|
||||
float64 = onp.float64
|
||||
complex64 = onp.complex64
|
||||
complex128 = onp.complex128
|
||||
float32 = single = onp.float32
|
||||
float64 = double = onp.float64
|
||||
complex64 = csingle = onp.complex64
|
||||
complex128 = cdouble = onp.complex128
|
||||
|
||||
flexible = onp.flexible
|
||||
character = onp.character
|
||||
|
@ -45,7 +45,7 @@ nonempty_shapes = scalar_shapes + nonempty_array_shapes
|
||||
all_shapes = scalar_shapes + array_shapes
|
||||
|
||||
float_dtypes = [onp.float32, onp.float64]
|
||||
complex_dtypes = [onp.complex64]
|
||||
complex_dtypes = [onp.complex64, onp.complex128]
|
||||
int_dtypes = [onp.int32, onp.int64]
|
||||
unsigned_dtypes = [onp.uint32, onp.uint64]
|
||||
bool_dtypes = [onp.bool_]
|
||||
|
@ -52,7 +52,7 @@ def num_float_bits(dtype):
|
||||
# arguments of appropriate shapes and dtypes using the following table.
|
||||
|
||||
float_dtypes = [onp.float32, onp.float64]
|
||||
complex_dtypes = [onp.complex64]
|
||||
complex_dtypes = [onp.complex64, onp.complex128]
|
||||
int_dtypes = [onp.int32, onp.int64]
|
||||
bool_dtypes = [onp.bool_]
|
||||
default_dtypes = float_dtypes + int_dtypes
|
||||
|
Loading…
x
Reference in New Issue
Block a user