Add PEP484-compatible export for jax and its subpackages

This commit is contained in:
Jake VanderPlas 2021-08-30 14:35:22 -07:00
parent 836bfb297e
commit 245581411e
44 changed files with 1093 additions and 751 deletions

View File

@ -38,58 +38,65 @@ from . import config as _config_module
del _config_module
from ._src.config import (
config, enable_checks, check_tracer_leaks, checking_leaks, enable_custom_prng,
debug_nans, debug_infs, log_compiles, default_matmul_precision,
numpy_rank_promotion
config as config,
enable_checks as enable_checks,
check_tracer_leaks as check_tracer_leaks,
checking_leaks as checking_leaks,
enable_custom_prng as enable_custom_prng,
debug_nans as debug_nans,
debug_infs as debug_infs,
log_compiles as log_compiles,
default_matmul_precision as default_matmul_precision,
numpy_rank_promotion as numpy_rank_promotion,
)
from ._src.api import (
ad, # TODO(phawkins): update users to avoid this.
checkpoint,
checkpoint_policies,
closure_convert,
checkpoint as checkpoint,
checkpoint_policies as checkpoint_policies,
closure_convert as closure_convert,
curry, # TODO(phawkins): update users to avoid this.
custom_ivjp,
custom_gradient,
custom_jvp,
custom_vjp,
default_backend,
device_count,
device_get,
device_put,
device_put_sharded,
device_put_replicated,
devices,
disable_jit,
eval_shape,
custom_ivjp as custom_ivjp,
custom_gradient as custom_gradient,
custom_jvp as custom_jvp,
custom_vjp as custom_vjp,
default_backend as default_backend,
device_count as device_count,
device_get as device_get,
device_put as device_put,
device_put_sharded as device_put_sharded,
device_put_replicated as device_put_replicated,
devices as devices,
disable_jit as disable_jit,
eval_shape as eval_shape,
flatten_fun_nokwargs, # TODO(phawkins): update users to avoid this.
float0,
grad,
hessian,
host_count,
host_id,
host_ids,
invertible,
jacobian,
jacfwd,
jacrev,
jit,
jvp,
local_device_count,
local_devices,
linearize,
linear_transpose,
make_jaxpr,
mask,
named_call,
float0 as float0,
grad as grad,
hessian as hessian,
host_count as host_count,
host_id as host_id,
host_ids as host_ids,
invertible as invertible,
jacobian as jacobian,
jacfwd as jacfwd,
jacrev as jacrev,
jit as jit,
jvp as jvp,
local_device_count as local_device_count,
local_devices as local_devices,
linearize as linearize,
linear_transpose as linear_transpose,
make_jaxpr as make_jaxpr,
mask as mask,
named_call as named_call,
partial, # TODO(phawkins): update callers to use functools.partial.
pmap,
process_count,
process_index,
pmap as pmap,
process_count as process_count,
process_index as process_index,
pxla, # TODO(phawkins): update users to avoid this.
remat,
shapecheck,
ShapedArray,
ShapeDtypeStruct,
remat as remat,
shapecheck as shapecheck,
ShapedArray as ShapedArray,
ShapeDtypeStruct as ShapeDtypeStruct,
# TODO(phawkins): hide tree* functions from jax, update callers to use
# jax.tree_util.
treedef_is_leaf,
@ -106,26 +113,26 @@ from ._src.api import (
xla, # TODO(phawkins): update users to avoid this.
xla_computation,
)
from .experimental.maps import soft_pmap
from .version import __version__
from .experimental.maps import soft_pmap as soft_pmap
from .version import __version__ as __version__
# These submodules are separate because they are in an import cycle with
# jax and rely on the names imported above.
from . import abstract_arrays
from . import api
from . import api_util
from . import dtypes
from . import errors
from . import image
from . import lax
from . import nn
from . import profiler
from . import random
from . import tree_util
from . import util
from . import abstract_arrays as abstract_arrays
from . import api as api
from . import api_util as api_util
from . import dtypes as dtypes
from . import errors as errors
from . import image as image
from . import lax as lax
from . import nn as nn
from . import profiler as profiler
from . import random as random
from . import tree_util as tree_util
from . import util as util
def _init():
from . import numpy # side-effecting import sets up operator overloads
from . import numpy as numpy # side-effecting import sets up operator overloads
_init()
del _init

View File

@ -17,13 +17,13 @@ from jax._src.custom_derivatives import (
_initial_style_jaxpr,
_sum_tangents,
_zeros_like_pytree,
closure_convert,
custom_gradient,
custom_jvp,
custom_jvp_call_p,
custom_jvp_call_jaxpr_p,
custom_vjp,
custom_vjp_call_p,
custom_vjp_call_jaxpr_p,
linear_call,
closure_convert as closure_convert,
custom_gradient as custom_gradient,
custom_jvp as custom_jvp,
custom_jvp_call_p as custom_jvp_call_p,
custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p,
custom_vjp as custom_vjp,
custom_vjp_call_p as custom_vjp_call_p,
custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p,
linear_call as linear_call,
)

View File

@ -15,12 +15,12 @@
# flake8: noqa: F401
from jax._src.dtypes import (
_jax_types, # TODO(phawkins): fix users and remove?
bfloat16,
canonicalize_dtype,
bfloat16 as bfloat16,
canonicalize_dtype as canonicalize_dtype,
finfo, # TODO(phawkins): switch callers to jnp.finfo?
float0,
float0 as float0,
iinfo, # TODO(phawkins): switch callers to jnp.iinfo?
issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype?
result_type,
scalar_type_of,
result_type as result_type,
scalar_type_of as scalar_type_of,
)

View File

@ -13,10 +13,12 @@
# limitations under the License.
# flake8: noqa: F401
from ._src.errors import (JAXTypeError,
JAXIndexError,
ConcretizationTypeError,
NonConcreteBooleanIndexError,
TracerArrayConversionError,
TracerIntegerConversionError,
UnexpectedTracerError)
from ._src.errors import (
JAXTypeError as JAXTypeError,
JAXIndexError as JAXIndexError,
ConcretizationTypeError as ConcretizationTypeError,
NonConcreteBooleanIndexError as NonConcreteBooleanIndexError,
TracerArrayConversionError as TracerArrayConversionError,
TracerIntegerConversionError as TracerIntegerConversionError,
UnexpectedTracerError as UnexpectedTracerError,
)

View File

@ -16,7 +16,7 @@
# flake8: noqa: F401
from jax._src.image.scale import (
resize,
ResizeMethod,
scale_and_translate,
resize as resize,
ResizeMethod as ResizeMethod,
scale_and_translate as scale_and_translate,
)

View File

@ -14,287 +14,287 @@
# flake8: noqa: F401
from jax._src.lax.lax import (
ConvDimensionNumbers,
ConvGeneralDilatedDimensionNumbers,
DotDimensionNumbers,
GatherDimensionNumbers,
Precision,
RandomAlgorithm,
RoundingMethod,
ScatterDimensionNumbers,
abs,
abs_p,
acos,
acos_p,
acosh,
acosh_p,
abs,
abs_p,
acos,
acosh,
acosh_p,
add,
add_p,
after_all,
after_all_p,
and_p,
argmax,
argmax_p,
argmin,
argmin_p,
asin,
asin_p,
asinh,
asinh_p,
atan,
atan_p,
atan2,
atan2_p,
atanh,
atanh_p,
batch_matmul,
bessel_i0e,
bessel_i0e_p,
bessel_i1e,
bessel_i1e_p,
betainc,
bitcast_convert_type,
bitcast_convert_type_p,
bitwise_and,
bitwise_not,
bitwise_or,
bitwise_xor,
broadcast,
broadcast_in_dim,
broadcast_in_dim_p,
broadcast_shapes,
broadcast_to_rank,
broadcasted_iota,
cbrt,
cbrt_p,
ceil,
ceil_p,
clamp,
clamp_p,
clz,
clz_p,
collapse,
complex,
complex_p,
concatenate,
concatenate_p,
conj,
conj_p,
conv,
conv_dimension_numbers,
conv_general_dilated,
conv_general_dilated_p,
conv_general_permutations,
conv_general_shape_tuple,
conv_shape_tuple,
conv_transpose,
conv_transpose_shape_tuple,
conv_with_general_padding,
convert_element_type,
_convert_element_type,
convert_element_type_p,
cos,
cos_p,
cosh,
cosh_p,
create_token,
create_token_p,
digamma,
digamma_p,
div,
div_p,
dot,
dot_general,
dot_general_p,
dtype,
dtypes,
dynamic_index_in_dim,
dynamic_slice,
dynamic_slice_in_dim,
dynamic_slice_p,
dynamic_update_index_in_dim,
dynamic_update_slice,
dynamic_update_slice_in_dim,
dynamic_update_slice_p,
eq,
eq_p,
erf,
erf_inv,
erf_inv_p,
erf_p,
erfc,
erfc_p,
exp,
exp_p,
expand_dims,
expm1,
expm1_p,
floor,
floor_p,
full,
full_like,
gather,
gather_p,
ge,
ge_p,
gt,
gt_p,
igamma,
igamma_grad_a,
igamma_grad_a_p,
igamma_p,
igammac,
igammac_p,
imag,
imag_p,
index_in_dim,
index_take,
infeed,
infeed_p,
integer_pow,
integer_pow_p,
iota,
iota_p,
is_finite,
is_finite_p,
itertools,
le,
le_p,
lgamma,
lgamma_p,
log,
log1p,
log1p_p,
log_p,
lt,
lt_p,
max,
max_p,
min,
min_p,
mul,
mul_p,
naryop,
naryop_dtype_rule,
ne,
ne_p,
neg,
neg_p,
nextafter,
nextafter_p,
not_p,
or_p,
outfeed,
outfeed_p,
pad,
pad_p,
padtype_to_pads,
partial,
population_count,
population_count_p,
pow,
pow_p,
prod,
random_gamma_grad,
random_gamma_grad_p,
real,
real_p,
reciprocal,
reduce,
reduce_and_p,
reduce_max_p,
reduce_min_p,
reduce_or_p,
reduce_p,
reduce_precision,
reduce_precision_p,
reduce_prod_p,
reduce_sum_p,
reduce_window,
reduce_window_max_p,
reduce_window_min_p,
reduce_window_p,
reduce_window_shape_tuple,
reduce_window_sum_p,
regularized_incomplete_beta_p,
rem,
rem_p,
reshape,
reshape_p,
rev,
rev_p,
rng_bit_generator,
rng_bit_generator_p,
rng_uniform,
rng_uniform_p,
round,
round_p,
rsqrt,
rsqrt_p,
scatter,
scatter_add,
scatter_add_p,
scatter_max,
scatter_max_p,
scatter_min,
scatter_min_p,
scatter_mul,
scatter_mul_p,
scatter_p,
select,
select_and_gather_add_p,
select_and_scatter_add_p,
select_and_scatter_p,
select_p,
shift_left,
shift_left_p,
shift_right_arithmetic,
shift_right_arithmetic_p,
shift_right_logical,
shift_right_logical_p,
sign,
sign_p,
sin,
sin_p,
sinh,
sinh_p,
slice,
slice_in_dim,
slice_p,
sort,
sort_key_val,
sort_p,
sqrt,
sqrt_p,
square,
squeeze,
squeeze_p,
standard_abstract_eval,
standard_naryop,
standard_primitive,
standard_translate,
standard_unop,
stop_gradient,
sub,
sub_p,
tan,
tan_p,
tanh,
tanh_p,
tie_in,
top_k,
top_k_p,
transpose,
transpose_p,
unop,
unop_dtype_rule,
xor_p,
zeros_like_array,
ConvDimensionNumbers as ConvDimensionNumbers,
ConvGeneralDilatedDimensionNumbers as ConvGeneralDilatedDimensionNumbers,
DotDimensionNumbers as DotDimensionNumbers,
GatherDimensionNumbers as GatherDimensionNumbers,
Precision as Precision,
RandomAlgorithm as RandomAlgorithm,
RoundingMethod as RoundingMethod,
ScatterDimensionNumbers as ScatterDimensionNumbers,
abs as abs,
abs_p as abs_p,
acos as acos,
acos_p as acos_p,
acosh as acosh,
acosh_p as acosh_p,
abs as abs,
abs_p as abs_p,
acos as acos,
acosh as acosh,
acosh_p as acosh_p,
add as add,
add_p as add_p,
after_all as after_all,
after_all_p as after_all_p,
and_p as and_p,
argmax as argmax,
argmax_p as argmax_p,
argmin as argmin,
argmin_p as argmin_p,
asin as asin,
asin_p as asin_p,
asinh as asinh,
asinh_p as asinh_p,
atan as atan,
atan_p as atan_p,
atan2 as atan2,
atan2_p as atan2_p,
atanh as atanh,
atanh_p as atanh_p,
batch_matmul as batch_matmul,
bessel_i0e as bessel_i0e,
bessel_i0e_p as bessel_i0e_p,
bessel_i1e as bessel_i1e,
bessel_i1e_p as bessel_i1e_p,
betainc as betainc,
bitcast_convert_type as bitcast_convert_type,
bitcast_convert_type_p as bitcast_convert_type_p,
bitwise_and as bitwise_and,
bitwise_not as bitwise_not,
bitwise_or as bitwise_or,
bitwise_xor as bitwise_xor,
broadcast as broadcast,
broadcast_in_dim as broadcast_in_dim,
broadcast_in_dim_p as broadcast_in_dim_p,
broadcast_shapes as broadcast_shapes,
broadcast_to_rank as broadcast_to_rank,
broadcasted_iota as broadcasted_iota,
cbrt as cbrt,
cbrt_p as cbrt_p,
ceil as ceil,
ceil_p as ceil_p,
clamp as clamp,
clamp_p as clamp_p,
clz as clz,
clz_p as clz_p,
collapse as collapse,
complex as complex,
complex_p as complex_p,
concatenate as concatenate,
concatenate_p as concatenate_p,
conj as conj,
conj_p as conj_p,
conv as conv,
conv_dimension_numbers as conv_dimension_numbers,
conv_general_dilated as conv_general_dilated,
conv_general_dilated_p as conv_general_dilated_p,
conv_general_permutations as conv_general_permutations,
conv_general_shape_tuple as conv_general_shape_tuple,
conv_shape_tuple as conv_shape_tuple,
conv_transpose as conv_transpose,
conv_transpose_shape_tuple as conv_transpose_shape_tuple,
conv_with_general_padding as conv_with_general_padding,
convert_element_type as convert_element_type,
_convert_element_type as _convert_element_type,
convert_element_type_p as convert_element_type_p,
cos as cos,
cos_p as cos_p,
cosh as cosh,
cosh_p as cosh_p,
create_token as create_token,
create_token_p as create_token_p,
digamma as digamma,
digamma_p as digamma_p,
div as div,
div_p as div_p,
dot as dot,
dot_general as dot_general,
dot_general_p as dot_general_p,
dtype as dtype,
dtypes as dtypes,
dynamic_index_in_dim as dynamic_index_in_dim,
dynamic_slice as dynamic_slice,
dynamic_slice_in_dim as dynamic_slice_in_dim,
dynamic_slice_p as dynamic_slice_p,
dynamic_update_index_in_dim as dynamic_update_index_in_dim,
dynamic_update_slice as dynamic_update_slice,
dynamic_update_slice_in_dim as dynamic_update_slice_in_dim,
dynamic_update_slice_p as dynamic_update_slice_p,
eq as eq,
eq_p as eq_p,
erf as erf,
erf_inv as erf_inv,
erf_inv_p as erf_inv_p,
erf_p as erf_p,
erfc as erfc,
erfc_p as erfc_p,
exp as exp,
exp_p as exp_p,
expand_dims as expand_dims,
expm1 as expm1,
expm1_p as expm1_p,
floor as floor,
floor_p as floor_p,
full as full,
full_like as full_like,
gather as gather,
gather_p as gather_p,
ge as ge,
ge_p as ge_p,
gt as gt,
gt_p as gt_p,
igamma as igamma,
igamma_grad_a as igamma_grad_a,
igamma_grad_a_p as igamma_grad_a_p,
igamma_p as igamma_p,
igammac as igammac,
igammac_p as igammac_p,
imag as imag,
imag_p as imag_p,
index_in_dim as index_in_dim,
index_take as index_take,
infeed as infeed,
infeed_p as infeed_p,
integer_pow as integer_pow,
integer_pow_p as integer_pow_p,
iota as iota,
iota_p as iota_p,
is_finite as is_finite,
is_finite_p as is_finite_p,
itertools as itertools,
le as le,
le_p as le_p,
lgamma as lgamma,
lgamma_p as lgamma_p,
log as log,
log1p as log1p,
log1p_p as log1p_p,
log_p as log_p,
lt as lt,
lt_p as lt_p,
max as max,
max_p as max_p,
min as min,
min_p as min_p,
mul as mul,
mul_p as mul_p,
naryop as naryop,
naryop_dtype_rule as naryop_dtype_rule,
ne as ne,
ne_p as ne_p,
neg as neg,
neg_p as neg_p,
nextafter as nextafter,
nextafter_p as nextafter_p,
not_p as not_p,
or_p as or_p,
outfeed as outfeed,
outfeed_p as outfeed_p,
pad as pad,
pad_p as pad_p,
padtype_to_pads as padtype_to_pads,
partial as partial,
population_count as population_count,
population_count_p as population_count_p,
pow as pow,
pow_p as pow_p,
prod as prod,
random_gamma_grad as random_gamma_grad,
random_gamma_grad_p as random_gamma_grad_p,
real as real,
real_p as real_p,
reciprocal as reciprocal,
reduce as reduce,
reduce_and_p as reduce_and_p,
reduce_max_p as reduce_max_p,
reduce_min_p as reduce_min_p,
reduce_or_p as reduce_or_p,
reduce_p as reduce_p,
reduce_precision as reduce_precision,
reduce_precision_p as reduce_precision_p,
reduce_prod_p as reduce_prod_p,
reduce_sum_p as reduce_sum_p,
reduce_window as reduce_window,
reduce_window_max_p as reduce_window_max_p,
reduce_window_min_p as reduce_window_min_p,
reduce_window_p as reduce_window_p,
reduce_window_shape_tuple as reduce_window_shape_tuple,
reduce_window_sum_p as reduce_window_sum_p,
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
rem as rem,
rem_p as rem_p,
reshape as reshape,
reshape_p as reshape_p,
rev as rev,
rev_p as rev_p,
rng_bit_generator as rng_bit_generator,
rng_bit_generator_p as rng_bit_generator_p,
rng_uniform as rng_uniform,
rng_uniform_p as rng_uniform_p,
round as round,
round_p as round_p,
rsqrt as rsqrt,
rsqrt_p as rsqrt_p,
scatter as scatter,
scatter_add as scatter_add,
scatter_add_p as scatter_add_p,
scatter_max as scatter_max,
scatter_max_p as scatter_max_p,
scatter_min as scatter_min,
scatter_min_p as scatter_min_p,
scatter_mul as scatter_mul,
scatter_mul_p as scatter_mul_p,
scatter_p as scatter_p,
select as select,
select_and_gather_add_p as select_and_gather_add_p,
select_and_scatter_add_p as select_and_scatter_add_p,
select_and_scatter_p as select_and_scatter_p,
select_p as select_p,
shift_left as shift_left,
shift_left_p as shift_left_p,
shift_right_arithmetic as shift_right_arithmetic,
shift_right_arithmetic_p as shift_right_arithmetic_p,
shift_right_logical as shift_right_logical,
shift_right_logical_p as shift_right_logical_p,
sign as sign,
sign_p as sign_p,
sin as sin,
sin_p as sin_p,
sinh as sinh,
sinh_p as sinh_p,
slice as slice,
slice_in_dim as slice_in_dim,
slice_p as slice_p,
sort as sort,
sort_key_val as sort_key_val,
sort_p as sort_p,
sqrt as sqrt,
sqrt_p as sqrt_p,
square as square,
squeeze as squeeze,
squeeze_p as squeeze_p,
standard_abstract_eval as standard_abstract_eval,
standard_naryop as standard_naryop,
standard_primitive as standard_primitive,
standard_translate as standard_translate,
standard_unop as standard_unop,
stop_gradient as stop_gradient,
sub as sub,
sub_p as sub_p,
tan as tan,
tan_p as tan_p,
tanh as tanh,
tanh_p as tanh_p,
tie_in as tie_in,
top_k as top_k,
top_k_p as top_k_p,
transpose as transpose,
transpose_p as transpose_p,
unop as unop,
unop_dtype_rule as unop_dtype_rule,
xor_p as xor_p,
zeros_like_array as zeros_like_array,
)
from jax._src.lax.lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
_reduce_and, _reduce_window_sum, _reduce_window_max,
@ -306,56 +306,56 @@ from jax._src.lax.lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
_upcast_fp16_for_computation, _broadcasting_shape_rule,
_eye, _tri, _delta, _ones, _zeros, _dilate_shape)
from jax._src.lax.control_flow import (
associative_scan,
cond,
cond_p,
cummax,
cummax_p,
cummin,
cummin_p,
cumprod,
cumprod_p,
cumsum,
cumsum_p,
custom_linear_solve,
custom_root,
fori_loop,
linear_solve_p,
map,
scan,
scan_bind,
scan_p,
switch,
while_loop,
while_p,
associative_scan as associative_scan,
cond as cond,
cond_p as cond_p,
cummax as cummax,
cummax_p as cummax_p,
cummin as cummin,
cummin_p as cummin_p,
cumprod as cumprod,
cumprod_p as cumprod_p,
cumsum as cumsum,
cumsum_p as cumsum_p,
custom_linear_solve as custom_linear_solve,
custom_root as custom_root,
fori_loop as fori_loop,
linear_solve_p as linear_solve_p,
map as map,
scan as scan,
scan_bind as scan_bind,
scan_p as scan_p,
switch as switch,
while_loop as while_loop,
while_p as while_p,
)
from jax._src.lax.fft import (
fft,
fft_p,
fft as fft,
fft_p as fft_p,
)
from jax._src.lax.parallel import (
all_gather,
all_to_all,
all_to_all_p,
axis_index,
axis_index_p,
pmax,
pmax_p,
pmean,
pmin,
pmin_p,
ppermute,
ppermute_p,
pshuffle,
psum,
psum_p,
psum_scatter,
pswapaxes,
pdot,
xeinsum,
all_gather as all_gather,
all_to_all as all_to_all,
all_to_all_p as all_to_all_p,
axis_index as axis_index,
axis_index_p as axis_index_p,
pmax as pmax,
pmax_p as pmax_p,
pmean as pmean,
pmin as pmin,
pmin_p as pmin_p,
ppermute as ppermute,
ppermute_p as ppermute_p,
pshuffle as pshuffle,
psum as psum,
psum_p as psum_p,
psum_scatter as psum_scatter,
pswapaxes as pswapaxes,
pdot as pdot,
xeinsum as xeinsum,
)
from jax._src.lax.other import (
conv_general_dilated_patches
conv_general_dilated_patches as conv_general_dilated_patches
)
from jax._src.ad_util import stop_gradient_p
from . import linalg
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
from . import linalg as linalg

