We insert a ConvertOp as the only use of an input argument in a shape polymorphic
`main` function. This helps the downstream shape refinement because it will set the type
of input arguments to static shapes, and this can invalidate the
module if the argument appears as the result of a function, or if
it appears as the input to a custom_call with output_operand_alias
attribute.
See b/287386268.
The Windows CI currently installs all of the test requirements before building jaxlib, but NumPy is needed to build jaxlib.
Previously this came transitively via matplotlib.
In the Windows CI, we seem to be hitting the following error:
```
=================================== ERRORS ====================================
____________________ ERROR collecting tests/lobpcg_test.py ____________________
tests\lobpcg_test.py:28: in <module>
from matplotlib import pyplot as plt
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\pyplot.py:52: in <module>
import matplotlib.colorbar
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\colorbar.py:19: in <module>
from matplotlib import _api, cbook, collections, cm, colors, contour, ticker
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\contour.py:13: in <module>
from matplotlib.backend_bases import MouseButton
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\backend_bases.py:45: in <module>
from matplotlib import (
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\text.py:16: in <module>
from .font_manager import FontProperties
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:1548: in <module>
fontManager = _load_fontmanager()
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:1543: in _load_fontmanager
json_dump(fm, fm_path)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\font_manager.py:957: in json_dump
with cbook._lock_path(filename), open(filename, 'w') as fh:
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\contextlib.py:119: in __enter__
return next(self.gen)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\site-packages\matplotlib\cbook\__init__.py:1804: in _lock_path
with lock_path.open("xb"):
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\pathlib.py:1252: in open
return io.open(self, mode, buffering, encoding, errors, newline,
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\pathlib.py:1120: in _opener
return self._accessor.open(self, flags, mode)
E PermissionError: [Errno 13] Permission denied: 'C:\\Users\\runneradmin\\.matplotlib\\fontlist-v330.json.matplotlib-lock'
```
The use of matplotlib is only for an optional debugging feature anyway, so just make it an optional dependency.
NumPy is inconsistent between platforms on what it returns for the exponent of an infinite input. On Linux/Mac it returns 0, and on Windows it returns -1. Normalize the test reference result to use 0 in this case.
Use os.replace() for cross-platform renaming with overwriting.
See https://bugs.python.org/issue8828.
Note, per the implementation, it is not atomic on Windows as for UNIX.
JAX shape polymorphism relies on implicit assumptions.
For example, when tracing with input specification `(a, a)`,
we assume that the first two dimensions have the same size
greater or equal to 1.
Here we extend the checking that these assumptions hold. When
we call an `Exported` module from jax, with `jax_export.call_exported`
we check these assumptions statically. However, when we
stage an `Exported` using `XlaCallModule` to be called from
TensorFlow, or when we use TF graph serialization we need
to check these assumptions when we execute and compile
the op (that is when the shapes are available).
To prepare for this compile-time shape checking we add
`Exported.shape_check_module` to produce a serialized
MLIR module containing the shape checking code. This
will be added in a future change to `XlaCallModule`.