mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix the latest jax jaxlib on pypi failure
PiperOrigin-RevId: 507208172
This commit is contained in:
parent
25673316bd
commit
a30ba83db2
@ -1492,7 +1492,9 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
|
||||
if not config.jax_array:
|
||||
# TODO(yashkatariya): After xla_extension_version is bumped to >= 123, make
|
||||
# this condition: `if not config.jax_array`.
|
||||
if (resource_env is not None and xla_extension_version < 123) or not config.jax_array:
|
||||
if num_residuals:
|
||||
in_is_global = _calc_is_global_sequence(
|
||||
known_params['in_positional_semantics'], known_params['in_shardings'])
|
||||
|
Loading…
x
Reference in New Issue
Block a user