View File

@ -22,7 +22,7 @@ from typing import Optional
__all__ = [
'cuda_linalg', 'cuda_prng', 'cusolver', 'rocsolver', 'jaxlib', 'lapack',
'pocketfft', 'pytree', 'tpu_client', 'version', 'xla_client'
'pocketfft', 'pytree', 'tpu_driver_client', 'version', 'xla_client'
]
# First, before attempting to import jaxlib, warn about experimental machine

View File

@ -19,27 +19,27 @@
from jax.numpy import tanh
from . import initializers
from jax._src.nn.functions import (
celu,
elu,
gelu,
glu,
hard_sigmoid,
hard_silu,
hard_swish,
hard_tanh,
leaky_relu,
log_sigmoid,
log_softmax,
logsumexp,
normalize,
one_hot,
relu,
relu6,
selu,
sigmoid,
soft_sign,
softmax,
softplus,
silu,
swish,
celu as celu,
elu as elu,
gelu as gelu,
glu as glu,
hard_sigmoid as hard_sigmoid,
hard_silu as hard_silu,
hard_swish as hard_swish,
hard_tanh as hard_tanh,
leaky_relu as leaky_relu,
log_sigmoid as log_sigmoid,
log_softmax as log_softmax,
logsumexp as logsumexp,
normalize as normalize,
one_hot as one_hot,
relu as relu,
relu6 as relu6,
selu as selu,
sigmoid as sigmoid,
soft_sign as soft_sign,
softmax as softmax,
softplus as softplus,
silu as silu,
swish as swish,
)

