* Exposes LoadedExecutable.cost_analysis via pybind
* Updates XlaExecutable.cost_analysis to try
LoadedExecutable.cost_analysis, then fallback to the client method.
PiperOrigin-RevId: 542671990
The semantics are as follow:
* if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings
* If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics.
This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit.
PiperOrigin-RevId: 540705660
Previously, we had a boolean `native_serialization_strict_checks` parameter
that was disabling all safety checks. This mechanism had several
disadvantages:
* the mechanism did not differentiate between different safety checks.
E.g., in order to disable checking of the custom call targets, one
had to disable checking for all custom call targets, and also the
checking that the serialization and execution platforms are the same.
* the mechanism operated only at serialization time. Now, the
XlaCallModule supports a `disabled_checks` attribute to control
which safety checks should be disabled.
Here we replace the `native_serialization_strict_checks` with
`native_serialization_disabled_checks`, whose values are sequences
of disabled check descriptors.
Use a dummy PyCodeObject in our dummy PyFrameObjects. Using a real PyCodeObject with a fake PyFrameObject confuses CPython 3.11+ when it attempts to compute the locals of a frame, since the frame lacks certain details such as closure information.
This unfortunately means we will not get source column information under Python 3.11 any more, but that is probably better than crashing.
Fixes https://github.com/google/jax/issues/16027
PiperOrigin-RevId: 538873850
Do not multiply the result of PyFrame_GetLasti() by sizeof(_Py_CODEUNIT), because the CPython implementation already does this inside PyFrame_GetLasti().
* In CPython versions 3.9 or earlier, the f_lasti value in a PyFrameObject was in bytes.
* In CPython 3.10, f_lasti was changed to be in code units, which required multiplying it by sizeof(_Py_CODEUNIT) before passing it to functions like PyCode_Addr2Line(). https://docs.python.org/3/whatsnew/3.10.html#changes-in-the-c-api
* In CPython 3.11, direct access to the representation of the PyFrameObject was removed from the headers, requiring the use of PyFrame_GetLasti() (https://docs.python.org/3/whatsnew/3.11.html#pyframeobject-3-11-hiding). This function multiplies by sizeof(_Py_CODEUNIT) internally (deaf509e8f/Objects/frameobject.c (L1353)) so there is no need for the caller to do this multiplication.
It is difficult to write a good test for this, since the only symptom is slightly inaccurate code line information. This issue was found under a debug mode build of CPython (https://docs.python.org/3/using/configure.html#python-debug-build), where PyCode_Addr2Line() has additional checks for out of range lasti values.
PiperOrigin-RevId: 538847288
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 535248553
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.
I am leaving pmap's flag alone for now.
PiperOrigin-RevId: 522602754
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)
To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.