Merge pull request #8633 from shawwn:2021-11-19/autodidax-fix-jaxpr-subcomp-return-type

PiperOrigin-RevId: 519745476
This commit is contained in:
jax authors 2023-03-27 09:52:20 -07:00
commit af4d4943a7
3 changed files with 3 additions and 3 deletions

View File

@ -2038,7 +2038,7 @@
"outputs": [],
"source": [
"def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]\n",
" ) -> xe.XlaOp:\n",
" ) -> List[xe.XlaOp]:\n",
" env: Dict[Var, xe.XlaOp] = {}\n",
"\n",
" def read(x: Atom) -> xe.XlaOp:\n",

View File

@ -1598,7 +1598,7 @@ compiled program:
```{code-cell}
def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
) -> xe.XlaOp:
) -> List[xe.XlaOp]:
env: Dict[Var, xe.XlaOp] = {}
def read(x: Atom) -> xe.XlaOp:

View File

@ -1592,7 +1592,7 @@ def _xla_shape(aval: ShapedArray) -> xe.Shape:
# +
def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
) -> xe.XlaOp:
) -> List[xe.XlaOp]:
env: Dict[Var, xe.XlaOp] = {}
def read(x: Atom) -> xe.XlaOp: