[jax] move jax.core to jax._src.core

Re-export roughly all of the same symbols via `jax.core` for now.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 495766963
This commit is contained in:
Roy Frostig 2022-12-15 20:34:43 -08:00 committed by jax authors
parent ecaa215043
commit 523c6f7a53
4 changed files with 3179 additions and 2903 deletions

2904
jax/_src/core.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -2789,7 +2789,7 @@ def _broadcast_in_dim_typecheck_rule(
else:
# TODO(mattjj): perform more checks like _broadcast_in_dim_shape_rule
out_shape = _merge_dyn_shape(shape, dyn_shape)
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape]
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype,
operand.aval.weak_type)
return [out_aval], core.no_effects
@ -3273,7 +3273,7 @@ def _reshape_typecheck_rule(operand, *dyn_shape, new_sizes, dimensions):
else:
# TODO(mattjj, necula): perform more checks like _reshape_shape_rule
out_shape = _merge_dyn_shape(new_sizes, dyn_shape)
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape]
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype,
operand.aval.weak_type)
return [out_aval], core.no_effects
@ -4511,7 +4511,7 @@ def _iota_typecheck_rule(*dyn_shape, dtype, shape, dimension):
return [out_aval], effects
else:
out_shape = _merge_dyn_shape(shape, dyn_shape)
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape]
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
out_aval = core.DShapedArray(tuple(out_shape), dtype, False)
return [out_aval], core.no_effects
core.custom_typechecks[iota_p] = _iota_typecheck_rule

File diff suppressed because it is too large Load Diff

View File

@ -714,7 +714,7 @@ class JaxprTracer(Tracer):
if self.pval.is_known():
return get_referent(self.pval.get_known())
elif isinstance(self.recipe, (FreeVar, ConstVar, Literal)):
return get_referent(self.recipe.val)
return get_referent(self.recipe.val) # pytype: disable=attribute-error
else:
return self
@ -2406,11 +2406,11 @@ def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue:
def _is_bint_axis_size(d: Union[int, core.DArray, core.Var]) -> bool:
if isinstance(d, core.DArray):
assert not d.shape
return type(d.dtype) is core.bint
assert not d.shape # pytype: disable=attribute-error
return type(d.dtype) is core.bint # pytype: disable=attribute-error
elif isinstance(d, core.Var):
return (isinstance(d.aval, core.DShapedArray) and
type(d.aval.dtype) is core.bint)
return (isinstance(d.aval, core.DShapedArray) and # pytype: disable=attribute-error
type(d.aval.dtype) is core.bint) # pytype: disable=attribute-error
return False