mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00

In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases. The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested. Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected. To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.) With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`. Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this. One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach. PiperOrigin-RevId: 683302687
393 lines
11 KiB
Python
393 lines
11 KiB
Python
# Copyright 2019 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Note: import <name> as <name> is required for names to be exported.
|
|
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
|
|
|
from jax._src.lax.lax import (
|
|
DotDimensionNumbers as DotDimensionNumbers,
|
|
Precision as Precision,
|
|
PrecisionLike as PrecisionLike,
|
|
DotAlgorithm as DotAlgorithm,
|
|
DotAlgorithmPreset as DotAlgorithmPreset,
|
|
RandomAlgorithm as RandomAlgorithm,
|
|
RoundingMethod as RoundingMethod,
|
|
abs as abs,
|
|
abs_p as abs_p,
|
|
acos as acos,
|
|
acos_p as acos_p,
|
|
acosh as acosh,
|
|
acosh_p as acosh_p,
|
|
add as add,
|
|
add_p as add_p,
|
|
after_all as after_all,
|
|
after_all_p as after_all_p,
|
|
and_p as and_p,
|
|
argmax as argmax,
|
|
argmax_p as argmax_p,
|
|
argmin as argmin,
|
|
argmin_p as argmin_p,
|
|
asin as asin,
|
|
asin_p as asin_p,
|
|
asinh as asinh,
|
|
asinh_p as asinh_p,
|
|
atan as atan,
|
|
atan_p as atan_p,
|
|
atan2 as atan2,
|
|
atan2_p as atan2_p,
|
|
atanh as atanh,
|
|
atanh_p as atanh_p,
|
|
batch_matmul as batch_matmul,
|
|
bitcast_convert_type as bitcast_convert_type,
|
|
bitcast_convert_type_p as bitcast_convert_type_p,
|
|
bitwise_and as bitwise_and,
|
|
bitwise_not as bitwise_not,
|
|
bitwise_or as bitwise_or,
|
|
bitwise_xor as bitwise_xor,
|
|
broadcast as broadcast,
|
|
broadcast_in_dim as broadcast_in_dim,
|
|
broadcast_in_dim_p as broadcast_in_dim_p,
|
|
broadcast_shapes as broadcast_shapes,
|
|
broadcast_to_rank as broadcast_to_rank,
|
|
broadcasted_iota as broadcasted_iota,
|
|
cbrt as cbrt,
|
|
cbrt_p as cbrt_p,
|
|
ceil as ceil,
|
|
ceil_p as ceil_p,
|
|
clamp as clamp,
|
|
clamp_p as clamp_p,
|
|
clz as clz,
|
|
clz_p as clz_p,
|
|
collapse as collapse,
|
|
complex as complex,
|
|
complex_p as complex_p,
|
|
concatenate as concatenate,
|
|
concatenate_p as concatenate_p,
|
|
conj as conj,
|
|
conj_p as conj_p,
|
|
convert_element_type as convert_element_type,
|
|
convert_element_type_p as convert_element_type_p,
|
|
copy_p as copy_p,
|
|
cos as cos,
|
|
cos_p as cos_p,
|
|
cosh as cosh,
|
|
cosh_p as cosh_p,
|
|
create_token as create_token,
|
|
create_token_p as create_token_p,
|
|
div as div,
|
|
div_p as div_p,
|
|
dot as dot,
|
|
dot_general as dot_general,
|
|
dot_general_p as dot_general_p,
|
|
dtype as dtype,
|
|
eq as eq,
|
|
eq_p as eq_p,
|
|
eq_to_p as eq_to_p,
|
|
exp as exp,
|
|
exp_p as exp_p,
|
|
exp2 as exp2,
|
|
exp2_p as exp2_p,
|
|
expand_dims as expand_dims,
|
|
expm1 as expm1,
|
|
expm1_p as expm1_p,
|
|
floor as floor,
|
|
floor_p as floor_p,
|
|
full as full,
|
|
full_like as full_like,
|
|
ge as ge,
|
|
ge_p as ge_p,
|
|
gt as gt,
|
|
gt_p as gt_p,
|
|
imag as imag,
|
|
imag_p as imag_p,
|
|
infeed as infeed,
|
|
infeed_p as infeed_p,
|
|
integer_pow as integer_pow,
|
|
integer_pow_p as integer_pow_p,
|
|
iota as iota,
|
|
iota_p as iota_p,
|
|
is_finite as is_finite,
|
|
is_finite_p as is_finite_p,
|
|
le as le,
|
|
le_p as le_p,
|
|
le_to_p as le_to_p,
|
|
log as log,
|
|
log1p as log1p,
|
|
log1p_p as log1p_p,
|
|
log_p as log_p,
|
|
logistic as logistic,
|
|
logistic_p as logistic_p,
|
|
lt as lt,
|
|
lt_p as lt_p,
|
|
lt_to_p as lt_to_p,
|
|
max as max,
|
|
max_p as max_p,
|
|
min as min,
|
|
min_p as min_p,
|
|
mul as mul,
|
|
mul_p as mul_p,
|
|
ne as ne,
|
|
ne_p as ne_p,
|
|
neg as neg,
|
|
neg_p as neg_p,
|
|
nextafter as nextafter,
|
|
nextafter_p as nextafter_p,
|
|
not_p as not_p,
|
|
optimization_barrier as optimization_barrier,
|
|
optimization_barrier_p as optimization_barrier_p,
|
|
or_p as or_p,
|
|
outfeed as outfeed,
|
|
outfeed_p as outfeed_p,
|
|
pad as pad,
|
|
pad_p as pad_p,
|
|
padtype_to_pads as padtype_to_pads,
|
|
population_count as population_count,
|
|
population_count_p as population_count_p,
|
|
pow as pow,
|
|
pow_p as pow_p,
|
|
ragged_dot as ragged_dot,
|
|
real as real,
|
|
real_p as real_p,
|
|
reciprocal as reciprocal,
|
|
reduce as reduce,
|
|
reduce_and_p as reduce_and_p,
|
|
reduce_max_p as reduce_max_p,
|
|
reduce_min_p as reduce_min_p,
|
|
reduce_or_p as reduce_or_p,
|
|
reduce_p as reduce_p,
|
|
reduce_precision as reduce_precision,
|
|
reduce_precision_p as reduce_precision_p,
|
|
reduce_prod_p as reduce_prod_p,
|
|
reduce_sum_p as reduce_sum_p,
|
|
reduce_xor_p as reduce_xor_p,
|
|
rem as rem,
|
|
rem_p as rem_p,
|
|
reshape as reshape,
|
|
reshape_p as reshape_p,
|
|
rev as rev,
|
|
rev_p as rev_p,
|
|
rng_bit_generator as rng_bit_generator,
|
|
rng_bit_generator_p as rng_bit_generator_p,
|
|
rng_uniform as rng_uniform,
|
|
rng_uniform_p as rng_uniform_p,
|
|
round as round,
|
|
round_p as round_p,
|
|
rsqrt as rsqrt,
|
|
rsqrt_p as rsqrt_p,
|
|
select as select,
|
|
select_n as select_n,
|
|
select_n_p as select_n_p,
|
|
shift_left as shift_left,
|
|
shift_left_p as shift_left_p,
|
|
shift_right_arithmetic as shift_right_arithmetic,
|
|
shift_right_arithmetic_p as shift_right_arithmetic_p,
|
|
shift_right_logical as shift_right_logical,
|
|
shift_right_logical_p as shift_right_logical_p,
|
|
sign as sign,
|
|
sign_p as sign_p,
|
|
sin as sin,
|
|
sin_p as sin_p,
|
|
sinh as sinh,
|
|
sinh_p as sinh_p,
|
|
sort as sort,
|
|
sort_key_val as sort_key_val,
|
|
sort_p as sort_p,
|
|
sqrt as sqrt,
|
|
sqrt_p as sqrt_p,
|
|
square as square,
|
|
squeeze as squeeze,
|
|
squeeze_p as squeeze_p,
|
|
stop_gradient as stop_gradient,
|
|
sub as sub,
|
|
sub_p as sub_p,
|
|
tan as tan,
|
|
tan_p as tan_p,
|
|
tanh as tanh,
|
|
tanh_p as tanh_p,
|
|
top_k as top_k,
|
|
top_k_p as top_k_p,
|
|
transpose as transpose,
|
|
transpose_p as transpose_p,
|
|
xor_p as xor_p,
|
|
zeros_like_array as zeros_like_array,
|
|
)
|
|
from jax._src.lax.special import (
|
|
bessel_i0e as bessel_i0e,
|
|
bessel_i0e_p as bessel_i0e_p,
|
|
bessel_i1e as bessel_i1e,
|
|
bessel_i1e_p as bessel_i1e_p,
|
|
betainc as betainc,
|
|
digamma as digamma,
|
|
digamma_p as digamma_p,
|
|
erf as erf,
|
|
erfc as erfc,
|
|
erfc_p as erfc_p,
|
|
erf_inv as erf_inv,
|
|
erf_inv_p as erf_inv_p,
|
|
erf_p as erf_p,
|
|
igamma as igamma,
|
|
igammac as igammac,
|
|
igammac_p as igammac_p,
|
|
igamma_grad_a as igamma_grad_a,
|
|
igamma_grad_a_p as igamma_grad_a_p,
|
|
igamma_p as igamma_p,
|
|
lgamma as lgamma,
|
|
lgamma_p as lgamma_p,
|
|
polygamma as polygamma,
|
|
polygamma_p as polygamma_p,
|
|
random_gamma_grad as random_gamma_grad,
|
|
random_gamma_grad_p as random_gamma_grad_p,
|
|
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
|
|
zeta as zeta,
|
|
zeta_p as zeta_p,
|
|
)
|
|
from jax._src.lax.slicing import (
|
|
GatherDimensionNumbers as GatherDimensionNumbers,
|
|
GatherScatterMode as GatherScatterMode,
|
|
ScatterDimensionNumbers as ScatterDimensionNumbers,
|
|
dynamic_index_in_dim as dynamic_index_in_dim,
|
|
dynamic_slice as dynamic_slice,
|
|
dynamic_slice_in_dim as dynamic_slice_in_dim,
|
|
dynamic_slice_p as dynamic_slice_p,
|
|
dynamic_update_index_in_dim as dynamic_update_index_in_dim,
|
|
dynamic_update_slice as dynamic_update_slice,
|
|
dynamic_update_slice_in_dim as dynamic_update_slice_in_dim,
|
|
dynamic_update_slice_p as dynamic_update_slice_p,
|
|
gather as gather,
|
|
gather_p as gather_p,
|
|
index_in_dim as index_in_dim,
|
|
index_take as index_take,
|
|
scatter as scatter,
|
|
scatter_apply as scatter_apply,
|
|
scatter_add as scatter_add,
|
|
scatter_add_p as scatter_add_p,
|
|
scatter_max as scatter_max,
|
|
scatter_max_p as scatter_max_p,
|
|
scatter_min as scatter_min,
|
|
scatter_min_p as scatter_min_p,
|
|
scatter_mul as scatter_mul,
|
|
scatter_mul_p as scatter_mul_p,
|
|
scatter_p as scatter_p,
|
|
scatter_sub as scatter_sub,
|
|
scatter_sub_p as scatter_sub_p,
|
|
slice as slice,
|
|
slice_in_dim as slice_in_dim,
|
|
slice_p as slice_p,
|
|
)
|
|
from jax._src.lax.convolution import (
|
|
ConvDimensionNumbers as ConvDimensionNumbers,
|
|
ConvGeneralDilatedDimensionNumbers as ConvGeneralDilatedDimensionNumbers,
|
|
conv as conv,
|
|
conv_dimension_numbers as conv_dimension_numbers,
|
|
conv_general_dilated as conv_general_dilated,
|
|
conv_general_dilated_p as conv_general_dilated_p,
|
|
conv_general_permutations as conv_general_permutations,
|
|
conv_general_shape_tuple as conv_general_shape_tuple,
|
|
conv_shape_tuple as conv_shape_tuple,
|
|
conv_transpose as conv_transpose,
|
|
conv_transpose_shape_tuple as conv_transpose_shape_tuple,
|
|
conv_with_general_padding as conv_with_general_padding,
|
|
)
|
|
from jax._src.lax.windowed_reductions import (
|
|
reduce_window as reduce_window,
|
|
reduce_window_max_p as reduce_window_max_p,
|
|
reduce_window_min_p as reduce_window_min_p,
|
|
reduce_window_p as reduce_window_p,
|
|
reduce_window_shape_tuple as reduce_window_shape_tuple,
|
|
reduce_window_sum_p as reduce_window_sum_p,
|
|
select_and_gather_add_p as select_and_gather_add_p,
|
|
select_and_scatter_p as select_and_scatter_p,
|
|
select_and_scatter_add_p as select_and_scatter_add_p,
|
|
)
|
|
from jax._src.lax.control_flow import (
|
|
associative_scan as associative_scan,
|
|
cond as cond,
|
|
cond_p as cond_p,
|
|
cumlogsumexp as cumlogsumexp,
|
|
cumlogsumexp_p as cumlogsumexp_p,
|
|
cummax as cummax,
|
|
cummax_p as cummax_p,
|
|
cummin as cummin,
|
|
cummin_p as cummin_p,
|
|
cumprod as cumprod,
|
|
cumprod_p as cumprod_p,
|
|
cumsum as cumsum,
|
|
cumsum_p as cumsum_p,
|
|
custom_linear_solve as custom_linear_solve,
|
|
custom_root as custom_root,
|
|
fori_loop as fori_loop,
|
|
linear_solve_p as linear_solve_p,
|
|
map as map,
|
|
scan as scan,
|
|
scan_bind as scan_bind,
|
|
scan_p as scan_p,
|
|
switch as switch,
|
|
while_loop as while_loop,
|
|
while_p as while_p,
|
|
platform_dependent as platform_dependent,
|
|
)
|
|
from jax._src.lax.fft import (
|
|
fft as fft,
|
|
fft_p as fft_p,
|
|
)
|
|
from jax._src.lax.parallel import (
|
|
all_gather as all_gather,
|
|
all_gather_p as all_gather_p,
|
|
all_to_all as all_to_all,
|
|
all_to_all_p as all_to_all_p,
|
|
axis_index as axis_index,
|
|
axis_index_p as axis_index_p,
|
|
pbroadcast as pbroadcast,
|
|
pmax as pmax,
|
|
pmax_p as pmax_p,
|
|
pmean as pmean,
|
|
pmin as pmin,
|
|
pmin_p as pmin_p,
|
|
ppermute as ppermute,
|
|
ppermute_p as ppermute_p,
|
|
pshuffle as pshuffle,
|
|
psum as psum,
|
|
psum_p as psum_p,
|
|
psum_scatter as psum_scatter,
|
|
pswapaxes as pswapaxes,
|
|
)
|
|
from jax._src.lax.other import (
|
|
conv_general_dilated_local as conv_general_dilated_local,
|
|
conv_general_dilated_patches as conv_general_dilated_patches
|
|
)
|
|
from jax._src.lax.ann import (
|
|
approx_max_k as approx_max_k,
|
|
approx_min_k as approx_min_k,
|
|
approx_top_k_p as approx_top_k_p
|
|
)
|
|
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
|
|
from jax.lax import linalg as linalg
|
|
|
|
from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
|
|
from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
|
|
from jax._src.dispatch import device_put_p as device_put_p
|
|
|
|
|
|
_deprecations = {
|
|
# Finalized 2024-05-13; remove after 2024-08-13
|
|
"tie_in": (
|
|
"jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. "
|
|
"Replace z = tie_in(x, y) with z = y.", None,
|
|
),
|
|
}
|
|
|
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
|
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
|
del _deprecation_getattr
|