View File

@ -19,21 +19,21 @@ used in Keras and Sonnet.
# flake8: noqa: F401
from jax._src.nn.initializers import (
delta_orthogonal,
glorot_normal,
glorot_uniform,
he_normal,
he_uniform,
kaiming_normal,
kaiming_uniform,
lecun_normal,
lecun_uniform,
normal,
ones,
orthogonal,
uniform,
variance_scaling,
xavier_normal,
xavier_uniform,
zeros,
delta_orthogonal as delta_orthogonal,
glorot_normal as glorot_normal,
glorot_uniform as glorot_uniform,
he_normal as he_normal,
he_uniform as he_uniform,
kaiming_normal as kaiming_normal,
kaiming_uniform as kaiming_uniform,
lecun_normal as lecun_normal,
lecun_uniform as lecun_uniform,
normal as normal,
ones as ones,
orthogonal as orthogonal,
uniform as uniform,
variance_scaling as variance_scaling,
xavier_normal as xavier_normal,
xavier_uniform as xavier_uniform,
zeros as zeros,
)

View File

@ -12,59 +12,379 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
from . import fft
from . import linalg
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax.interpreters.xla import DeviceArray
# flake8: noqa: F401
from . import fft as fft
from . import linalg as linalg
from jax.interpreters.xla import DeviceArray as DeviceArray
from jax._src.numpy.lax_numpy import (
ComplexWarning, NINF, NZERO, PZERO, abs, absolute, add, all, allclose,
alltrue, amax, amin, angle, any, append,
apply_along_axis, apply_over_axes, arange, arccos, arccosh, arcsin,
arcsinh, arctan, arctan2, arctanh, argmax, argmin, argsort, argwhere, around,
array, array_equal, array_equiv, array_repr, array_split, array_str, asarray, atleast_1d, atleast_2d,
atleast_3d, average, bartlett, bfloat16, bincount, bitwise_and, bitwise_not,
bitwise_or, bitwise_xor, blackman, block, bool_, broadcast_arrays, broadcast_shapes,
broadcast_to, c_, can_cast, cbrt, cdouble, ceil, character, choose, clip, column_stack,
complex128, complex64, complex_, complexfloating, compress, concatenate,
conj, conjugate, convolve, copysign, corrcoef, correlate, cos, cosh,
count_nonzero, cov, cross, csingle, cumprod, cumproduct, cumsum, deg2rad, degrees,
delete, diag, diagflat, diag_indices, diag_indices_from, diagonal, diff, digitize, divide, divmod, dot,
double, dsplit, dstack, dtype, e, ediff1d, einsum, einsum_path, empty,
empty_like, equal, euler_gamma, exp, exp2, expand_dims, expm1, extract, eye,
fabs, finfo, fix, flatnonzero, flexible, flip, fliplr, flipud, float16, float32,
float64, float_, float_power, floating, floor, floor_divide, fmax, fmin,
fmod, frexp, full, full_like, gcd, geomspace, gradient, greater,
greater_equal, hamming, hanning, heaviside, histogram, histogram_bin_edges, histogram2d, histogramdd,
hsplit, hstack, hypot, i0, identity, iinfo, imag, index_exp,
indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer,
interp, intersect1d, invert,
isclose, iscomplex, iscomplexobj, isfinite, isin, isinf, isnan, isneginf,
isposinf, isreal, isrealobj, isscalar, issubdtype, issubsctype, iterable,
ix_, kaiser, kron, lcm, ldexp, left_shift, less, less_equal, lexsort, linspace,
load, log, log10, log1p, log2, logaddexp, logaddexp2, logical_and,
logical_not, logical_or, logical_xor, logspace, mask_indices, matmul, max,
maximum, mean, median, meshgrid, mgrid, min, minimum, mod, modf, moveaxis, msort,
multiply, nan, nan_to_num, nanargmax, nanargmin, nancumprod, nancumsum,
nanmedian, nanpercentile, nanquantile,
nanmax, nanmean, nanmin, nanprod, nanstd, nansum, nanvar, ndarray, ndim,
negative, newaxis, nextafter, nonzero, not_equal, number,
object_, ogrid, ones, ones_like, outer, packbits, pad, percentile,
pi, piecewise, poly, polyadd, polyder, polyfit, polyint, polymul, polysub, polyval, positive, power,
prod, product, promote_types, ptp, quantile,
r_, rad2deg, radians, ravel, ravel_multi_index, real, reciprocal, remainder, repeat, reshape,
resize, result_type, right_shift, rint, roll, rollaxis, rot90, round, round_, row_stack,
save, savez, searchsorted, select, set_printoptions, setdiff1d, setxor1d, shape, sign, signbit,
s_, signedinteger, sin, sinc, single, sinh, size, sometrue, sort, sort_complex, split, sqrt,
square, squeeze, stack, std, subtract, sum, swapaxes, take, take_along_axis,
tan, tanh, tensordot, tile, trace, trapz, transpose, tri, tril, tril_indices, tril_indices_from,
trim_zeros, triu, triu_indices, triu_indices_from, true_divide, trunc, uint16, uint32, uint64, uint8, unique,
union1d, unpackbits, unravel_index, unsignedinteger, unwrap, vander, var, vdot, vsplit,
vstack, where, zeros, zeros_like, _NOT_IMPLEMENTED)
ComplexWarning as ComplexWarning,
NINF as NINF,
NZERO as NZERO,
PZERO as PZERO,
abs as abs,
absolute as absolute,
add as add,
all as all,
allclose as allclose,
alltrue as alltrue,
amax as amax,
amin as amin,
angle as angle,
any as any,
append as append,
apply_along_axis as apply_along_axis,
apply_over_axes as apply_over_axes,
arange as arange,
arccos as arccos,
arccosh as arccosh,
arcsin as arcsin,
arcsinh as arcsinh,
arctan as arctan,
arctan2 as arctan2,
arctanh as arctanh,
argmax as argmax,
argmin as argmin,
argsort as argsort,
argwhere as argwhere,
around as around,
array as array,
array_equal as array_equal,
array_equiv as array_equiv,
array_repr as array_repr,
array_split as array_split,
array_str as array_str,
asarray as asarray,
atleast_1d as atleast_1d,
atleast_2d as atleast_2d,
atleast_3d as atleast_3d,
average as average,
bartlett as bartlett,
bfloat16 as bfloat16,
bincount as bincount,
bitwise_and as bitwise_and,
bitwise_not as bitwise_not,
bitwise_or as bitwise_or,
bitwise_xor as bitwise_xor,
blackman as blackman,
block as block,
bool_ as bool_,
broadcast_arrays as broadcast_arrays,
broadcast_shapes as broadcast_shapes,
broadcast_to as broadcast_to,
c_ as c_,
can_cast as can_cast,
cbrt as cbrt,
cdouble as cdouble,
ceil as ceil,
character as character,
choose as choose,
clip as clip,
column_stack as column_stack,
complex128 as complex128,
complex64 as complex64,
complex_ as complex_,
complexfloating as complexfloating,
compress as compress,
concatenate as concatenate,
conj as conj,
conjugate as conjugate,
convolve as convolve,
copysign as copysign,
corrcoef as corrcoef,
correlate as correlate,
cos as cos,
cosh as cosh,
count_nonzero as count_nonzero,
cov as cov,
cross as cross,
csingle as csingle,
cumprod as cumprod,
cumproduct as cumproduct,
cumsum as cumsum,
deg2rad as deg2rad,
degrees as degrees,
delete as delete,
diag as diag,
diagflat as diagflat,
diag_indices as diag_indices,
diag_indices_from as diag_indices_from,
diagonal as diagonal,
diff as diff,
digitize as digitize,
divide as divide,
divmod as divmod,
dot as dot,
double as double,
dsplit as dsplit,
dstack as dstack,
dtype as dtype,
e as e,
ediff1d as ediff1d,
einsum as einsum,
einsum_path as einsum_path,
empty as empty,
empty_like as empty_like,
equal as equal,
euler_gamma as euler_gamma,
exp as exp,
exp2 as exp2,
expand_dims as expand_dims,
expm1 as expm1,
extract as extract,
eye as eye,
fabs as fabs,
finfo as finfo,
fix as fix,
flatnonzero as flatnonzero,
flexible as flexible,
flip as flip,
fliplr as fliplr,
flipud as flipud,
float16 as float16,
float32 as float32,
float64 as float64,
float_ as float_,
float_power as float_power,
floating as floating,
floor as floor,
floor_divide as floor_divide,
fmax as fmax,
fmin as fmin,
fmod as fmod,
frexp as frexp,
full as full,
full_like as full_like,
gcd as gcd,
geomspace as geomspace,
gradient as gradient,
greater as greater,
greater_equal as greater_equal,
hamming as hamming,
hanning as hanning,
heaviside as heaviside,
histogram as histogram,
histogram_bin_edges as histogram_bin_edges,
histogram2d as histogram2d,
histogramdd as histogramdd,
hsplit as hsplit,
hstack as hstack,
hypot as hypot,
i0 as i0,
identity as identity,
iinfo as iinfo,
imag as imag,
index_exp as index_exp,
indices as indices,
inexact as inexact,
in1d as in1d,
inf as inf,
inner as inner,
int16 as int16,
int32 as int32,
int64 as int64,
int8 as int8,
int_ as int_,
integer as integer,
interp as interp,
intersect1d as intersect1d,
invert as invert,
isclose as isclose,
iscomplex as iscomplex,
iscomplexobj as iscomplexobj,
isfinite as isfinite,
isin as isin,
isinf as isinf,
isnan as isnan,
isneginf as isneginf,
isposinf as isposinf,
isreal as isreal,
isrealobj as isrealobj,
isscalar as isscalar,
issubdtype as issubdtype,
issubsctype as issubsctype,
iterable as iterable,
ix_ as ix_,
kaiser as kaiser,
kron as kron,
lcm as lcm,
ldexp as ldexp,
left_shift as left_shift,
less as less,
less_equal as less_equal,
lexsort as lexsort,
linspace as linspace,
load as load,
log as log,
log10 as log10,
log1p as log1p,
log2 as log2,
logaddexp as logaddexp,
logaddexp2 as logaddexp2,
logical_and as logical_and,
logical_not as logical_not,
logical_or as logical_or,
logical_xor as logical_xor,
logspace as logspace,
mask_indices as mask_indices,
matmul as matmul,
max as max,
maximum as maximum,
mean as mean,
median as median,
meshgrid as meshgrid,
mgrid as mgrid,
min as min,
minimum as minimum,
mod as mod,
modf as modf,
moveaxis as moveaxis,
msort as msort,
multiply as multiply,
nan as nan,
nan_to_num as nan_to_num,
nanargmax as nanargmax,
nanargmin as nanargmin,
nancumprod as nancumprod,
nancumsum as nancumsum,
nanmedian as nanmedian,
nanpercentile as nanpercentile,
nanquantile as nanquantile,
nanmax as nanmax,
nanmean as nanmean,
nanmin as nanmin,
nanprod as nanprod,
nanstd as nanstd,
nansum as nansum,
nanvar as nanvar,
ndarray as ndarray,
ndim as ndim,
negative as negative,
newaxis as newaxis,
nextafter as nextafter,
nonzero as nonzero,
not_equal as not_equal,
number as number,
object_ as object_,
ogrid as ogrid,
ones as ones,
ones_like as ones_like,
outer as outer,
packbits as packbits,
pad as pad,
percentile as percentile,
pi as pi,
piecewise as piecewise,
poly as poly,
polyadd as polyadd,
polyder as polyder,
polyfit as polyfit,
polyint as polyint,
polymul as polymul,
polysub as polysub,
polyval as polyval,
positive as positive,
power as power,
prod as prod,
product as product,
promote_types as promote_types,
ptp as ptp,
quantile as quantile,
r_ as r_,
rad2deg as rad2deg,
radians as radians,
ravel as ravel,
ravel_multi_index as ravel_multi_index,
real as real,
reciprocal as reciprocal,
remainder as remainder,
repeat as repeat,
reshape as reshape,
resize as resize,
result_type as result_type,
right_shift as right_shift,
rint as rint,
roll as roll,
rollaxis as rollaxis,
rot90 as rot90,
round as round,
round_ as round_,
row_stack as row_stack,
save as save,
savez as savez,
searchsorted as searchsorted,
select as select,
set_printoptions as set_printoptions,
setdiff1d as setdiff1d,
setxor1d as setxor1d,
shape as shape,
sign as sign,
signbit as signbit,
s_ as s_,
signedinteger as signedinteger,
sin as sin,
sinc as sinc,
single as single,
sinh as sinh,
size as size,
sometrue as sometrue,
sort as sort,
sort_complex as sort_complex,
split as split,
sqrt as sqrt,
square as square,
squeeze as squeeze,
stack as stack,
std as std,
subtract as subtract,
sum as sum,
swapaxes as swapaxes,
take as take,
take_along_axis as take_along_axis,
tan as tan,
tanh as tanh,
tensordot as tensordot,
tile as tile,
trace as trace,
trapz as trapz,
transpose as transpose,
tri as tri,
tril as tril,
tril_indices as tril_indices,
tril_indices_from as tril_indices_from,
trim_zeros as trim_zeros,
triu as triu,
triu_indices as triu_indices,
triu_indices_from as triu_indices_from,
true_divide as true_divide,
trunc as trunc,
uint16 as uint16,
uint32 as uint32,
uint64 as uint64,
uint8 as uint8,
unique as unique,
union1d as union1d,
unpackbits as unpackbits,
unravel_index as unravel_index,
unsignedinteger as unsignedinteger,
unwrap as unwrap,
vander as vander,
var as var,
vdot as vdot,
vsplit as vsplit,
vstack as vstack,
where as where,
zeros as zeros,
zeros_like as zeros_like,
_NOT_IMPLEMENTED,
)
from jax._src.numpy.polynomial import roots
from jax._src.numpy.vectorize import vectorize
from jax._src.numpy.polynomial import roots as roots
from jax._src.numpy.vectorize import vectorize as vectorize
# TODO(phawkins): remove this import after fixing users.
from jax._src.numpy import lax_numpy

