Use a dummy PyCodeObject in our dummy PyFrameObjects. Using a real PyCodeObject with a fake PyFrameObject confuses CPython 3.11+ when it attempts to compute the locals of a frame, since the frame lacks certain details such as closure information.
This unfortunately means we will not get source column information under Python 3.11 any more, but that is probably better than crashing.
Fixes https://github.com/google/jax/issues/16027
PiperOrigin-RevId: 538873850
Do not multiply the result of PyFrame_GetLasti() by sizeof(_Py_CODEUNIT), because the CPython implementation already does this inside PyFrame_GetLasti().
* In CPython versions 3.9 or earlier, the f_lasti value in a PyFrameObject was in bytes.
* In CPython 3.10, f_lasti was changed to be in code units, which required multiplying it by sizeof(_Py_CODEUNIT) before passing it to functions like PyCode_Addr2Line(). https://docs.python.org/3/whatsnew/3.10.html#changes-in-the-c-api
* In CPython 3.11, direct access to the representation of the PyFrameObject was removed from the headers, requiring the use of PyFrame_GetLasti() (https://docs.python.org/3/whatsnew/3.11.html#pyframeobject-3-11-hiding). This function multiplies by sizeof(_Py_CODEUNIT) internally (deaf509e8f/Objects/frameobject.c (L1353)) so there is no need for the caller to do this multiplication.
It is difficult to write a good test for this, since the only symptom is slightly inaccurate code line information. This issue was found under a debug mode build of CPython (https://docs.python.org/3/using/configure.html#python-debug-build), where PyCode_Addr2Line() has additional checks for out of range lasti values.
PiperOrigin-RevId: 538847288
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 535248553
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.
I am leaving pmap's flag alone for now.
PiperOrigin-RevId: 522602754
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)
To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
This is another attempt to land a rolled-back change https://github.com/google/jax/pull/14734 (cl/514070997).
See b/272154366 for more details.
The use case for call_tf with shape polymorphism is when we have a JAX program
that calls into TF function, and we want to serialize the JAX program with
some shapes unknown. Previously this use case did not work, except in the special
case when the output shape of the called TF function returns statically known
shapes.
The idea is that we allow the user of call_tf to specify the output shape.
This can be done even in presence of shape polymorphism, by writing the
output shape as an expression in terms of the input shapes. This is what
other JAX primitives do, e.g., concat, so we are simply enabling call_tf
to get the same behavior.
This change should be enough for old-style jax2tf, but will require more
work for native serialization.
We also removed some old code that was trying to workaround some limitations
in shape inference in TF. I think that those workarounds are ugly, and I am
prepared to give error messages rather than keep that code. So far no
tests fail.
PiperOrigin-RevId: 515137407