Fix the latest jax jaxlib on pypi failure

PiperOrigin-RevId: 507208172
This commit is contained in:
Yash Katariya 2023-02-04 20:15:56 -08:00 committed by jax authors
parent 25673316bd
commit a30ba83db2

View File

@ -1492,7 +1492,9 @@ def _pjit_partial_eval(trace, *in_tracers,
keep_unused=keep_unused,
inline=inline)
if not config.jax_array:
# TODO(yashkatariya): After xla_extension_version is bumped to >= 123, make
# this condition: `if not config.jax_array`.
if (resource_env is not None and xla_extension_version < 123) or not config.jax_array:
if num_residuals:
in_is_global = _calc_is_global_sequence(
known_params['in_positional_semantics'], known_params['in_shardings'])