View File

@ -15,24 +15,24 @@
# flake8: noqa: F401
from jax._src.numpy.fft import (
ifft,
ifft2,
ifftn,
ifftshift,
ihfft,
irfft,
irfft2,
irfftn,
fft,
fft2,
fftfreq,
fftn,
fftshift,
hfft,
rfft,
rfft2,
rfftfreq,
rfftn,
ifft as ifft,
ifft2 as ifft2,
ifftn as ifftn,
ifftshift as ifftshift,
ihfft as ihfft,
irfft as irfft,
irfft2 as irfft2,
irfftn as irfftn,
fft as fft,
fft2 as fft2,
fftfreq as fftfreq,
fftn as fftn,
fftshift as fftshift,
hfft as hfft,
rfft as rfft,
rfft2 as rfft2,
rfftfreq as rfftfreq,
rfftn as rfftn,
)
# Module initialization is encapsulated in a function to avoid accidental

View File

@ -15,28 +15,28 @@
# flake8: noqa: F401
from jax._src.numpy.linalg import (
cholesky,
det,
eig,
eigh,
eigvals,
eigvalsh,
inv,
lstsq,
matrix_power,
matrix_rank,
norm,
pinv,
qr,
slogdet,
solve,
svd,
cholesky as cholesky,
det as det,
eig as eig,
eigh as eigh,
eigvals as eigvals,
eigvalsh as eigvalsh,
inv as inv,
lstsq as lstsq,
matrix_power as matrix_power,
matrix_rank as matrix_rank,
norm as norm,
pinv as pinv,
qr as qr,
slogdet as slogdet,
solve as solve,
svd as svd,
)
from jax._src.third_party.numpy.linalg import (
cond,
multi_dot,
tensorinv,
tensorsolve
cond as cond,
multi_dot as multi_dot,
tensorinv as tensorinv,
tensorsolve as tensorsolve,
)
# Module initialization is encapsulated in a function to avoid accidental

