diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index cfcaeb84f..442ee1b6e 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -2038,7 +2038,7 @@ "outputs": [], "source": [ "def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]\n", - " ) -> xe.XlaOp:\n", + " ) -> List[xe.XlaOp]:\n", " env: Dict[Var, xe.XlaOp] = {}\n", "\n", " def read(x: Atom) -> xe.XlaOp:\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 130f1994e..3bad5aed0 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -1598,7 +1598,7 @@ compiled program: ```{code-cell} def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp] - ) -> xe.XlaOp: + ) -> List[xe.XlaOp]: env: Dict[Var, xe.XlaOp] = {} def read(x: Atom) -> xe.XlaOp: diff --git a/docs/autodidax.py b/docs/autodidax.py index 2a1e2ceee..6c42c349f 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -1592,7 +1592,7 @@ def _xla_shape(aval: ShapedArray) -> xe.Shape: # + def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp] - ) -> xe.XlaOp: + ) -> List[xe.XlaOp]: env: Dict[Var, xe.XlaOp] = {} def read(x: Atom) -> xe.XlaOp: