mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add PEP484-compatible export for jax and its subpackages
This commit is contained in:
parent
836bfb297e
commit
245581411e
127
jax/__init__.py
127
jax/__init__.py
@ -38,58 +38,65 @@ from . import config as _config_module
|
||||
del _config_module
|
||||
|
||||
from ._src.config import (
|
||||
config, enable_checks, check_tracer_leaks, checking_leaks, enable_custom_prng,
|
||||
debug_nans, debug_infs, log_compiles, default_matmul_precision,
|
||||
numpy_rank_promotion
|
||||
config as config,
|
||||
enable_checks as enable_checks,
|
||||
check_tracer_leaks as check_tracer_leaks,
|
||||
checking_leaks as checking_leaks,
|
||||
enable_custom_prng as enable_custom_prng,
|
||||
debug_nans as debug_nans,
|
||||
debug_infs as debug_infs,
|
||||
log_compiles as log_compiles,
|
||||
default_matmul_precision as default_matmul_precision,
|
||||
numpy_rank_promotion as numpy_rank_promotion,
|
||||
)
|
||||
from ._src.api import (
|
||||
ad, # TODO(phawkins): update users to avoid this.
|
||||
checkpoint,
|
||||
checkpoint_policies,
|
||||
closure_convert,
|
||||
checkpoint as checkpoint,
|
||||
checkpoint_policies as checkpoint_policies,
|
||||
closure_convert as closure_convert,
|
||||
curry, # TODO(phawkins): update users to avoid this.
|
||||
custom_ivjp,
|
||||
custom_gradient,
|
||||
custom_jvp,
|
||||
custom_vjp,
|
||||
default_backend,
|
||||
device_count,
|
||||
device_get,
|
||||
device_put,
|
||||
device_put_sharded,
|
||||
device_put_replicated,
|
||||
devices,
|
||||
disable_jit,
|
||||
eval_shape,
|
||||
custom_ivjp as custom_ivjp,
|
||||
custom_gradient as custom_gradient,
|
||||
custom_jvp as custom_jvp,
|
||||
custom_vjp as custom_vjp,
|
||||
default_backend as default_backend,
|
||||
device_count as device_count,
|
||||
device_get as device_get,
|
||||
device_put as device_put,
|
||||
device_put_sharded as device_put_sharded,
|
||||
device_put_replicated as device_put_replicated,
|
||||
devices as devices,
|
||||
disable_jit as disable_jit,
|
||||
eval_shape as eval_shape,
|
||||
flatten_fun_nokwargs, # TODO(phawkins): update users to avoid this.
|
||||
float0,
|
||||
grad,
|
||||
hessian,
|
||||
host_count,
|
||||
host_id,
|
||||
host_ids,
|
||||
invertible,
|
||||
jacobian,
|
||||
jacfwd,
|
||||
jacrev,
|
||||
jit,
|
||||
jvp,
|
||||
local_device_count,
|
||||
local_devices,
|
||||
linearize,
|
||||
linear_transpose,
|
||||
make_jaxpr,
|
||||
mask,
|
||||
named_call,
|
||||
float0 as float0,
|
||||
grad as grad,
|
||||
hessian as hessian,
|
||||
host_count as host_count,
|
||||
host_id as host_id,
|
||||
host_ids as host_ids,
|
||||
invertible as invertible,
|
||||
jacobian as jacobian,
|
||||
jacfwd as jacfwd,
|
||||
jacrev as jacrev,
|
||||
jit as jit,
|
||||
jvp as jvp,
|
||||
local_device_count as local_device_count,
|
||||
local_devices as local_devices,
|
||||
linearize as linearize,
|
||||
linear_transpose as linear_transpose,
|
||||
make_jaxpr as make_jaxpr,
|
||||
mask as mask,
|
||||
named_call as named_call,
|
||||
partial, # TODO(phawkins): update callers to use functools.partial.
|
||||
pmap,
|
||||
process_count,
|
||||
process_index,
|
||||
pmap as pmap,
|
||||
process_count as process_count,
|
||||
process_index as process_index,
|
||||
pxla, # TODO(phawkins): update users to avoid this.
|
||||
remat,
|
||||
shapecheck,
|
||||
ShapedArray,
|
||||
ShapeDtypeStruct,
|
||||
remat as remat,
|
||||
shapecheck as shapecheck,
|
||||
ShapedArray as ShapedArray,
|
||||
ShapeDtypeStruct as ShapeDtypeStruct,
|
||||
# TODO(phawkins): hide tree* functions from jax, update callers to use
|
||||
# jax.tree_util.
|
||||
treedef_is_leaf,
|
||||
@ -106,26 +113,26 @@ from ._src.api import (
|
||||
xla, # TODO(phawkins): update users to avoid this.
|
||||
xla_computation,
|
||||
)
|
||||
from .experimental.maps import soft_pmap
|
||||
from .version import __version__
|
||||
from .experimental.maps import soft_pmap as soft_pmap
|
||||
from .version import __version__ as __version__
|
||||
|
||||
# These submodules are separate because they are in an import cycle with
|
||||
# jax and rely on the names imported above.
|
||||
from . import abstract_arrays
|
||||
from . import api
|
||||
from . import api_util
|
||||
from . import dtypes
|
||||
from . import errors
|
||||
from . import image
|
||||
from . import lax
|
||||
from . import nn
|
||||
from . import profiler
|
||||
from . import random
|
||||
from . import tree_util
|
||||
from . import util
|
||||
from . import abstract_arrays as abstract_arrays
|
||||
from . import api as api
|
||||
from . import api_util as api_util
|
||||
from . import dtypes as dtypes
|
||||
from . import errors as errors
|
||||
from . import image as image
|
||||
from . import lax as lax
|
||||
from . import nn as nn
|
||||
from . import profiler as profiler
|
||||
from . import random as random
|
||||
from . import tree_util as tree_util
|
||||
from . import util as util
|
||||
|
||||
def _init():
|
||||
from . import numpy # side-effecting import sets up operator overloads
|
||||
from . import numpy as numpy # side-effecting import sets up operator overloads
|
||||
|
||||
_init()
|
||||
del _init
|
||||
|
@ -17,13 +17,13 @@ from jax._src.custom_derivatives import (
|
||||
_initial_style_jaxpr,
|
||||
_sum_tangents,
|
||||
_zeros_like_pytree,
|
||||
closure_convert,
|
||||
custom_gradient,
|
||||
custom_jvp,
|
||||
custom_jvp_call_p,
|
||||
custom_jvp_call_jaxpr_p,
|
||||
custom_vjp,
|
||||
custom_vjp_call_p,
|
||||
custom_vjp_call_jaxpr_p,
|
||||
linear_call,
|
||||
closure_convert as closure_convert,
|
||||
custom_gradient as custom_gradient,
|
||||
custom_jvp as custom_jvp,
|
||||
custom_jvp_call_p as custom_jvp_call_p,
|
||||
custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p,
|
||||
custom_vjp as custom_vjp,
|
||||
custom_vjp_call_p as custom_vjp_call_p,
|
||||
custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p,
|
||||
linear_call as linear_call,
|
||||
)
|
||||
|
@ -15,12 +15,12 @@
|
||||
# flake8: noqa: F401
|
||||
from jax._src.dtypes import (
|
||||
_jax_types, # TODO(phawkins): fix users and remove?
|
||||
bfloat16,
|
||||
canonicalize_dtype,
|
||||
bfloat16 as bfloat16,
|
||||
canonicalize_dtype as canonicalize_dtype,
|
||||
finfo, # TODO(phawkins): switch callers to jnp.finfo?
|
||||
float0,
|
||||
float0 as float0,
|
||||
iinfo, # TODO(phawkins): switch callers to jnp.iinfo?
|
||||
issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype?
|
||||
result_type,
|
||||
scalar_type_of,
|
||||
result_type as result_type,
|
||||
scalar_type_of as scalar_type_of,
|
||||
)
|
||||
|
@ -13,10 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from ._src.errors import (JAXTypeError,
|
||||
JAXIndexError,
|
||||
ConcretizationTypeError,
|
||||
NonConcreteBooleanIndexError,
|
||||
TracerArrayConversionError,
|
||||
TracerIntegerConversionError,
|
||||
UnexpectedTracerError)
|
||||
from ._src.errors import (
|
||||
JAXTypeError as JAXTypeError,
|
||||
JAXIndexError as JAXIndexError,
|
||||
ConcretizationTypeError as ConcretizationTypeError,
|
||||
NonConcreteBooleanIndexError as NonConcreteBooleanIndexError,
|
||||
TracerArrayConversionError as TracerArrayConversionError,
|
||||
TracerIntegerConversionError as TracerIntegerConversionError,
|
||||
UnexpectedTracerError as UnexpectedTracerError,
|
||||
)
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.image.scale import (
|
||||
resize,
|
||||
ResizeMethod,
|
||||
scale_and_translate,
|
||||
resize as resize,
|
||||
ResizeMethod as ResizeMethod,
|
||||
scale_and_translate as scale_and_translate,
|
||||
)
|
||||
|
@ -14,287 +14,287 @@
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.lax.lax import (
|
||||
ConvDimensionNumbers,
|
||||
ConvGeneralDilatedDimensionNumbers,
|
||||
DotDimensionNumbers,
|
||||
GatherDimensionNumbers,
|
||||
Precision,
|
||||
RandomAlgorithm,
|
||||
RoundingMethod,
|
||||
ScatterDimensionNumbers,
|
||||
abs,
|
||||
abs_p,
|
||||
acos,
|
||||
acos_p,
|
||||
acosh,
|
||||
acosh_p,
|
||||
abs,
|
||||
abs_p,
|
||||
acos,
|
||||
acosh,
|
||||
acosh_p,
|
||||
add,
|
||||
add_p,
|
||||
after_all,
|
||||
after_all_p,
|
||||
and_p,
|
||||
argmax,
|
||||
argmax_p,
|
||||
argmin,
|
||||
argmin_p,
|
||||
asin,
|
||||
asin_p,
|
||||
asinh,
|
||||
asinh_p,
|
||||
atan,
|
||||
atan_p,
|
||||
atan2,
|
||||
atan2_p,
|
||||
atanh,
|
||||
atanh_p,
|
||||
batch_matmul,
|
||||
bessel_i0e,
|
||||
bessel_i0e_p,
|
||||
bessel_i1e,
|
||||
bessel_i1e_p,
|
||||
betainc,
|
||||
bitcast_convert_type,
|
||||
bitcast_convert_type_p,
|
||||
bitwise_and,
|
||||
bitwise_not,
|
||||
bitwise_or,
|
||||
bitwise_xor,
|
||||
broadcast,
|
||||
broadcast_in_dim,
|
||||
broadcast_in_dim_p,
|
||||
broadcast_shapes,
|
||||
broadcast_to_rank,
|
||||
broadcasted_iota,
|
||||
cbrt,
|
||||
cbrt_p,
|
||||
ceil,
|
||||
ceil_p,
|
||||
clamp,
|
||||
clamp_p,
|
||||
clz,
|
||||
clz_p,
|
||||
collapse,
|
||||
complex,
|
||||
complex_p,
|
||||
concatenate,
|
||||
concatenate_p,
|
||||
conj,
|
||||
conj_p,
|
||||
conv,
|
||||
conv_dimension_numbers,
|
||||
conv_general_dilated,
|
||||
conv_general_dilated_p,
|
||||
conv_general_permutations,
|
||||
conv_general_shape_tuple,
|
||||
conv_shape_tuple,
|
||||
conv_transpose,
|
||||
conv_transpose_shape_tuple,
|
||||
conv_with_general_padding,
|
||||
convert_element_type,
|
||||
_convert_element_type,
|
||||
convert_element_type_p,
|
||||
cos,
|
||||
cos_p,
|
||||
cosh,
|
||||
cosh_p,
|
||||
create_token,
|
||||
create_token_p,
|
||||
digamma,
|
||||
digamma_p,
|
||||
div,
|
||||
div_p,
|
||||
dot,
|
||||
dot_general,
|
||||
dot_general_p,
|
||||
dtype,
|
||||
dtypes,
|
||||
dynamic_index_in_dim,
|
||||
dynamic_slice,
|
||||
dynamic_slice_in_dim,
|
||||
dynamic_slice_p,
|
||||
dynamic_update_index_in_dim,
|
||||
dynamic_update_slice,
|
||||
dynamic_update_slice_in_dim,
|
||||
dynamic_update_slice_p,
|
||||
eq,
|
||||
eq_p,
|
||||
erf,
|
||||
erf_inv,
|
||||
erf_inv_p,
|
||||
erf_p,
|
||||
erfc,
|
||||
erfc_p,
|
||||
exp,
|
||||
exp_p,
|
||||
expand_dims,
|
||||
expm1,
|
||||
expm1_p,
|
||||
floor,
|
||||
floor_p,
|
||||
full,
|
||||
full_like,
|
||||
gather,
|
||||
gather_p,
|
||||
ge,
|
||||
ge_p,
|
||||
gt,
|
||||
gt_p,
|
||||
igamma,
|
||||
igamma_grad_a,
|
||||
igamma_grad_a_p,
|
||||
igamma_p,
|
||||
igammac,
|
||||
igammac_p,
|
||||
imag,
|
||||
imag_p,
|
||||
index_in_dim,
|
||||
index_take,
|
||||
infeed,
|
||||
infeed_p,
|
||||
integer_pow,
|
||||
integer_pow_p,
|
||||
iota,
|
||||
iota_p,
|
||||
is_finite,
|
||||
is_finite_p,
|
||||
itertools,
|
||||
le,
|
||||
le_p,
|
||||
lgamma,
|
||||
lgamma_p,
|
||||
log,
|
||||
log1p,
|
||||
log1p_p,
|
||||
log_p,
|
||||
lt,
|
||||
lt_p,
|
||||
max,
|
||||
max_p,
|
||||
min,
|
||||
min_p,
|
||||
mul,
|
||||
mul_p,
|
||||
naryop,
|
||||
naryop_dtype_rule,
|
||||
ne,
|
||||
ne_p,
|
||||
neg,
|
||||
neg_p,
|
||||
nextafter,
|
||||
nextafter_p,
|
||||
not_p,
|
||||
or_p,
|
||||
outfeed,
|
||||
outfeed_p,
|
||||
pad,
|
||||
pad_p,
|
||||
padtype_to_pads,
|
||||
partial,
|
||||
population_count,
|
||||
population_count_p,
|
||||
pow,
|
||||
pow_p,
|
||||
prod,
|
||||
random_gamma_grad,
|
||||
random_gamma_grad_p,
|
||||
real,
|
||||
real_p,
|
||||
reciprocal,
|
||||
reduce,
|
||||
reduce_and_p,
|
||||
reduce_max_p,
|
||||
reduce_min_p,
|
||||
reduce_or_p,
|
||||
reduce_p,
|
||||
reduce_precision,
|
||||
reduce_precision_p,
|
||||
reduce_prod_p,
|
||||
reduce_sum_p,
|
||||
reduce_window,
|
||||
reduce_window_max_p,
|
||||
reduce_window_min_p,
|
||||
reduce_window_p,
|
||||
reduce_window_shape_tuple,
|
||||
reduce_window_sum_p,
|
||||
regularized_incomplete_beta_p,
|
||||
rem,
|
||||
rem_p,
|
||||
reshape,
|
||||
reshape_p,
|
||||
rev,
|
||||
rev_p,
|
||||
rng_bit_generator,
|
||||
rng_bit_generator_p,
|
||||
rng_uniform,
|
||||
rng_uniform_p,
|
||||
round,
|
||||
round_p,
|
||||
rsqrt,
|
||||
rsqrt_p,
|
||||
scatter,
|
||||
scatter_add,
|
||||
scatter_add_p,
|
||||
scatter_max,
|
||||
scatter_max_p,
|
||||
scatter_min,
|
||||
scatter_min_p,
|
||||
scatter_mul,
|
||||
scatter_mul_p,
|
||||
scatter_p,
|
||||
select,
|
||||
select_and_gather_add_p,
|
||||
select_and_scatter_add_p,
|
||||
select_and_scatter_p,
|
||||
select_p,
|
||||
shift_left,
|
||||
shift_left_p,
|
||||
shift_right_arithmetic,
|
||||
shift_right_arithmetic_p,
|
||||
shift_right_logical,
|
||||
shift_right_logical_p,
|
||||
sign,
|
||||
sign_p,
|
||||
sin,
|
||||
sin_p,
|
||||
sinh,
|
||||
sinh_p,
|
||||
slice,
|
||||
slice_in_dim,
|
||||
slice_p,
|
||||
sort,
|
||||
sort_key_val,
|
||||
sort_p,
|
||||
sqrt,
|
||||
sqrt_p,
|
||||
square,
|
||||
squeeze,
|
||||
squeeze_p,
|
||||
standard_abstract_eval,
|
||||
standard_naryop,
|
||||
standard_primitive,
|
||||
standard_translate,
|
||||
standard_unop,
|
||||
stop_gradient,
|
||||
sub,
|
||||
sub_p,
|
||||
tan,
|
||||
tan_p,
|
||||
tanh,
|
||||
tanh_p,
|
||||
tie_in,
|
||||
top_k,
|
||||
top_k_p,
|
||||
transpose,
|
||||
transpose_p,
|
||||
unop,
|
||||
unop_dtype_rule,
|
||||
xor_p,
|
||||
zeros_like_array,
|
||||
ConvDimensionNumbers as ConvDimensionNumbers,
|
||||
ConvGeneralDilatedDimensionNumbers as ConvGeneralDilatedDimensionNumbers,
|
||||
DotDimensionNumbers as DotDimensionNumbers,
|
||||
GatherDimensionNumbers as GatherDimensionNumbers,
|
||||
Precision as Precision,
|
||||
RandomAlgorithm as RandomAlgorithm,
|
||||
RoundingMethod as RoundingMethod,
|
||||
ScatterDimensionNumbers as ScatterDimensionNumbers,
|
||||
abs as abs,
|
||||
abs_p as abs_p,
|
||||
acos as acos,
|
||||
acos_p as acos_p,
|
||||
acosh as acosh,
|
||||
acosh_p as acosh_p,
|
||||
abs as abs,
|
||||
abs_p as abs_p,
|
||||
acos as acos,
|
||||
acosh as acosh,
|
||||
acosh_p as acosh_p,
|
||||
add as add,
|
||||
add_p as add_p,
|
||||
after_all as after_all,
|
||||
after_all_p as after_all_p,
|
||||
and_p as and_p,
|
||||
argmax as argmax,
|
||||
argmax_p as argmax_p,
|
||||
argmin as argmin,
|
||||
argmin_p as argmin_p,
|
||||
asin as asin,
|
||||
asin_p as asin_p,
|
||||
asinh as asinh,
|
||||
asinh_p as asinh_p,
|
||||
atan as atan,
|
||||
atan_p as atan_p,
|
||||
atan2 as atan2,
|
||||
atan2_p as atan2_p,
|
||||
atanh as atanh,
|
||||
atanh_p as atanh_p,
|
||||
batch_matmul as batch_matmul,
|
||||
bessel_i0e as bessel_i0e,
|
||||
bessel_i0e_p as bessel_i0e_p,
|
||||
bessel_i1e as bessel_i1e,
|
||||
bessel_i1e_p as bessel_i1e_p,
|
||||
betainc as betainc,
|
||||
bitcast_convert_type as bitcast_convert_type,
|
||||
bitcast_convert_type_p as bitcast_convert_type_p,
|
||||
bitwise_and as bitwise_and,
|
||||
bitwise_not as bitwise_not,
|
||||
bitwise_or as bitwise_or,
|
||||
bitwise_xor as bitwise_xor,
|
||||
broadcast as broadcast,
|
||||
broadcast_in_dim as broadcast_in_dim,
|
||||
broadcast_in_dim_p as broadcast_in_dim_p,
|
||||
broadcast_shapes as broadcast_shapes,
|
||||
broadcast_to_rank as broadcast_to_rank,
|
||||
broadcasted_iota as broadcasted_iota,
|
||||
cbrt as cbrt,
|
||||
cbrt_p as cbrt_p,
|
||||
ceil as ceil,
|
||||
ceil_p as ceil_p,
|
||||
clamp as clamp,
|
||||
clamp_p as clamp_p,
|
||||
clz as clz,
|
||||
clz_p as clz_p,
|
||||
collapse as collapse,
|
||||
complex as complex,
|
||||
complex_p as complex_p,
|
||||
concatenate as concatenate,
|
||||
concatenate_p as concatenate_p,
|
||||
conj as conj,
|
||||
conj_p as conj_p,
|
||||
conv as conv,
|
||||
conv_dimension_numbers as conv_dimension_numbers,
|
||||
conv_general_dilated as conv_general_dilated,
|
||||
conv_general_dilated_p as conv_general_dilated_p,
|
||||
conv_general_permutations as conv_general_permutations,
|
||||
conv_general_shape_tuple as conv_general_shape_tuple,
|
||||
conv_shape_tuple as conv_shape_tuple,
|
||||
conv_transpose as conv_transpose,
|
||||
conv_transpose_shape_tuple as conv_transpose_shape_tuple,
|
||||
conv_with_general_padding as conv_with_general_padding,
|
||||
convert_element_type as convert_element_type,
|
||||
_convert_element_type as _convert_element_type,
|
||||
convert_element_type_p as convert_element_type_p,
|
||||
cos as cos,
|
||||
cos_p as cos_p,
|
||||
cosh as cosh,
|
||||
cosh_p as cosh_p,
|
||||
create_token as create_token,
|
||||
create_token_p as create_token_p,
|
||||
digamma as digamma,
|
||||
digamma_p as digamma_p,
|
||||
div as div,
|
||||
div_p as div_p,
|
||||
dot as dot,
|
||||
dot_general as dot_general,
|
||||
dot_general_p as dot_general_p,
|
||||
dtype as dtype,
|
||||
dtypes as dtypes,
|
||||
dynamic_index_in_dim as dynamic_index_in_dim,
|
||||
dynamic_slice as dynamic_slice,
|
||||
dynamic_slice_in_dim as dynamic_slice_in_dim,
|
||||
dynamic_slice_p as dynamic_slice_p,
|
||||
dynamic_update_index_in_dim as dynamic_update_index_in_dim,
|
||||
dynamic_update_slice as dynamic_update_slice,
|
||||
dynamic_update_slice_in_dim as dynamic_update_slice_in_dim,
|
||||
dynamic_update_slice_p as dynamic_update_slice_p,
|
||||
eq as eq,
|
||||
eq_p as eq_p,
|
||||
erf as erf,
|
||||
erf_inv as erf_inv,
|
||||
erf_inv_p as erf_inv_p,
|
||||
erf_p as erf_p,
|
||||
erfc as erfc,
|
||||
erfc_p as erfc_p,
|
||||
exp as exp,
|
||||
exp_p as exp_p,
|
||||
expand_dims as expand_dims,
|
||||
expm1 as expm1,
|
||||
expm1_p as expm1_p,
|
||||
floor as floor,
|
||||
floor_p as floor_p,
|
||||
full as full,
|
||||
full_like as full_like,
|
||||
gather as gather,
|
||||
gather_p as gather_p,
|
||||
ge as ge,
|
||||
ge_p as ge_p,
|
||||
gt as gt,
|
||||
gt_p as gt_p,
|
||||
igamma as igamma,
|
||||
igamma_grad_a as igamma_grad_a,
|
||||
igamma_grad_a_p as igamma_grad_a_p,
|
||||
igamma_p as igamma_p,
|
||||
igammac as igammac,
|
||||
igammac_p as igammac_p,
|
||||
imag as imag,
|
||||
imag_p as imag_p,
|
||||
index_in_dim as index_in_dim,
|
||||
index_take as index_take,
|
||||
infeed as infeed,
|
||||
infeed_p as infeed_p,
|
||||
integer_pow as integer_pow,
|
||||
integer_pow_p as integer_pow_p,
|
||||
iota as iota,
|
||||
iota_p as iota_p,
|
||||
is_finite as is_finite,
|
||||
is_finite_p as is_finite_p,
|
||||
itertools as itertools,
|
||||
le as le,
|
||||
le_p as le_p,
|
||||
lgamma as lgamma,
|
||||
lgamma_p as lgamma_p,
|
||||
log as log,
|
||||
log1p as log1p,
|
||||
log1p_p as log1p_p,
|
||||
log_p as log_p,
|
||||
lt as lt,
|
||||
lt_p as lt_p,
|
||||
max as max,
|
||||
max_p as max_p,
|
||||
min as min,
|
||||
min_p as min_p,
|
||||
mul as mul,
|
||||
mul_p as mul_p,
|
||||
naryop as naryop,
|
||||
naryop_dtype_rule as naryop_dtype_rule,
|
||||
ne as ne,
|
||||
ne_p as ne_p,
|
||||
neg as neg,
|
||||
neg_p as neg_p,
|
||||
nextafter as nextafter,
|
||||
nextafter_p as nextafter_p,
|
||||
not_p as not_p,
|
||||
or_p as or_p,
|
||||
outfeed as outfeed,
|
||||
outfeed_p as outfeed_p,
|
||||
pad as pad,
|
||||
pad_p as pad_p,
|
||||
padtype_to_pads as padtype_to_pads,
|
||||
partial as partial,
|
||||
population_count as population_count,
|
||||
population_count_p as population_count_p,
|
||||
pow as pow,
|
||||
pow_p as pow_p,
|
||||
prod as prod,
|
||||
random_gamma_grad as random_gamma_grad,
|
||||
random_gamma_grad_p as random_gamma_grad_p,
|
||||
real as real,
|
||||
real_p as real_p,
|
||||
reciprocal as reciprocal,
|
||||
reduce as reduce,
|
||||
reduce_and_p as reduce_and_p,
|
||||
reduce_max_p as reduce_max_p,
|
||||
reduce_min_p as reduce_min_p,
|
||||
reduce_or_p as reduce_or_p,
|
||||
reduce_p as reduce_p,
|
||||
reduce_precision as reduce_precision,
|
||||
reduce_precision_p as reduce_precision_p,
|
||||
reduce_prod_p as reduce_prod_p,
|
||||
reduce_sum_p as reduce_sum_p,
|
||||
reduce_window as reduce_window,
|
||||
reduce_window_max_p as reduce_window_max_p,
|
||||
reduce_window_min_p as reduce_window_min_p,
|
||||
reduce_window_p as reduce_window_p,
|
||||
reduce_window_shape_tuple as reduce_window_shape_tuple,
|
||||
reduce_window_sum_p as reduce_window_sum_p,
|
||||
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
|
||||
rem as rem,
|
||||
rem_p as rem_p,
|
||||
reshape as reshape,
|
||||
reshape_p as reshape_p,
|
||||
rev as rev,
|
||||
rev_p as rev_p,
|
||||
rng_bit_generator as rng_bit_generator,
|
||||
rng_bit_generator_p as rng_bit_generator_p,
|
||||
rng_uniform as rng_uniform,
|
||||
rng_uniform_p as rng_uniform_p,
|
||||
round as round,
|
||||
round_p as round_p,
|
||||
rsqrt as rsqrt,
|
||||
rsqrt_p as rsqrt_p,
|
||||
scatter as scatter,
|
||||
scatter_add as scatter_add,
|
||||
scatter_add_p as scatter_add_p,
|
||||
scatter_max as scatter_max,
|
||||
scatter_max_p as scatter_max_p,
|
||||
scatter_min as scatter_min,
|
||||
scatter_min_p as scatter_min_p,
|
||||
scatter_mul as scatter_mul,
|
||||
scatter_mul_p as scatter_mul_p,
|
||||
scatter_p as scatter_p,
|
||||
select as select,
|
||||
select_and_gather_add_p as select_and_gather_add_p,
|
||||
select_and_scatter_add_p as select_and_scatter_add_p,
|
||||
select_and_scatter_p as select_and_scatter_p,
|
||||
select_p as select_p,
|
||||
shift_left as shift_left,
|
||||
shift_left_p as shift_left_p,
|
||||
shift_right_arithmetic as shift_right_arithmetic,
|
||||
shift_right_arithmetic_p as shift_right_arithmetic_p,
|
||||
shift_right_logical as shift_right_logical,
|
||||
shift_right_logical_p as shift_right_logical_p,
|
||||
sign as sign,
|
||||
sign_p as sign_p,
|
||||
sin as sin,
|
||||
sin_p as sin_p,
|
||||
sinh as sinh,
|
||||
sinh_p as sinh_p,
|
||||
slice as slice,
|
||||
slice_in_dim as slice_in_dim,
|
||||
slice_p as slice_p,
|
||||
sort as sort,
|
||||
sort_key_val as sort_key_val,
|
||||
sort_p as sort_p,
|
||||
sqrt as sqrt,
|
||||
sqrt_p as sqrt_p,
|
||||
square as square,
|
||||
squeeze as squeeze,
|
||||
squeeze_p as squeeze_p,
|
||||
standard_abstract_eval as standard_abstract_eval,
|
||||
standard_naryop as standard_naryop,
|
||||
standard_primitive as standard_primitive,
|
||||
standard_translate as standard_translate,
|
||||
standard_unop as standard_unop,
|
||||
stop_gradient as stop_gradient,
|
||||
sub as sub,
|
||||
sub_p as sub_p,
|
||||
tan as tan,
|
||||
tan_p as tan_p,
|
||||
tanh as tanh,
|
||||
tanh_p as tanh_p,
|
||||
tie_in as tie_in,
|
||||
top_k as top_k,
|
||||
top_k_p as top_k_p,
|
||||
transpose as transpose,
|
||||
transpose_p as transpose_p,
|
||||
unop as unop,
|
||||
unop_dtype_rule as unop_dtype_rule,
|
||||
xor_p as xor_p,
|
||||
zeros_like_array as zeros_like_array,
|
||||
)
|
||||
from jax._src.lax.lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
|
||||
_reduce_and, _reduce_window_sum, _reduce_window_max,
|
||||
@ -306,56 +306,56 @@ from jax._src.lax.lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
|
||||
_upcast_fp16_for_computation, _broadcasting_shape_rule,
|
||||
_eye, _tri, _delta, _ones, _zeros, _dilate_shape)
|
||||
from jax._src.lax.control_flow import (
|
||||
associative_scan,
|
||||
cond,
|
||||
cond_p,
|
||||
cummax,
|
||||
cummax_p,
|
||||
cummin,
|
||||
cummin_p,
|
||||
cumprod,
|
||||
cumprod_p,
|
||||
cumsum,
|
||||
cumsum_p,
|
||||
custom_linear_solve,
|
||||
custom_root,
|
||||
fori_loop,
|
||||
linear_solve_p,
|
||||
map,
|
||||
scan,
|
||||
scan_bind,
|
||||
scan_p,
|
||||
switch,
|
||||
while_loop,
|
||||
while_p,
|
||||
associative_scan as associative_scan,
|
||||
cond as cond,
|
||||
cond_p as cond_p,
|
||||
cummax as cummax,
|
||||
cummax_p as cummax_p,
|
||||
cummin as cummin,
|
||||
cummin_p as cummin_p,
|
||||
cumprod as cumprod,
|
||||
cumprod_p as cumprod_p,
|
||||
cumsum as cumsum,
|
||||
cumsum_p as cumsum_p,
|
||||
custom_linear_solve as custom_linear_solve,
|
||||
custom_root as custom_root,
|
||||
fori_loop as fori_loop,
|
||||
linear_solve_p as linear_solve_p,
|
||||
map as map,
|
||||
scan as scan,
|
||||
scan_bind as scan_bind,
|
||||
scan_p as scan_p,
|
||||
switch as switch,
|
||||
while_loop as while_loop,
|
||||
while_p as while_p,
|
||||
)
|
||||
from jax._src.lax.fft import (
|
||||
fft,
|
||||
fft_p,
|
||||
fft as fft,
|
||||
fft_p as fft_p,
|
||||
)
|
||||
from jax._src.lax.parallel import (
|
||||
all_gather,
|
||||
all_to_all,
|
||||
all_to_all_p,
|
||||
axis_index,
|
||||
axis_index_p,
|
||||
pmax,
|
||||
pmax_p,
|
||||
pmean,
|
||||
pmin,
|
||||
pmin_p,
|
||||
ppermute,
|
||||
ppermute_p,
|
||||
pshuffle,
|
||||
psum,
|
||||
psum_p,
|
||||
psum_scatter,
|
||||
pswapaxes,
|
||||
pdot,
|
||||
xeinsum,
|
||||
all_gather as all_gather,
|
||||
all_to_all as all_to_all,
|
||||
all_to_all_p as all_to_all_p,
|
||||
axis_index as axis_index,
|
||||
axis_index_p as axis_index_p,
|
||||
pmax as pmax,
|
||||
pmax_p as pmax_p,
|
||||
pmean as pmean,
|
||||
pmin as pmin,
|
||||
pmin_p as pmin_p,
|
||||
ppermute as ppermute,
|
||||
ppermute_p as ppermute_p,
|
||||
pshuffle as pshuffle,
|
||||
psum as psum,
|
||||
psum_p as psum_p,
|
||||
psum_scatter as psum_scatter,
|
||||
pswapaxes as pswapaxes,
|
||||
pdot as pdot,
|
||||
xeinsum as xeinsum,
|
||||
)
|
||||
from jax._src.lax.other import (
|
||||
conv_general_dilated_patches
|
||||
conv_general_dilated_patches as conv_general_dilated_patches
|
||||
)
|
||||
from jax._src.ad_util import stop_gradient_p
|
||||
from . import linalg
|
||||
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
|
||||
from . import linalg as linalg
|
||||
|
@ -22,7 +22,7 @@ from typing import Optional
|
||||
|
||||
__all__ = [
|
||||
'cuda_linalg', 'cuda_prng', 'cusolver', 'rocsolver', 'jaxlib', 'lapack',
|
||||
'pocketfft', 'pytree', 'tpu_client', 'version', 'xla_client'
|
||||
'pocketfft', 'pytree', 'tpu_driver_client', 'version', 'xla_client'
|
||||
]
|
||||
|
||||
# First, before attempting to import jaxlib, warn about experimental machine
|
||||
|
@ -19,27 +19,27 @@
|
||||
from jax.numpy import tanh
|
||||
from . import initializers
|
||||
from jax._src.nn.functions import (
|
||||
celu,
|
||||
elu,
|
||||
gelu,
|
||||
glu,
|
||||
hard_sigmoid,
|
||||
hard_silu,
|
||||
hard_swish,
|
||||
hard_tanh,
|
||||
leaky_relu,
|
||||
log_sigmoid,
|
||||
log_softmax,
|
||||
logsumexp,
|
||||
normalize,
|
||||
one_hot,
|
||||
relu,
|
||||
relu6,
|
||||
selu,
|
||||
sigmoid,
|
||||
soft_sign,
|
||||
softmax,
|
||||
softplus,
|
||||
silu,
|
||||
swish,
|
||||
celu as celu,
|
||||
elu as elu,
|
||||
gelu as gelu,
|
||||
glu as glu,
|
||||
hard_sigmoid as hard_sigmoid,
|
||||
hard_silu as hard_silu,
|
||||
hard_swish as hard_swish,
|
||||
hard_tanh as hard_tanh,
|
||||
leaky_relu as leaky_relu,
|
||||
log_sigmoid as log_sigmoid,
|
||||
log_softmax as log_softmax,
|
||||
logsumexp as logsumexp,
|
||||
normalize as normalize,
|
||||
one_hot as one_hot,
|
||||
relu as relu,
|
||||
relu6 as relu6,
|
||||
selu as selu,
|
||||
sigmoid as sigmoid,
|
||||
soft_sign as soft_sign,
|
||||
softmax as softmax,
|
||||
softplus as softplus,
|
||||
silu as silu,
|
||||
swish as swish,
|
||||
)
|
||||
|
@ -19,21 +19,21 @@ used in Keras and Sonnet.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.nn.initializers import (
|
||||
delta_orthogonal,
|
||||
glorot_normal,
|
||||
glorot_uniform,
|
||||
he_normal,
|
||||
he_uniform,
|
||||
kaiming_normal,
|
||||
kaiming_uniform,
|
||||
lecun_normal,
|
||||
lecun_uniform,
|
||||
normal,
|
||||
ones,
|
||||
orthogonal,
|
||||
uniform,
|
||||
variance_scaling,
|
||||
xavier_normal,
|
||||
xavier_uniform,
|
||||
zeros,
|
||||
delta_orthogonal as delta_orthogonal,
|
||||
glorot_normal as glorot_normal,
|
||||
glorot_uniform as glorot_uniform,
|
||||
he_normal as he_normal,
|
||||
he_uniform as he_uniform,
|
||||
kaiming_normal as kaiming_normal,
|
||||
kaiming_uniform as kaiming_uniform,
|
||||
lecun_normal as lecun_normal,
|
||||
lecun_uniform as lecun_uniform,
|
||||
normal as normal,
|
||||
ones as ones,
|
||||
orthogonal as orthogonal,
|
||||
uniform as uniform,
|
||||
variance_scaling as variance_scaling,
|
||||
xavier_normal as xavier_normal,
|
||||
xavier_uniform as xavier_uniform,
|
||||
zeros as zeros,
|
||||
)
|
||||
|
@ -12,59 +12,379 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from . import fft
|
||||
from . import linalg
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax.interpreters.xla import DeviceArray
|
||||
# flake8: noqa: F401
|
||||
from . import fft as fft
|
||||
from . import linalg as linalg
|
||||
|
||||
from jax.interpreters.xla import DeviceArray as DeviceArray
|
||||
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
ComplexWarning, NINF, NZERO, PZERO, abs, absolute, add, all, allclose,
|
||||
alltrue, amax, amin, angle, any, append,
|
||||
apply_along_axis, apply_over_axes, arange, arccos, arccosh, arcsin,
|
||||
arcsinh, arctan, arctan2, arctanh, argmax, argmin, argsort, argwhere, around,
|
||||
array, array_equal, array_equiv, array_repr, array_split, array_str, asarray, atleast_1d, atleast_2d,
|
||||
atleast_3d, average, bartlett, bfloat16, bincount, bitwise_and, bitwise_not,
|
||||
bitwise_or, bitwise_xor, blackman, block, bool_, broadcast_arrays, broadcast_shapes,
|
||||
broadcast_to, c_, can_cast, cbrt, cdouble, ceil, character, choose, clip, column_stack,
|
||||
complex128, complex64, complex_, complexfloating, compress, concatenate,
|
||||
conj, conjugate, convolve, copysign, corrcoef, correlate, cos, cosh,
|
||||
count_nonzero, cov, cross, csingle, cumprod, cumproduct, cumsum, deg2rad, degrees,
|
||||
delete, diag, diagflat, diag_indices, diag_indices_from, diagonal, diff, digitize, divide, divmod, dot,
|
||||
double, dsplit, dstack, dtype, e, ediff1d, einsum, einsum_path, empty,
|
||||
empty_like, equal, euler_gamma, exp, exp2, expand_dims, expm1, extract, eye,
|
||||
fabs, finfo, fix, flatnonzero, flexible, flip, fliplr, flipud, float16, float32,
|
||||
float64, float_, float_power, floating, floor, floor_divide, fmax, fmin,
|
||||
fmod, frexp, full, full_like, gcd, geomspace, gradient, greater,
|
||||
greater_equal, hamming, hanning, heaviside, histogram, histogram_bin_edges, histogram2d, histogramdd,
|
||||
hsplit, hstack, hypot, i0, identity, iinfo, imag, index_exp,
|
||||
indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer,
|
||||
interp, intersect1d, invert,
|
||||
isclose, iscomplex, iscomplexobj, isfinite, isin, isinf, isnan, isneginf,
|
||||
isposinf, isreal, isrealobj, isscalar, issubdtype, issubsctype, iterable,
|
||||
ix_, kaiser, kron, lcm, ldexp, left_shift, less, less_equal, lexsort, linspace,
|
||||
load, log, log10, log1p, log2, logaddexp, logaddexp2, logical_and,
|
||||
logical_not, logical_or, logical_xor, logspace, mask_indices, matmul, max,
|
||||
maximum, mean, median, meshgrid, mgrid, min, minimum, mod, modf, moveaxis, msort,
|
||||
multiply, nan, nan_to_num, nanargmax, nanargmin, nancumprod, nancumsum,
|
||||
nanmedian, nanpercentile, nanquantile,
|
||||
nanmax, nanmean, nanmin, nanprod, nanstd, nansum, nanvar, ndarray, ndim,
|
||||
negative, newaxis, nextafter, nonzero, not_equal, number,
|
||||
object_, ogrid, ones, ones_like, outer, packbits, pad, percentile,
|
||||
pi, piecewise, poly, polyadd, polyder, polyfit, polyint, polymul, polysub, polyval, positive, power,
|
||||
prod, product, promote_types, ptp, quantile,
|
||||
r_, rad2deg, radians, ravel, ravel_multi_index, real, reciprocal, remainder, repeat, reshape,
|
||||
resize, result_type, right_shift, rint, roll, rollaxis, rot90, round, round_, row_stack,
|
||||
save, savez, searchsorted, select, set_printoptions, setdiff1d, setxor1d, shape, sign, signbit,
|
||||
s_, signedinteger, sin, sinc, single, sinh, size, sometrue, sort, sort_complex, split, sqrt,
|
||||
square, squeeze, stack, std, subtract, sum, swapaxes, take, take_along_axis,
|
||||
tan, tanh, tensordot, tile, trace, trapz, transpose, tri, tril, tril_indices, tril_indices_from,
|
||||
trim_zeros, triu, triu_indices, triu_indices_from, true_divide, trunc, uint16, uint32, uint64, uint8, unique,
|
||||
union1d, unpackbits, unravel_index, unsignedinteger, unwrap, vander, var, vdot, vsplit,
|
||||
vstack, where, zeros, zeros_like, _NOT_IMPLEMENTED)
|
||||
ComplexWarning as ComplexWarning,
|
||||
NINF as NINF,
|
||||
NZERO as NZERO,
|
||||
PZERO as PZERO,
|
||||
abs as abs,
|
||||
absolute as absolute,
|
||||
add as add,
|
||||
all as all,
|
||||
allclose as allclose,
|
||||
alltrue as alltrue,
|
||||
amax as amax,
|
||||
amin as amin,
|
||||
angle as angle,
|
||||
any as any,
|
||||
append as append,
|
||||
apply_along_axis as apply_along_axis,
|
||||
apply_over_axes as apply_over_axes,
|
||||
arange as arange,
|
||||
arccos as arccos,
|
||||
arccosh as arccosh,
|
||||
arcsin as arcsin,
|
||||
arcsinh as arcsinh,
|
||||
arctan as arctan,
|
||||
arctan2 as arctan2,
|
||||
arctanh as arctanh,
|
||||
argmax as argmax,
|
||||
argmin as argmin,
|
||||
argsort as argsort,
|
||||
argwhere as argwhere,
|
||||
around as around,
|
||||
array as array,
|
||||
array_equal as array_equal,
|
||||
array_equiv as array_equiv,
|
||||
array_repr as array_repr,
|
||||
array_split as array_split,
|
||||
array_str as array_str,
|
||||
asarray as asarray,
|
||||
atleast_1d as atleast_1d,
|
||||
atleast_2d as atleast_2d,
|
||||
atleast_3d as atleast_3d,
|
||||
average as average,
|
||||
bartlett as bartlett,
|
||||
bfloat16 as bfloat16,
|
||||
bincount as bincount,
|
||||
bitwise_and as bitwise_and,
|
||||
bitwise_not as bitwise_not,
|
||||
bitwise_or as bitwise_or,
|
||||
bitwise_xor as bitwise_xor,
|
||||
blackman as blackman,
|
||||
block as block,
|
||||
bool_ as bool_,
|
||||
broadcast_arrays as broadcast_arrays,
|
||||
broadcast_shapes as broadcast_shapes,
|
||||
broadcast_to as broadcast_to,
|
||||
c_ as c_,
|
||||
can_cast as can_cast,
|
||||
cbrt as cbrt,
|
||||
cdouble as cdouble,
|
||||
ceil as ceil,
|
||||
character as character,
|
||||
choose as choose,
|
||||
clip as clip,
|
||||
column_stack as column_stack,
|
||||
complex128 as complex128,
|
||||
complex64 as complex64,
|
||||
complex_ as complex_,
|
||||
complexfloating as complexfloating,
|
||||
compress as compress,
|
||||
concatenate as concatenate,
|
||||
conj as conj,
|
||||
conjugate as conjugate,
|
||||
convolve as convolve,
|
||||
copysign as copysign,
|
||||
corrcoef as corrcoef,
|
||||
correlate as correlate,
|
||||
cos as cos,
|
||||
cosh as cosh,
|
||||
count_nonzero as count_nonzero,
|
||||
cov as cov,
|
||||
cross as cross,
|
||||
csingle as csingle,
|
||||
cumprod as cumprod,
|
||||
cumproduct as cumproduct,
|
||||
cumsum as cumsum,
|
||||
deg2rad as deg2rad,
|
||||
degrees as degrees,
|
||||
delete as delete,
|
||||
diag as diag,
|
||||
diagflat as diagflat,
|
||||
diag_indices as diag_indices,
|
||||
diag_indices_from as diag_indices_from,
|
||||
diagonal as diagonal,
|
||||
diff as diff,
|
||||
digitize as digitize,
|
||||
divide as divide,
|
||||
divmod as divmod,
|
||||
dot as dot,
|
||||
double as double,
|
||||
dsplit as dsplit,
|
||||
dstack as dstack,
|
||||
dtype as dtype,
|
||||
e as e,
|
||||
ediff1d as ediff1d,
|
||||
einsum as einsum,
|
||||
einsum_path as einsum_path,
|
||||
empty as empty,
|
||||
empty_like as empty_like,
|
||||
equal as equal,
|
||||
euler_gamma as euler_gamma,
|
||||
exp as exp,
|
||||
exp2 as exp2,
|
||||
expand_dims as expand_dims,
|
||||
expm1 as expm1,
|
||||
extract as extract,
|
||||
eye as eye,
|
||||
fabs as fabs,
|
||||
finfo as finfo,
|
||||
fix as fix,
|
||||
flatnonzero as flatnonzero,
|
||||
flexible as flexible,
|
||||
flip as flip,
|
||||
fliplr as fliplr,
|
||||
flipud as flipud,
|
||||
float16 as float16,
|
||||
float32 as float32,
|
||||
float64 as float64,
|
||||
float_ as float_,
|
||||
float_power as float_power,
|
||||
floating as floating,
|
||||
floor as floor,
|
||||
floor_divide as floor_divide,
|
||||
fmax as fmax,
|
||||
fmin as fmin,
|
||||
fmod as fmod,
|
||||
frexp as frexp,
|
||||
full as full,
|
||||
full_like as full_like,
|
||||
gcd as gcd,
|
||||
geomspace as geomspace,
|
||||
gradient as gradient,
|
||||
greater as greater,
|
||||
greater_equal as greater_equal,
|
||||
hamming as hamming,
|
||||
hanning as hanning,
|
||||
heaviside as heaviside,
|
||||
histogram as histogram,
|
||||
histogram_bin_edges as histogram_bin_edges,
|
||||
histogram2d as histogram2d,
|
||||
histogramdd as histogramdd,
|
||||
hsplit as hsplit,
|
||||
hstack as hstack,
|
||||
hypot as hypot,
|
||||
i0 as i0,
|
||||
identity as identity,
|
||||
iinfo as iinfo,
|
||||
imag as imag,
|
||||
index_exp as index_exp,
|
||||
indices as indices,
|
||||
inexact as inexact,
|
||||
in1d as in1d,
|
||||
inf as inf,
|
||||
inner as inner,
|
||||
int16 as int16,
|
||||
int32 as int32,
|
||||
int64 as int64,
|
||||
int8 as int8,
|
||||
int_ as int_,
|
||||
integer as integer,
|
||||
interp as interp,
|
||||
intersect1d as intersect1d,
|
||||
invert as invert,
|
||||
isclose as isclose,
|
||||
iscomplex as iscomplex,
|
||||
iscomplexobj as iscomplexobj,
|
||||
isfinite as isfinite,
|
||||
isin as isin,
|
||||
isinf as isinf,
|
||||
isnan as isnan,
|
||||
isneginf as isneginf,
|
||||
isposinf as isposinf,
|
||||
isreal as isreal,
|
||||
isrealobj as isrealobj,
|
||||
isscalar as isscalar,
|
||||
issubdtype as issubdtype,
|
||||
issubsctype as issubsctype,
|
||||
iterable as iterable,
|
||||
ix_ as ix_,
|
||||
kaiser as kaiser,
|
||||
kron as kron,
|
||||
lcm as lcm,
|
||||
ldexp as ldexp,
|
||||
left_shift as left_shift,
|
||||
less as less,
|
||||
less_equal as less_equal,
|
||||
lexsort as lexsort,
|
||||
linspace as linspace,
|
||||
load as load,
|
||||
log as log,
|
||||
log10 as log10,
|
||||
log1p as log1p,
|
||||
log2 as log2,
|
||||
logaddexp as logaddexp,
|
||||
logaddexp2 as logaddexp2,
|
||||
logical_and as logical_and,
|
||||
logical_not as logical_not,
|
||||
logical_or as logical_or,
|
||||
logical_xor as logical_xor,
|
||||
logspace as logspace,
|
||||
mask_indices as mask_indices,
|
||||
matmul as matmul,
|
||||
max as max,
|
||||
maximum as maximum,
|
||||
mean as mean,
|
||||
median as median,
|
||||
meshgrid as meshgrid,
|
||||
mgrid as mgrid,
|
||||
min as min,
|
||||
minimum as minimum,
|
||||
mod as mod,
|
||||
modf as modf,
|
||||
moveaxis as moveaxis,
|
||||
msort as msort,
|
||||
multiply as multiply,
|
||||
nan as nan,
|
||||
nan_to_num as nan_to_num,
|
||||
nanargmax as nanargmax,
|
||||
nanargmin as nanargmin,
|
||||
nancumprod as nancumprod,
|
||||
nancumsum as nancumsum,
|
||||
nanmedian as nanmedian,
|
||||
nanpercentile as nanpercentile,
|
||||
nanquantile as nanquantile,
|
||||
nanmax as nanmax,
|
||||
nanmean as nanmean,
|
||||
nanmin as nanmin,
|
||||
nanprod as nanprod,
|
||||
nanstd as nanstd,
|
||||
nansum as nansum,
|
||||
nanvar as nanvar,
|
||||
ndarray as ndarray,
|
||||
ndim as ndim,
|
||||
negative as negative,
|
||||
newaxis as newaxis,
|
||||
nextafter as nextafter,
|
||||
nonzero as nonzero,
|
||||
not_equal as not_equal,
|
||||
number as number,
|
||||
object_ as object_,
|
||||
ogrid as ogrid,
|
||||
ones as ones,
|
||||
ones_like as ones_like,
|
||||
outer as outer,
|
||||
packbits as packbits,
|
||||
pad as pad,
|
||||
percentile as percentile,
|
||||
pi as pi,
|
||||
piecewise as piecewise,
|
||||
poly as poly,
|
||||
polyadd as polyadd,
|
||||
polyder as polyder,
|
||||
polyfit as polyfit,
|
||||
polyint as polyint,
|
||||
polymul as polymul,
|
||||
polysub as polysub,
|
||||
polyval as polyval,
|
||||
positive as positive,
|
||||
power as power,
|
||||
prod as prod,
|
||||
product as product,
|
||||
promote_types as promote_types,
|
||||
ptp as ptp,
|
||||
quantile as quantile,
|
||||
r_ as r_,
|
||||
rad2deg as rad2deg,
|
||||
radians as radians,
|
||||
ravel as ravel,
|
||||
ravel_multi_index as ravel_multi_index,
|
||||
real as real,
|
||||
reciprocal as reciprocal,
|
||||
remainder as remainder,
|
||||
repeat as repeat,
|
||||
reshape as reshape,
|
||||
resize as resize,
|
||||
result_type as result_type,
|
||||
right_shift as right_shift,
|
||||
rint as rint,
|
||||
roll as roll,
|
||||
rollaxis as rollaxis,
|
||||
rot90 as rot90,
|
||||
round as round,
|
||||
round_ as round_,
|
||||
row_stack as row_stack,
|
||||
save as save,
|
||||
savez as savez,
|
||||
searchsorted as searchsorted,
|
||||
select as select,
|
||||
set_printoptions as set_printoptions,
|
||||
setdiff1d as setdiff1d,
|
||||
setxor1d as setxor1d,
|
||||
shape as shape,
|
||||
sign as sign,
|
||||
signbit as signbit,
|
||||
s_ as s_,
|
||||
signedinteger as signedinteger,
|
||||
sin as sin,
|
||||
sinc as sinc,
|
||||
single as single,
|
||||
sinh as sinh,
|
||||
size as size,
|
||||
sometrue as sometrue,
|
||||
sort as sort,
|
||||
sort_complex as sort_complex,
|
||||
split as split,
|
||||
sqrt as sqrt,
|
||||
square as square,
|
||||
squeeze as squeeze,
|
||||
stack as stack,
|
||||
std as std,
|
||||
subtract as subtract,
|
||||
sum as sum,
|
||||
swapaxes as swapaxes,
|
||||
take as take,
|
||||
take_along_axis as take_along_axis,
|
||||
tan as tan,
|
||||
tanh as tanh,
|
||||
tensordot as tensordot,
|
||||
tile as tile,
|
||||
trace as trace,
|
||||
trapz as trapz,
|
||||
transpose as transpose,
|
||||
tri as tri,
|
||||
tril as tril,
|
||||
tril_indices as tril_indices,
|
||||
tril_indices_from as tril_indices_from,
|
||||
trim_zeros as trim_zeros,
|
||||
triu as triu,
|
||||
triu_indices as triu_indices,
|
||||
triu_indices_from as triu_indices_from,
|
||||
true_divide as true_divide,
|
||||
trunc as trunc,
|
||||
uint16 as uint16,
|
||||
uint32 as uint32,
|
||||
uint64 as uint64,
|
||||
uint8 as uint8,
|
||||
unique as unique,
|
||||
union1d as union1d,
|
||||
unpackbits as unpackbits,
|
||||
unravel_index as unravel_index,
|
||||
unsignedinteger as unsignedinteger,
|
||||
unwrap as unwrap,
|
||||
vander as vander,
|
||||
var as var,
|
||||
vdot as vdot,
|
||||
vsplit as vsplit,
|
||||
vstack as vstack,
|
||||
where as where,
|
||||
zeros as zeros,
|
||||
zeros_like as zeros_like,
|
||||
_NOT_IMPLEMENTED,
|
||||
)
|
||||
|
||||
from jax._src.numpy.polynomial import roots
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.numpy.polynomial import roots as roots
|
||||
from jax._src.numpy.vectorize import vectorize as vectorize
|
||||
|
||||
# TODO(phawkins): remove this import after fixing users.
|
||||
from jax._src.numpy import lax_numpy
|
||||
|
@ -15,24 +15,24 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.numpy.fft import (
|
||||
ifft,
|
||||
ifft2,
|
||||
ifftn,
|
||||
ifftshift,
|
||||
ihfft,
|
||||
irfft,
|
||||
irfft2,
|
||||
irfftn,
|
||||
fft,
|
||||
fft2,
|
||||
fftfreq,
|
||||
fftn,
|
||||
fftshift,
|
||||
hfft,
|
||||
rfft,
|
||||
rfft2,
|
||||
rfftfreq,
|
||||
rfftn,
|
||||
ifft as ifft,
|
||||
ifft2 as ifft2,
|
||||
ifftn as ifftn,
|
||||
ifftshift as ifftshift,
|
||||
ihfft as ihfft,
|
||||
irfft as irfft,
|
||||
irfft2 as irfft2,
|
||||
irfftn as irfftn,
|
||||
fft as fft,
|
||||
fft2 as fft2,
|
||||
fftfreq as fftfreq,
|
||||
fftn as fftn,
|
||||
fftshift as fftshift,
|
||||
hfft as hfft,
|
||||
rfft as rfft,
|
||||
rfft2 as rfft2,
|
||||
rfftfreq as rfftfreq,
|
||||
rfftn as rfftn,
|
||||
)
|
||||
|
||||
# Module initialization is encapsulated in a function to avoid accidental
|
||||
|
@ -15,28 +15,28 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.numpy.linalg import (
|
||||
cholesky,
|
||||
det,
|
||||
eig,
|
||||
eigh,
|
||||
eigvals,
|
||||
eigvalsh,
|
||||
inv,
|
||||
lstsq,
|
||||
matrix_power,
|
||||
matrix_rank,
|
||||
norm,
|
||||
pinv,
|
||||
qr,
|
||||
slogdet,
|
||||
solve,
|
||||
svd,
|
||||
cholesky as cholesky,
|
||||
det as det,
|
||||
eig as eig,
|
||||
eigh as eigh,
|
||||
eigvals as eigvals,
|
||||
eigvalsh as eigvalsh,
|
||||
inv as inv,
|
||||
lstsq as lstsq,
|
||||
matrix_power as matrix_power,
|
||||
matrix_rank as matrix_rank,
|
||||
norm as norm,
|
||||
pinv as pinv,
|
||||
qr as qr,
|
||||
slogdet as slogdet,
|
||||
solve as solve,
|
||||
svd as svd,
|
||||
)
|
||||
from jax._src.third_party.numpy.linalg import (
|
||||
cond,
|
||||
multi_dot,
|
||||
tensorinv,
|
||||
tensorsolve
|
||||
cond as cond,
|
||||
multi_dot as multi_dot,
|
||||
tensorinv as tensorinv,
|
||||
tensorsolve as tensorsolve,
|
||||
)
|
||||
|
||||
# Module initialization is encapsulated in a function to avoid accidental
|
||||
|
@ -14,6 +14,14 @@
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.ops.scatter import (
|
||||
index, index_add, index_mul, index_update, index_min, index_max,
|
||||
segment_sum, segment_prod, segment_min, segment_max,
|
||||
index as index,
|
||||
index_add as index_add,
|
||||
index_mul as index_mul,
|
||||
index_update as index_update,
|
||||
index_min as index_min,
|
||||
index_max as index_max,
|
||||
segment_sum as segment_sum,
|
||||
segment_prod as segment_prod,
|
||||
segment_min as segment_min,
|
||||
segment_max as segment_max,
|
||||
)
|
||||
|
10
jax/prng.py
10
jax/prng.py
@ -15,9 +15,9 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.prng import (
|
||||
PRNGImpl,
|
||||
seed_with_impl,
|
||||
threefry2x32_p,
|
||||
threefry_2x32,
|
||||
threefry_prng_impl,
|
||||
PRNGImpl as PRNGImpl,
|
||||
seed_with_impl as seed_with_impl,
|
||||
threefry2x32_p as threefry2x32_p,
|
||||
threefry_2x32 as threefry_2x32,
|
||||
threefry_prng_impl as threefry_prng_impl,
|
||||
)
|
||||
|
@ -14,16 +14,16 @@
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.profiler import (
|
||||
StepTraceAnnotation,
|
||||
StepTraceContext,
|
||||
TraceAnnotation,
|
||||
TraceContext,
|
||||
device_memory_profile,
|
||||
save_device_memory_profile,
|
||||
start_server,
|
||||
start_trace,
|
||||
stop_trace,
|
||||
trace,
|
||||
annotate_function,
|
||||
trace_function,
|
||||
StepTraceAnnotation as StepTraceAnnotation,
|
||||
StepTraceContext as StepTraceContext,
|
||||
TraceAnnotation as TraceAnnotation,
|
||||
TraceContext as TraceContext,
|
||||
device_memory_profile as device_memory_profile,
|
||||
save_device_memory_profile as save_device_memory_profile,
|
||||
start_server as start_server,
|
||||
start_trace as start_trace,
|
||||
stop_trace as stop_trace,
|
||||
trace as trace,
|
||||
annotate_function as annotate_function,
|
||||
trace_function as trace_function,
|
||||
)
|
||||
|
@ -79,35 +79,35 @@ for the design and its motivation.
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.random import (
|
||||
PRNGKey,
|
||||
bernoulli,
|
||||
beta,
|
||||
categorical,
|
||||
cauchy,
|
||||
choice,
|
||||
dirichlet,
|
||||
double_sided_maxwell,
|
||||
exponential,
|
||||
fold_in,
|
||||
gamma,
|
||||
gumbel,
|
||||
laplace,
|
||||
logistic,
|
||||
maxwell,
|
||||
multivariate_normal,
|
||||
normal,
|
||||
pareto,
|
||||
permutation,
|
||||
poisson,
|
||||
rademacher,
|
||||
randint,
|
||||
random_gamma_p,
|
||||
shuffle,
|
||||
split,
|
||||
t,
|
||||
threefry_2x32,
|
||||
threefry2x32_p,
|
||||
truncated_normal,
|
||||
uniform,
|
||||
weibull_min,
|
||||
PRNGKey as PRNGKey,
|
||||
bernoulli as bernoulli,
|
||||
beta as beta,
|
||||
categorical as categorical,
|
||||
cauchy as cauchy,
|
||||
choice as choice,
|
||||
dirichlet as dirichlet,
|
||||
double_sided_maxwell as double_sided_maxwell,
|
||||
exponential as exponential,
|
||||
fold_in as fold_in,
|
||||
gamma as gamma,
|
||||
gumbel as gumbel,
|
||||
laplace as laplace,
|
||||
logistic as logistic,
|
||||
maxwell as maxwell,
|
||||
multivariate_normal as multivariate_normal,
|
||||
normal as normal,
|
||||
pareto as pareto,
|
||||
permutation as permutation,
|
||||
poisson as poisson,
|
||||
rademacher as rademacher,
|
||||
randint as randint,
|
||||
random_gamma_p as random_gamma_p,
|
||||
shuffle as shuffle,
|
||||
split as split,
|
||||
t as t,
|
||||
threefry_2x32 as threefry_2x32,
|
||||
threefry2x32_p as threefry2x32_p,
|
||||
truncated_normal as truncated_normal,
|
||||
uniform as uniform,
|
||||
weibull_min as weibull_min,
|
||||
)
|
||||
|
@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from . import linalg
|
||||
from . import ndimage
|
||||
from . import signal
|
||||
from . import sparse
|
||||
from . import special
|
||||
from . import stats
|
||||
from . import fft
|
||||
from . import linalg as linalg
|
||||
from . import ndimage as ndimage
|
||||
from . import signal as signal
|
||||
from . import sparse as sparse
|
||||
from . import special as special
|
||||
from . import stats as stats
|
||||
from . import fft as fft
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.fft import (
|
||||
dct,
|
||||
dctn
|
||||
dct as dct,
|
||||
dctn as dctn,
|
||||
)
|
||||
|
@ -15,28 +15,28 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.linalg import (
|
||||
block_diag,
|
||||
cholesky,
|
||||
cho_factor,
|
||||
cho_solve,
|
||||
det,
|
||||
eigh,
|
||||
eigh_tridiagonal,
|
||||
expm,
|
||||
expm_frechet,
|
||||
inv,
|
||||
lu,
|
||||
lu_factor,
|
||||
lu_solve,
|
||||
polar,
|
||||
qr,
|
||||
solve,
|
||||
solve_triangular,
|
||||
svd,
|
||||
tril,
|
||||
triu,
|
||||
block_diag as block_diag,
|
||||
cholesky as cholesky,
|
||||
cho_factor as cho_factor,
|
||||
cho_solve as cho_solve,
|
||||
det as det,
|
||||
eigh as eigh,
|
||||
eigh_tridiagonal as eigh_tridiagonal,
|
||||
expm as expm,
|
||||
expm_frechet as expm_frechet,
|
||||
inv as inv,
|
||||
lu as lu,
|
||||
lu_factor as lu_factor,
|
||||
lu_solve as lu_solve,
|
||||
polar as polar,
|
||||
qr as qr,
|
||||
solve as solve,
|
||||
solve_triangular as solve_triangular,
|
||||
svd as svd,
|
||||
tril as tril,
|
||||
triu as triu,
|
||||
)
|
||||
|
||||
from jax._src.lax.polar import (
|
||||
polar_unitary
|
||||
polar_unitary as polar_unitary,
|
||||
)
|
||||
|
@ -14,4 +14,6 @@
|
||||
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.ndimage import map_coordinates
|
||||
from jax._src.scipy.ndimage import (
|
||||
map_coordinates as map_coordinates,
|
||||
)
|
||||
|
@ -14,4 +14,7 @@
|
||||
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.optimize.minimize import (minimize, OptimizeResults)
|
||||
from jax._src.scipy.optimize.minimize import (
|
||||
minimize as minimize,
|
||||
OptimizeResults as OptimizeResults,
|
||||
)
|
||||
|
@ -15,9 +15,9 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.signal import (
|
||||
convolve,
|
||||
convolve2d,
|
||||
correlate,
|
||||
correlate2d,
|
||||
detrend,
|
||||
convolve as convolve,
|
||||
convolve2d as convolve2d,
|
||||
correlate as correlate,
|
||||
correlate2d as correlate2d,
|
||||
detrend as detrend,
|
||||
)
|
||||
|
@ -13,4 +13,4 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from . import linalg
|
||||
from . import linalg as linalg
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.scipy.sparse.linalg import (
|
||||
cg,
|
||||
gmres,
|
||||
bicgstab
|
||||
cg as cg,
|
||||
gmres as gmres,
|
||||
bicgstab as bicgstab,
|
||||
)
|
||||
|
@ -15,35 +15,35 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.special import (
|
||||
betainc,
|
||||
betaln,
|
||||
digamma,
|
||||
entr,
|
||||
erf,
|
||||
erfc,
|
||||
erfinv,
|
||||
exp1,
|
||||
expi,
|
||||
expit,
|
||||
expn,
|
||||
gammainc,
|
||||
gammaincc,
|
||||
gammaln,
|
||||
i0,
|
||||
i0e,
|
||||
i1,
|
||||
i1e,
|
||||
logit,
|
||||
logsumexp,
|
||||
lpmn,
|
||||
lpmn_values,
|
||||
multigammaln,
|
||||
log_ndtr,
|
||||
ndtr,
|
||||
ndtri,
|
||||
polygamma,
|
||||
sph_harm,
|
||||
xlogy,
|
||||
xlog1py,
|
||||
zeta,
|
||||
betainc as betainc,
|
||||
betaln as betaln,
|
||||
digamma as digamma,
|
||||
entr as entr,
|
||||
erf as erf,
|
||||
erfc as erfc,
|
||||
erfinv as erfinv,
|
||||
exp1 as exp1,
|
||||
expi as expi,
|
||||
expit as expit,
|
||||
expn as expn,
|
||||
gammainc as gammainc,
|
||||
gammaincc as gammaincc,
|
||||
gammaln as gammaln,
|
||||
i0 as i0,
|
||||
i0e as i0e,
|
||||
i1 as i1,
|
||||
i1e as i1e,
|
||||
logit as logit,
|
||||
logsumexp as logsumexp,
|
||||
lpmn as lpmn,
|
||||
lpmn_values as lpmn_values,
|
||||
multigammaln as multigammaln,
|
||||
log_ndtr as log_ndtr,
|
||||
ndtr as ndtr,
|
||||
ndtri as ndtri,
|
||||
polygamma as polygamma,
|
||||
sph_harm as sph_harm,
|
||||
xlogy as xlogy,
|
||||
xlog1py as xlog1py,
|
||||
zeta as zeta,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.bernoulli import (
|
||||
logpmf,
|
||||
pmf,
|
||||
)
|
||||
logpmf as logpmf,
|
||||
pmf as pmf,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.beta import (
|
||||
logpdf,
|
||||
pdf,
|
||||
)
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.betabinom import (
|
||||
logpmf,
|
||||
pmf,
|
||||
logpmf as logpmf,
|
||||
pmf as pmf,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.cauchy import (
|
||||
logpdf,
|
||||
pdf,
|
||||
)
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.chi2 import (
|
||||
logpdf,
|
||||
pdf,
|
||||
)
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.dirichlet import (
|
||||
logpdf,
|
||||
pdf,
|
||||
)
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.expon import (
|
||||
logpdf,
|
||||
pdf,
|
||||
)
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.gamma import (
|
||||
logpdf,
|
||||
pdf,
|
||||
)
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.geom import (
|
||||
logpmf,
|
||||
pmf,
|
||||
)
|
||||
logpmf as logpmf,
|
||||
pmf as pmf,
|
||||
)
|
||||
|
@ -15,7 +15,7 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.laplace import (
|
||||
cdf,
|
||||
logpdf,
|
||||
pdf,
|
||||
)
|
||||
cdf as cdf,
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
||||
|
@ -15,10 +15,10 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.logistic import (
|
||||
cdf,
|
||||
isf,
|
||||
logpdf,
|
||||
pdf,
|
||||
ppf,
|
||||
sf,
|
||||
)
|
||||
cdf as cdf,
|
||||
isf as isf,
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
ppf as ppf,
|
||||
sf as sf,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.multivariate_normal import (
|
||||
logpdf,
|
||||
pdf,
|
||||
)
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
||||
|
@ -15,9 +15,9 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.norm import (
|
||||
cdf,
|
||||
logcdf,
|
||||
logpdf,
|
||||
pdf,
|
||||
ppf,
|
||||
)
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
ppf as ppf,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.pareto import (
|
||||
logpdf,
|
||||
pdf,
|
||||
)
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
||||
|
@ -15,7 +15,7 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.poisson import (
|
||||
logpmf,
|
||||
pmf,
|
||||
cdf
|
||||
logpmf as logpmf,
|
||||
pmf as pmf,
|
||||
cdf as cdf,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.t import (
|
||||
logpdf,
|
||||
pdf,
|
||||
)
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
||||
|
@ -15,6 +15,6 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.uniform import (
|
||||
logpdf,
|
||||
pdf,
|
||||
)
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
||||
|
@ -37,22 +37,22 @@ for examples.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.tree_util import (
|
||||
Partial,
|
||||
PyTreeDef,
|
||||
all_leaves,
|
||||
build_tree,
|
||||
register_pytree_node,
|
||||
register_pytree_node_class,
|
||||
tree_all,
|
||||
tree_flatten,
|
||||
tree_leaves,
|
||||
tree_map,
|
||||
tree_multimap,
|
||||
tree_reduce,
|
||||
tree_structure,
|
||||
tree_transpose,
|
||||
tree_unflatten,
|
||||
treedef_children,
|
||||
treedef_is_leaf,
|
||||
treedef_tuple,
|
||||
Partial as Partial,
|
||||
PyTreeDef as PyTreeDef,
|
||||
all_leaves as all_leaves,
|
||||
build_tree as build_tree,
|
||||
register_pytree_node as register_pytree_node,
|
||||
register_pytree_node_class as register_pytree_node_class,
|
||||
tree_all as tree_all,
|
||||
tree_flatten as tree_flatten,
|
||||
tree_leaves as tree_leaves,
|
||||
tree_map as tree_map,
|
||||
tree_multimap as tree_multimap,
|
||||
tree_reduce as tree_reduce,
|
||||
tree_structure as tree_structure,
|
||||
tree_transpose as tree_transpose,
|
||||
tree_unflatten as tree_unflatten,
|
||||
treedef_children as treedef_children,
|
||||
treedef_is_leaf as treedef_is_leaf,
|
||||
treedef_tuple as treedef_tuple,
|
||||
)
|
||||
|
28
jax/util.py
28
jax/util.py
@ -14,18 +14,18 @@
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.util import (
|
||||
HashableFunction,
|
||||
as_hashable_function,
|
||||
cache,
|
||||
partial,
|
||||
safe_map,
|
||||
safe_zip,
|
||||
split_dict,
|
||||
split_list,
|
||||
split_merge,
|
||||
subvals,
|
||||
toposort,
|
||||
unzip2,
|
||||
wrap_name,
|
||||
wraps,
|
||||
HashableFunction as HashableFunction,
|
||||
as_hashable_function as as_hashable_function,
|
||||
cache as cache,
|
||||
partial as partial,
|
||||
safe_map as safe_map,
|
||||
safe_zip as safe_zip,
|
||||
split_dict as split_dict,
|
||||
split_list as split_list,
|
||||
split_merge as split_merge,
|
||||
subvals as subvals,
|
||||
toposort as toposort,
|
||||
unzip2 as unzip2,
|
||||
wrap_name as wrap_name,
|
||||
wraps as wraps,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user