View File

@ -14,6 +14,14 @@
# flake8: noqa: F401
from jax._src.ops.scatter import (
index, index_add, index_mul, index_update, index_min, index_max,
segment_sum, segment_prod, segment_min, segment_max,
index as index,
index_add as index_add,
index_mul as index_mul,
index_update as index_update,
index_min as index_min,
index_max as index_max,
segment_sum as segment_sum,
segment_prod as segment_prod,
segment_min as segment_min,
segment_max as segment_max,
)

View File

@ -15,9 +15,9 @@
# flake8: noqa: F401
from jax._src.prng import (
PRNGImpl,
seed_with_impl,
threefry2x32_p,
threefry_2x32,
threefry_prng_impl,
PRNGImpl as PRNGImpl,
seed_with_impl as seed_with_impl,
threefry2x32_p as threefry2x32_p,
threefry_2x32 as threefry_2x32,
threefry_prng_impl as threefry_prng_impl,
)

View File

@ -14,16 +14,16 @@
# flake8: noqa: F401
from jax._src.profiler import (
StepTraceAnnotation,
StepTraceContext,
TraceAnnotation,
TraceContext,
device_memory_profile,
save_device_memory_profile,
start_server,
start_trace,
stop_trace,
trace,
annotate_function,
trace_function,
StepTraceAnnotation as StepTraceAnnotation,
StepTraceContext as StepTraceContext,
TraceAnnotation as TraceAnnotation,
TraceContext as TraceContext,
device_memory_profile as device_memory_profile,
save_device_memory_profile as save_device_memory_profile,
start_server as start_server,
start_trace as start_trace,
stop_trace as stop_trace,
trace as trace,
annotate_function as annotate_function,
trace_function as trace_function,
)

