mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[jax] move jax.core
to jax._src.core
Re-export roughly all of the same symbols via `jax.core` for now. Co-authored-by: Sharad Vikram <sharadmv@google.com> PiperOrigin-RevId: 495766963
This commit is contained in:
parent
ecaa215043
commit
523c6f7a53
2904
jax/_src/core.py
Normal file
2904
jax/_src/core.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -2789,7 +2789,7 @@ def _broadcast_in_dim_typecheck_rule(
|
||||
else:
|
||||
# TODO(mattjj): perform more checks like _broadcast_in_dim_shape_rule
|
||||
out_shape = _merge_dyn_shape(shape, dyn_shape)
|
||||
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape]
|
||||
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
|
||||
out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype,
|
||||
operand.aval.weak_type)
|
||||
return [out_aval], core.no_effects
|
||||
@ -3273,7 +3273,7 @@ def _reshape_typecheck_rule(operand, *dyn_shape, new_sizes, dimensions):
|
||||
else:
|
||||
# TODO(mattjj, necula): perform more checks like _reshape_shape_rule
|
||||
out_shape = _merge_dyn_shape(new_sizes, dyn_shape)
|
||||
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape]
|
||||
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
|
||||
out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype,
|
||||
operand.aval.weak_type)
|
||||
return [out_aval], core.no_effects
|
||||
@ -4511,7 +4511,7 @@ def _iota_typecheck_rule(*dyn_shape, dtype, shape, dimension):
|
||||
return [out_aval], effects
|
||||
else:
|
||||
out_shape = _merge_dyn_shape(shape, dyn_shape)
|
||||
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape]
|
||||
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
|
||||
out_aval = core.DShapedArray(tuple(out_shape), dtype, False)
|
||||
return [out_aval], core.no_effects
|
||||
core.custom_typechecks[iota_p] = _iota_typecheck_rule
|
||||
|
3162
jax/core.py
3162
jax/core.py
File diff suppressed because it is too large
Load Diff
@ -714,7 +714,7 @@ class JaxprTracer(Tracer):
|
||||
if self.pval.is_known():
|
||||
return get_referent(self.pval.get_known())
|
||||
elif isinstance(self.recipe, (FreeVar, ConstVar, Literal)):
|
||||
return get_referent(self.recipe.val)
|
||||
return get_referent(self.recipe.val) # pytype: disable=attribute-error
|
||||
else:
|
||||
return self
|
||||
|
||||
@ -2406,11 +2406,11 @@ def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue:
|
||||
|
||||
def _is_bint_axis_size(d: Union[int, core.DArray, core.Var]) -> bool:
|
||||
if isinstance(d, core.DArray):
|
||||
assert not d.shape
|
||||
return type(d.dtype) is core.bint
|
||||
assert not d.shape # pytype: disable=attribute-error
|
||||
return type(d.dtype) is core.bint # pytype: disable=attribute-error
|
||||
elif isinstance(d, core.Var):
|
||||
return (isinstance(d.aval, core.DShapedArray) and
|
||||
type(d.aval.dtype) is core.bint)
|
||||
return (isinstance(d.aval, core.DShapedArray) and # pytype: disable=attribute-error
|
||||
type(d.aval.dtype) is core.bint) # pytype: disable=attribute-error
|
||||
return False
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user