mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix Initializer protocol
This commit is contained in:
parent
99a12ef9ea
commit
5acfc88a00
@ -46,8 +46,8 @@ RealNumeric = Any # Scalar jnp array or float
|
||||
@export
|
||||
@typing.runtime_checkable
|
||||
class Initializer(Protocol):
|
||||
@staticmethod
|
||||
def __call__(key: Array,
|
||||
def __call__(self,
|
||||
key: Array,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Array:
|
||||
raise NotImplementedError
|
||||
|
Loading…
x
Reference in New Issue
Block a user