View File

@ -79,35 +79,35 @@ for the design and its motivation.
# flake8: noqa: F401
from jax._src.random import (
PRNGKey,
bernoulli,
beta,
categorical,
cauchy,
choice,
dirichlet,
double_sided_maxwell,
exponential,
fold_in,
gamma,
gumbel,
laplace,
logistic,
maxwell,
multivariate_normal,
normal,
pareto,
permutation,
poisson,
rademacher,
randint,
random_gamma_p,
shuffle,
split,
t,
threefry_2x32,
threefry2x32_p,
truncated_normal,
uniform,
weibull_min,
PRNGKey as PRNGKey,
bernoulli as bernoulli,
beta as beta,
categorical as categorical,
cauchy as cauchy,
choice as choice,
dirichlet as dirichlet,
double_sided_maxwell as double_sided_maxwell,
exponential as exponential,
fold_in as fold_in,
gamma as gamma,
gumbel as gumbel,
laplace as laplace,
logistic as logistic,
maxwell as maxwell,
multivariate_normal as multivariate_normal,
normal as normal,
pareto as pareto,
permutation as permutation,
poisson as poisson,
rademacher as rademacher,
randint as randint,
random_gamma_p as random_gamma_p,
shuffle as shuffle,
split as split,
t as t,
threefry_2x32 as threefry_2x32,
threefry2x32_p as threefry2x32_p,
truncated_normal as truncated_normal,
uniform as uniform,
weibull_min as weibull_min,
)

View File

@ -13,10 +13,10 @@
# limitations under the License.
# flake8: noqa: F401
from . import linalg
from . import ndimage
from . import signal
from . import sparse
from . import special
from . import stats
from . import fft
from . import linalg as linalg
from . import ndimage as ndimage
from . import signal as signal
from . import sparse as sparse
from . import special as special
from . import stats as stats
from . import fft as fft

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.fft import (
dct,
dctn
dct as dct,
dctn as dctn,
)

View File

@ -15,28 +15,28 @@
# flake8: noqa: F401
from jax._src.scipy.linalg import (
block_diag,
cholesky,
cho_factor,
cho_solve,
det,
eigh,
eigh_tridiagonal,
expm,
expm_frechet,
inv,
lu,
lu_factor,
lu_solve,
polar,
qr,
solve,
solve_triangular,
svd,
tril,
triu,
block_diag as block_diag,
cholesky as cholesky,
cho_factor as cho_factor,
cho_solve as cho_solve,
det as det,
eigh as eigh,
eigh_tridiagonal as eigh_tridiagonal,
expm as expm,
expm_frechet as expm_frechet,
inv as inv,
lu as lu,
lu_factor as lu_factor,
lu_solve as lu_solve,
polar as polar,
qr as qr,
solve as solve,
solve_triangular as solve_triangular,
svd as svd,
tril as tril,
triu as triu,
)
from jax._src.lax.polar import (
polar_unitary
polar_unitary as polar_unitary,
)

View File

@ -14,4 +14,6 @@
# flake8: noqa: F401
from jax._src.scipy.ndimage import map_coordinates
from jax._src.scipy.ndimage import (
map_coordinates as map_coordinates,
)

View File

