mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #8633 from shawwn:2021-11-19/autodidax-fix-jaxpr-subcomp-return-type
PiperOrigin-RevId: 519745476
This commit is contained in:
commit
af4d4943a7
@ -2038,7 +2038,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]\n",
|
||||
" ) -> xe.XlaOp:\n",
|
||||
" ) -> List[xe.XlaOp]:\n",
|
||||
" env: Dict[Var, xe.XlaOp] = {}\n",
|
||||
"\n",
|
||||
" def read(x: Atom) -> xe.XlaOp:\n",
|
||||
|
@ -1598,7 +1598,7 @@ compiled program:
|
||||
|
||||
```{code-cell}
|
||||
def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
|
||||
) -> xe.XlaOp:
|
||||
) -> List[xe.XlaOp]:
|
||||
env: Dict[Var, xe.XlaOp] = {}
|
||||
|
||||
def read(x: Atom) -> xe.XlaOp:
|
||||
|
@ -1592,7 +1592,7 @@ def _xla_shape(aval: ShapedArray) -> xe.Shape:
|
||||
|
||||
# +
|
||||
def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
|
||||
) -> xe.XlaOp:
|
||||
) -> List[xe.XlaOp]:
|
||||
env: Dict[Var, xe.XlaOp] = {}
|
||||
|
||||
def read(x: Atom) -> xe.XlaOp:
|
||||
|
Loading…
x
Reference in New Issue
Block a user