mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix minor issue in conj_p translation rule.
conj_p was forwarding its keyword arguments to the ComputationBuilder.Conj() method. The current implementation of Conj() ignores extra keyword arguments, but we shouldn't depend on this kind of implementation detail.
This commit is contained in:
parent
c9aa60102f
commit
fa6b06fd15
@ -1669,6 +1669,7 @@ def _conj_transpose_rule(t, x, input_dtype):
|
||||
else:
|
||||
return [real(t)]
|
||||
|
||||
xla.translations[conj_p] = lambda c, x, **kwargs: c.Conj(x)
|
||||
ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
|
||||
ad.primitive_transposes[conj_p] = _conj_transpose_rule
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user