@ -14,4 +14,7 @@
# flake8: noqa: F401
from jax._src.scipy.optimize.minimize import (minimize, OptimizeResults)
from jax._src.scipy.optimize.minimize import (
minimize as minimize,
OptimizeResults as OptimizeResults,
)

View File

@ -15,9 +15,9 @@
# flake8: noqa: F401
from jax._src.scipy.signal import (
convolve,
convolve2d,
correlate,
correlate2d,
detrend,
convolve as convolve,
convolve2d as convolve2d,
correlate as correlate,
correlate2d as correlate2d,
detrend as detrend,
)

View File

@ -13,4 +13,4 @@
# limitations under the License.
# flake8: noqa: F401
from . import linalg
from . import linalg as linalg

View File

@ -14,7 +14,7 @@
# flake8: noqa: F401
from jax._src.scipy.sparse.linalg import (
cg,
gmres,
bicgstab
cg as cg,
gmres as gmres,
bicgstab as bicgstab,
)

View File

@ -15,35 +15,35 @@
# flake8: noqa: F401
from jax._src.scipy.special import (
betainc,
betaln,
digamma,
entr,
erf,
erfc,
erfinv,
exp1,
expi,
expit,
expn,
gammainc,
gammaincc,
gammaln,
i0,
i0e,
i1,
i1e,
logit,
logsumexp,
lpmn,
lpmn_values,
multigammaln,
log_ndtr,
ndtr,
ndtri,
polygamma,
sph_harm,
xlogy,
xlog1py,
zeta,
betainc as betainc,
betaln as betaln,
digamma as digamma,
entr as entr,
erf as erf,
erfc as erfc,
erfinv as erfinv,
exp1 as exp1,
expi as expi,
expit as expit,
expn as expn,
gammainc as gammainc,
gammaincc as gammaincc,
gammaln as gammaln,
i0 as i0,
i0e as i0e,
i1 as i1,
i1e as i1e,
logit as logit,
logsumexp as logsumexp,
lpmn as lpmn,
lpmn_values as lpmn_values,
multigammaln as multigammaln,
log_ndtr as log_ndtr,
ndtr as ndtr,
ndtri as ndtri,
polygamma as polygamma,
sph_harm as sph_harm,
xlogy as xlogy,
xlog1py as xlog1py,
zeta as zeta,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.bernoulli import (
logpmf,
pmf,
)
logpmf as logpmf,
pmf as pmf,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.beta import (
logpdf,
pdf,
)
logpdf as logpdf,
pdf as pdf,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.betabinom import (
logpmf,
pmf,
logpmf as logpmf,
pmf as pmf,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.cauchy import (
logpdf,
pdf,
)
logpdf as logpdf,
pdf as pdf,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.chi2 import (
logpdf,
pdf,
)
logpdf as logpdf,
pdf as pdf,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.dirichlet import (
logpdf,
pdf,
)
logpdf as logpdf,
pdf as pdf,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.expon import (
logpdf,
pdf,
)
logpdf as logpdf,
pdf as pdf,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.gamma import (
logpdf,
pdf,
)
logpdf as logpdf,
pdf as pdf,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.geom import (
logpmf,
pmf,
)
logpmf as logpmf,
pmf as pmf,
)

View File

@ -15,7 +15,7 @@
# flake8: noqa: F401
from jax._src.scipy.stats.laplace import (
cdf,
logpdf,
pdf,
)
cdf as cdf,
logpdf as logpdf,
pdf as pdf,
)

View File

@ -15,10 +15,10 @@
# flake8: noqa: F401
from jax._src.scipy.stats.logistic import (
cdf,
isf,
logpdf,
pdf,
ppf,
sf,
)
cdf as cdf,
isf as isf,
logpdf as logpdf,
pdf as pdf,
ppf as ppf,
sf as sf,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.multivariate_normal import (
logpdf,
pdf,
)
logpdf as logpdf,
pdf as pdf,
)

View File

@ -15,9 +15,9 @@
# flake8: noqa: F401
from jax._src.scipy.stats.norm import (
cdf,
logcdf,
logpdf,
pdf,
ppf,
)
cdf as cdf,
logcdf as logcdf,
logpdf as logpdf,
pdf as pdf,
ppf as ppf,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.pareto import (
logpdf,
pdf,
)
logpdf as logpdf,
pdf as pdf,
)

View File

@ -15,7 +15,7 @@
# flake8: noqa: F401
from jax._src.scipy.stats.poisson import (
logpmf,
pmf,
cdf
logpmf as logpmf,
pmf as pmf,
cdf as cdf,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.t import (
logpdf,
pdf,
)
logpdf as logpdf,
pdf as pdf,
)

View File

@ -15,6 +15,6 @@
# flake8: noqa: F401
from jax._src.scipy.stats.uniform import (
logpdf,
pdf,
)
logpdf as logpdf,
pdf as pdf,
)

View File

@ -37,22 +37,22 @@ for examples.
# flake8: noqa: F401
from jax._src.tree_util import (
Partial,
PyTreeDef,
all_leaves,
build_tree,
register_pytree_node,
register_pytree_node_class,
tree_all,
tree_flatten,
tree_leaves,
tree_map,
tree_multimap,
tree_reduce,
tree_structure,
tree_transpose,
tree_unflatten,
treedef_children,
treedef_is_leaf,
treedef_tuple,
Partial as Partial,
PyTreeDef as PyTreeDef,
all_leaves as all_leaves,
build_tree as build_tree,
register_pytree_node as register_pytree_node,
register_pytree_node_class as register_pytree_node_class,
tree_all as tree_all,
tree_flatten as tree_flatten,
tree_leaves as tree_leaves,
tree_map as tree_map,
tree_multimap as tree_multimap,
tree_reduce as tree_reduce,
tree_structure as tree_structure,
tree_transpose as tree_transpose,
tree_unflatten as tree_unflatten,
treedef_children as treedef_children,
treedef_is_leaf as treedef_is_leaf,
treedef_tuple as treedef_tuple,
)

View File

@ -14,18 +14,18 @@
# flake8: noqa: F401
from jax._src.util import (
HashableFunction,
as_hashable_function,
cache,
partial,
safe_map,
safe_zip,
split_dict,
split_list,
split_merge,
subvals,
toposort,
unzip2,
wrap_name,
wraps,
HashableFunction as HashableFunction,
as_hashable_function as as_hashable_function,
cache as cache,
partial as partial,
safe_map as safe_map,
safe_zip as safe_zip,
split_dict as split_dict,
split_list as split_list,
split_merge as split_merge,
subvals as subvals,
toposort as toposort,
unzip2 as unzip2,
wrap_name as wrap_name,
wraps as wraps,
)