Fix lax imports

This commit is contained in:
Sharad Vikram 2022-09-26 17:29:08 -07:00
parent 82636b0bcd
commit 1d895b2c85

View File

@ -36,6 +36,8 @@ from jax._src import util
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.lax import lax as lax_internal
from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir.dialects import mhlo
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
@ -60,12 +62,12 @@ def nothing_saveable(*_, **__) -> bool:
def checkpoint_dots(prim, *_, **__) -> bool:
# Matrix multiplies are expensive, so let's save them (and nothing else).
return prim in {lax.lax.dot_general_p,
lax.convolution.conv_general_dilated_p}
return prim in {lax_internal.dot_general_p,
lax_convolution.conv_general_dilated_p}
def dot_with_no_batch_dims(prim, *_, **params) -> bool:
# This is a useful heuristic for transformers.
if prim is lax.lax.dot_general_p:
if prim is lax_internal.dot_general_p:
(_, _), (lhs_b, rhs_b) = params['dimension_numbers']
if not lhs_b and not rhs_b:
return True
@ -439,8 +441,8 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
ad.primitive_jvps[remat_p] = remat_jvp
remat_allowed_effects: Set[core.Effect] = set()
remat_allowed_effects.add(lax.lax.InOutFeedEffect.Infeed)
remat_allowed_effects.add(lax.lax.InOutFeedEffect.Outfeed)
remat_allowed_effects.add(lax_internal.InOutFeedEffect.Infeed)
remat_allowed_effects.add(lax_internal.InOutFeedEffect.Outfeed)
def remat_partial_eval(trace, *tracers, jaxpr, **params):
assert not jaxpr.constvars