mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Pass x
into transpose in autodidax
This commit is contained in:
parent
d349086ca5
commit
0fa70084a3
@ -123,7 +123,7 @@
|
||||
"def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis)\n",
|
||||
"def greater(x, y): return bind1(greater_p, x, y)\n",
|
||||
"def less(x, y): return bind1(less_p, x, y)\n",
|
||||
"def transpose(x, perm): return bind1(transpose_p, perm=perm)\n",
|
||||
"def transpose(x, perm): return bind1(transpose_p, x, perm=perm)\n",
|
||||
"def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)\n",
|
||||
"\n",
|
||||
"def bind1(prim, *args, **params):\n",
|
||||
|
@ -107,7 +107,7 @@ def cos(x): return bind1(cos_p, x)
|
||||
def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis)
|
||||
def greater(x, y): return bind1(greater_p, x, y)
|
||||
def less(x, y): return bind1(less_p, x, y)
|
||||
def transpose(x, perm): return bind1(transpose_p, perm=perm)
|
||||
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
|
||||
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
|
||||
|
||||
def bind1(prim, *args, **params):
|
||||
|
@ -97,7 +97,7 @@ def cos(x): return bind1(cos_p, x)
|
||||
def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis)
|
||||
def greater(x, y): return bind1(greater_p, x, y)
|
||||
def less(x, y): return bind1(less_p, x, y)
|
||||
def transpose(x, perm): return bind1(transpose_p, perm=perm)
|
||||
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
|
||||
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
|
||||
|
||||
def bind1(prim, *args, **params):
|
||||
|
Loading…
x
Reference in New Issue
Block a user