mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix lax imports
This commit is contained in:
parent
82636b0bcd
commit
1d895b2c85
@ -36,6 +36,8 @@ from jax._src import util
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.api_util import flatten_fun, shaped_abstractify
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import convolution as lax_convolution
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
|
||||
@ -60,12 +62,12 @@ def nothing_saveable(*_, **__) -> bool:
|
||||
|
||||
def checkpoint_dots(prim, *_, **__) -> bool:
|
||||
# Matrix multiplies are expensive, so let's save them (and nothing else).
|
||||
return prim in {lax.lax.dot_general_p,
|
||||
lax.convolution.conv_general_dilated_p}
|
||||
return prim in {lax_internal.dot_general_p,
|
||||
lax_convolution.conv_general_dilated_p}
|
||||
|
||||
def dot_with_no_batch_dims(prim, *_, **params) -> bool:
|
||||
# This is a useful heuristic for transformers.
|
||||
if prim is lax.lax.dot_general_p:
|
||||
if prim is lax_internal.dot_general_p:
|
||||
(_, _), (lhs_b, rhs_b) = params['dimension_numbers']
|
||||
if not lhs_b and not rhs_b:
|
||||
return True
|
||||
@ -439,8 +441,8 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
|
||||
ad.primitive_jvps[remat_p] = remat_jvp
|
||||
|
||||
remat_allowed_effects: Set[core.Effect] = set()
|
||||
remat_allowed_effects.add(lax.lax.InOutFeedEffect.Infeed)
|
||||
remat_allowed_effects.add(lax.lax.InOutFeedEffect.Outfeed)
|
||||
remat_allowed_effects.add(lax_internal.InOutFeedEffect.Infeed)
|
||||
remat_allowed_effects.add(lax_internal.InOutFeedEffect.Outfeed)
|
||||
|
||||
def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
||||
assert not jaxpr.constvars
|
||||
|
Loading…
x
Reference in New Issue
Block a user