diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 774953ed9..22b60a6fb 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -17,13 +17,15 @@ from __future__ import annotations +import collections from collections.abc import Callable, Sequence import copy import dataclasses import functools import itertools +import json import re -from typing import Any, Union, cast +from typing import Any, Protocol, TypeVar, Union, cast import warnings from absl import logging @@ -344,6 +346,189 @@ def deserialize(blob: bytearray) -> Exported: return deserialize(blob) +T = TypeVar("T") +PyTreeAuxData = Any # alias for tree_util._AuxData + + +class _SerializeAuxData(Protocol): + def __call__(self, aux_data: PyTreeAuxData) -> bytes: + """Serializes the PyTree node AuxData. + + The AuxData is returned by the `flatten_func` registered by + `tree_util.register_pytree_node`). + """ + + +class _DeserializeAuxData(Protocol): + def __call__(self, serialized_aux_data: bytes) -> PyTreeAuxData: + """Deserializes the PyTree node AuxData. + + The result will be passed to `_BuildFromChildren`. + """ + + +class _BuildFromChildren(Protocol): + def __call__(self, aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any: + """Materializes a T given a deserialized AuxData and children. + + This is similar in scope with the `unflatten_func`. + """ + + +serialization_registry: dict[type, tuple[str, _SerializeAuxData]] = {} + + +deserialization_registry: dict[ + str, + tuple[type, _DeserializeAuxData, _BuildFromChildren]] = {} + + +def _is_namedtuple(nodetype: type) -> bool: + return (issubclass(nodetype, tuple) and + hasattr(nodetype, "_fields") and + isinstance(nodetype._fields, Sequence) and + all(isinstance(f, str) for f in nodetype._fields)) + +def register_pytree_node_serialization( + nodetype: type[T], + *, + serialized_name: str, + serialize_auxdata: _SerializeAuxData, + deserialize_auxdata: _DeserializeAuxData, + from_children: _BuildFromChildren | None = None +) -> type[T]: + """Registers a custom PyTree node for serialization and deserialization. + + You must use this function before you can serialize and deserialize PyTree + nodes for the types not supported natively. We serialize PyTree nodes for + the `in_tree` and `out_tree` fields of `Exported`, which are part of the + exported function's calling convention. + + This function must be called after calling + `jax.tree_util.register_pytree_node` (except for `collections.namedtuple`, + which do not require a call to `register_pytree_node`). + + Args: + nodetype: the type whose PyTree nodes we want to serialize. It is an + error to attempt to register multiple serializations for a `nodetype`. + serialized_name: a string that will be present in the serialization and + will be used to look up the registration during deserialization. It is an + error to attempt to register multiple serializations for a + `serialized_name`. + serialize_auxdata: serialize the PyTree auxdata (returned by the + `flatten_func` argument to `jax.tree_util.register_pytree_node`.). + deserialize_auxdata: deserialize the auxdata that was serialized by the + `serialize_auxdata`. + from_children: if present, this is a function that takes that result of + `deserialize_auxdata` along with some children and creates an instance + of `nodetype`. This is similar to the `unflatten_func` passed to + `jax.tree_util.register_pytree_node`. If not present, we look up + and use the `unflatten_func`. This is needed for `collections.namedtuple`, + which does not have a `register_pytree_node`, but it can be useful to + override that function. Note that the result of `from_children` is + only used with `jax.tree_util.tree_structure` to construct a proper + PyTree node, it is not used to construct the outputs of the serialized + function. + + Returns: + the same type passed as `nodetype`, so that this function can + be used as a class decorator. + """ + if nodetype in serialization_registry: + raise ValueError( + f"Duplicate serialization registration for type `{nodetype}`. " + "Previous registration was with serialized_name " + f"`{serialization_registry[nodetype][0]}`.") + if serialized_name in deserialization_registry: + raise ValueError( + "Duplicate serialization registration for " + f"serialized_name `{serialized_name}`. " + "Previous registration was for type " + f"`{deserialization_registry[serialized_name][0]}`.") + if from_children is None: + if nodetype not in tree_util._registry: + raise ValueError( + f"If `from_children` is not present, you must call first" + f"`jax.tree_util.register_pytree_node` for `{nodetype}`") + from_children = tree_util._registry[nodetype].from_iter + + serialization_registry[nodetype] = ( + serialized_name, serialize_auxdata) + deserialization_registry[serialized_name] = ( + nodetype, deserialize_auxdata, from_children) + return nodetype + + +def register_namedtuple_serialization( + nodetype: type[T], + *, + serialized_name: str) -> type[T]: + """Registers a namedtuple for serialization and deserialization. + + JAX has native PyTree support for `collections.namedtuple`, and does not + require a call to `jax.tree_util.register_pytree_node`. However, if you + want to serialize functions that have inputs of outputs of a + namedtuple type, you must register that type for serialization. + + Args: + nodetype: the type whose PyTree nodes we want to serialize. It is an + error to attempt to register multiple serializations for a `nodetype`. + On deserialization, this type must have the same set of keys that + were present during serialization. + serialized_name: a string that will be present in the serialization and + will be used to look up the registration during deserialization. It is an + error to attempt to register multiple serializations for + a `serialized_name`. + + Returns: + the same type passed as `nodetype`, so that this function can + be used as a class decorator. +""" + if not _is_namedtuple(nodetype): + raise ValueError("Use `jax.export.register_pytree_node_serialization` for " + "types other than `collections.namedtuple`.") + + def serialize_auxdata(aux_data: PyTreeAuxData) -> bytes: + # Store the serialized keys in the serialized auxdata + del aux_data + return json.dumps(nodetype._fields).encode("utf-8") + + def deserialize_auxdata(serialized_aux_data: bytes) -> PyTreeAuxData: + return json.loads(serialized_aux_data.decode("utf-8")) + + def from_children(aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any: + # Use our own "from_children" because namedtuples do not have a pytree + # registration. + ser_keys = cast(Sequence[str], aux_data) + assert len(ser_keys) == len(children) + return nodetype(** dict(zip(ser_keys, children))) + + return register_pytree_node_serialization( + nodetype, + serialized_name=serialized_name, + serialize_auxdata=serialize_auxdata, + deserialize_auxdata=deserialize_auxdata, + from_children=from_children) + + +# collections.OrderedDict is registered as a pytree node with auxdata being +# `tuple(x.keys())`. +def _serialize_ordereddict_keys(keys): + if isinstance(keys, Sequence) and all(isinstance(k, str) for k in keys): + return json.dumps(keys).encode("utf-8") + else: + raise NotImplementedError( + "Serialization of collections.OrderedDict is supported only when the " + f"keys are strings. Found keys: {keys}.") + + +register_pytree_node_serialization( + collections.OrderedDict, + serialized_name="collections.OrderedDict", + serialize_auxdata=_serialize_ordereddict_keys, + deserialize_auxdata=lambda b: json.loads(b.decode("utf-8"))) + + def default_export_platform() -> str: """Retrieves the default export platform. @@ -404,9 +589,10 @@ def export_back_compat( disabled_checks: the safety checks to disable. See docstring of `DisabledSafetyCheck`. - Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct, - or values with `.shape` and `.dtype` attributes, and returns an - `Exported`. + Returns: + a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct, + or values with `.shape` and `.dtype` attributes, and returns an + `Exported`. Usage: @@ -480,9 +666,10 @@ def export( disabled_checks: the safety checks to disable. See documentation for of `jax.export.DisabledSafetyCheck`. - Returns: a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`, - or values with `.shape` and `.dtype` attributes, and returns an - `Exported`. + Returns: + a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`, + or values with `.shape` and `.dtype` attributes, and returns an + `Exported`. Usage: diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 758950ada..b72d0134c 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -28,12 +28,15 @@ enum PyTreeDefKind: byte { tuple = 2, list = 3, dict = 4, + custom = 5, } table PyTreeDef { kind: PyTreeDefKind; children: [PyTreeDef]; - children_names: [string]; // only for "dict" + children_names: [string]; // only for "kind==dict" + custom_name: string; // only for "kind==custom" + custom_auxdata: [byte]; // only for "kind==custom" } enum AbstractValueKind: byte { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index a47b095e4..434c4c5cf 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -16,6 +16,7 @@ from __future__ import annotations +import types from collections.abc import Callable, Sequence from functools import partial from typing import TypeVar @@ -45,6 +46,8 @@ SerT = TypeVar("SerT") # even if the change is backwards compatible. # Version 1, Nov 2023, first version. # Version 2, Dec 16th, 2023, adds the f0 dtype. +# Version 3, October 16th, 2024, adds serialization for namedtuple and custom types +# This version is backwards compatible with Version 2. _SERIALIZATION_VERSION = 2 def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: @@ -152,13 +155,13 @@ def _serialize_array( def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: serialization_version = exp.SerializationVersion() - if serialization_version != _SERIALIZATION_VERSION: + if serialization_version not in [2, 3]: raise NotImplementedError( f"deserialize unsupported version {serialization_version}" ) fun_name = exp.FunctionName().decode("utf-8") - _, in_tree = tree_util.tree_flatten( + in_tree = tree_util.tree_structure( _deserialize_pytreedef_to_pytree(exp.InTree()) ) scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints @@ -166,7 +169,7 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: in_avals = _deserialize_tuple( exp.InAvalsLength, exp.InAvals, deser_aval ) - _, out_tree = tree_util.tree_flatten( + out_tree = tree_util.tree_structure( _deserialize_pytreedef_to_pytree(exp.OutTree()) ) out_avals = _deserialize_tuple( @@ -246,23 +249,40 @@ def _serialize_pytreedef( children_vector_offset = _serialize_array( builder, _serialize_pytreedef, children ) + custom_name = None + custom_auxdata = None + node_type = node_data and node_data[0] if node_data is None: # leaf kind = ser_flatbuf.PyTreeDefKind.leaf - elif node_data[0] is type(None): + elif node_type is types.NoneType: kind = ser_flatbuf.PyTreeDefKind.none - elif node_data[0] is tuple: + elif node_type is tuple: kind = ser_flatbuf.PyTreeDefKind.tuple - elif node_data[0] is list: + elif node_type is list: kind = ser_flatbuf.PyTreeDefKind.list - elif node_data[0] is dict: + elif node_type is dict: kind = ser_flatbuf.PyTreeDefKind.dict assert len(node_data[1]) == len(children) children_names_vector_offset = _serialize_array( builder, lambda b, s: b.CreateString(s), node_data[1] ) + elif node_type in _export.serialization_registry: + kind = ser_flatbuf.PyTreeDefKind.custom + serialized_name, serialize_auxdata = _export.serialization_registry[node_type] + custom_name = builder.CreateString(serialized_name) + serialized_auxdata = serialize_auxdata(node_data[1]) + if not isinstance(serialized_auxdata, (bytes, bytearray)): + raise ValueError( + "The custom serialization function for `node_type` must " + f"return a `bytes` object. It returned a {type(serialized_auxdata)}.") + custom_auxdata = builder.CreateByteVector(serialized_auxdata) else: - raise NotImplementedError(f"serializing PyTreeDef {node_data}") + raise ValueError( + "Cannot serialize PyTreeDef containing an " + f"unregistered type `{node_type}`. " + "Use `export.register_pytree_node_serialization` or " + "`export.register_namedtuple_serialization`.") ser_flatbuf.PyTreeDefStart(builder) ser_flatbuf.PyTreeDefAddKind(builder, kind) @@ -270,6 +290,10 @@ def _serialize_pytreedef( ser_flatbuf.PyTreeDefAddChildren(builder, children_vector_offset) if children_names_vector_offset: ser_flatbuf.PyTreeDefAddChildrenNames(builder, children_names_vector_offset) + if custom_name is not None: + ser_flatbuf.PyTreeDefAddCustomName(builder, custom_name) + if custom_auxdata is not None: + ser_flatbuf.PyTreeDefAddCustomAuxdata(builder, custom_auxdata) return ser_flatbuf.PyTreeDefEnd(builder) @@ -294,6 +318,17 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): assert p.ChildrenNamesLength() == nr_children keys = [p.ChildrenNames(i).decode("utf-8") for i in range(nr_children)] return dict(zip(keys, children)) + elif kind == ser_flatbuf.PyTreeDefKind.custom: + serialized_name = p.CustomName().decode("utf-8") + if serialized_name not in _export.deserialization_registry: + raise ValueError( + "Cannot deserialize a PyTreeDef containing an " + f"unregistered type `{serialized_name}`. " + "Use `export.register_pytree_node_serialization` or " + "`export.register_namedtuple_serialization`.") + nodetype, deserialize_auxdata, from_iter = _export.deserialization_registry[serialized_name] + auxdata = deserialize_auxdata(p.CustomAuxdataAsNumpy().tobytes()) + return from_iter(auxdata, children) else: assert False, kind diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index a872d03a9..18dd2c3cb 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -21,20 +21,21 @@ import flatbuffers from flatbuffers.compat import import_numpy np = import_numpy() -class PyTreeDefKind: +class PyTreeDefKind(object): leaf = 0 none = 1 tuple = 2 list = 3 dict = 4 + custom = 5 -class AbstractValueKind: +class AbstractValueKind(object): shapedArray = 0 abstractToken = 1 -class DType: +class DType(object): bool = 0 i8 = 1 i16 = 2 @@ -60,18 +61,18 @@ class DType: f0 = 22 -class ShardingKind: +class ShardingKind(object): unspecified = 0 hlo_sharding = 1 -class DisabledSafetyCheckKind: +class DisabledSafetyCheckKind(object): platform = 0 custom_call = 1 shape_assertions = 2 -class PyTreeDef: +class PyTreeDef(object): __slots__ = ['_tab'] @classmethod @@ -140,8 +141,42 @@ class PyTreeDef: o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) return o == 0 + # PyTreeDef + def CustomName(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # PyTreeDef + def CustomAuxdata(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # PyTreeDef + def CustomAuxdataAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o) + return 0 + + # PyTreeDef + def CustomAuxdataLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # PyTreeDef + def CustomAuxdataIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + return o == 0 + def PyTreeDefStart(builder): - builder.StartObject(3) + builder.StartObject(5) def PyTreeDefAddKind(builder, kind): builder.PrependInt8Slot(0, kind, 0) @@ -158,12 +193,21 @@ def PyTreeDefAddChildrenNames(builder, childrenNames): def PyTreeDefStartChildrenNamesVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def PyTreeDefAddCustomName(builder, customName): + builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(customName), 0) + +def PyTreeDefAddCustomAuxdata(builder, customAuxdata): + builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(customAuxdata), 0) + +def PyTreeDefStartCustomAuxdataVector(builder, numElems): + return builder.StartVector(1, numElems, 1) + def PyTreeDefEnd(builder): return builder.EndObject() -class AbstractValue: +class AbstractValue(object): __slots__ = ['_tab'] @classmethod @@ -235,7 +279,7 @@ def AbstractValueEnd(builder): -class Sharding: +class Sharding(object): __slots__ = ['_tab'] @classmethod @@ -304,7 +348,7 @@ def ShardingEnd(builder): -class Effect: +class Effect(object): __slots__ = ['_tab'] @classmethod @@ -340,7 +384,7 @@ def EffectEnd(builder): -class DisabledSafetyCheck: +class DisabledSafetyCheck(object): __slots__ = ['_tab'] @classmethod @@ -386,7 +430,7 @@ def DisabledSafetyCheckEnd(builder): -class Exported: +class Exported(object): __slots__ = ['_tab'] @classmethod diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py index b67354bb4..d49aa2963 100644 --- a/jax/experimental/export/__init__.py +++ b/jax/experimental/export/__init__.py @@ -35,8 +35,8 @@ _deprecations = { "call": (_deprecation_message, _src_export.call), "call_exported": (_deprecation_message, _src_export.call_exported), "default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform), - "minimum_supported_serialization_version" : (_deprecation_message, _src_export.minimum_supported_calling_convention_version), - "maximum_supported_serialization_version" : (_deprecation_message, _src_export.maximum_supported_calling_convention_version), + "minimum_supported_serialization_version": (_deprecation_message, _src_export.minimum_supported_calling_convention_version), + "maximum_supported_serialization_version": (_deprecation_message, _src_export.maximum_supported_calling_convention_version), "serialize": (_deprecation_message, _src_serialization.serialize), "deserialize": (_deprecation_message, _src_serialization.deserialize), diff --git a/jax/export.py b/jax/export.py index 13186f886..d9559909f 100644 --- a/jax/export.py +++ b/jax/export.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. __all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize", + "register_pytree_node_serialization", + "register_namedtuple_serialization", "maximum_supported_calling_convention_version", "minimum_supported_calling_convention_version", "default_export_platform", @@ -23,6 +25,8 @@ from jax._src.export._export import ( Exported, export, deserialize, + register_pytree_node_serialization, + register_namedtuple_serialization, maximum_supported_calling_convention_version, minimum_supported_calling_convention_version, default_export_platform) diff --git a/tests/export_test.py b/tests/export_test.py index 0d946d84d..fd6bef11e 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -13,11 +13,13 @@ # limitations under the License. from __future__ import annotations +import collections from collections.abc import Callable, Sequence import contextlib import dataclasses import functools import logging +import json import math import re import unittest @@ -32,6 +34,7 @@ from jax.experimental.shard_map import shard_map from jax.sharding import NamedSharding from jax.sharding import Mesh from jax.sharding import PartitionSpec as P +from jax import tree_util from jax._src import config from jax._src import core @@ -311,6 +314,123 @@ class JaxExportTest(jtu.JaxTestCase): self.assertAllClose(f((a, b), a=a, b=b), exp_f.call((a, b), a=a, b=b)) + def test_pytree_namedtuple(self): + T = collections.namedtuple("SomeType", ("a", "b", "c")) + export.register_namedtuple_serialization( + T, + serialized_name="test_pytree_namedtuple.SomeType", + ) + x = T(a=1, b=2, c=3) + + def f(x): + return (x, x) # return 2 copies, to check that types are shared + + exp = export.export(jax.jit(f))(x) + res = exp.call(x) + self.assertEqual(tree_util.tree_structure(res), + tree_util.tree_structure((x, x))) + self.assertEqual(type(res[0]), type(x)) + self.assertEqual(type(res[1]), type(x)) + ser = exp.serialize() + exp2 = export.deserialize(ser) + self.assertEqual(exp2.in_tree, exp.in_tree) + self.assertEqual(exp2.out_tree, exp.out_tree) + res2 = exp2.call(x) + self.assertEqual(tree_util.tree_structure(res2), + tree_util.tree_structure(res)) + + def test_pytree_namedtuple_error(self): + T = collections.namedtuple("SomeType", ("a", "b")) + x = T(a=1, b=2) + with self.assertRaisesRegex( + ValueError, + "Cannot serialize .* unregistered type .*SomeType"): + export.export(jax.jit(lambda x: x))(x).serialize() + + with self.assertRaisesRegex( + ValueError, + "If `from_children` is not present.* must call.*register_pytree_node" + ): + export.register_pytree_node_serialization( + T, + serialized_name="test_pytree_namedtuple.SomeType_V2", + serialize_auxdata=lambda x: b"", + deserialize_auxdata=lambda b: None + ) + + with self.assertRaisesRegex(ValueError, + "Use .*register_pytree_node_serialization"): + export.register_namedtuple_serialization(str, serialized_name="n/a") + + export.register_namedtuple_serialization( + T, + serialized_name="test_pytree_namedtuple_error.SomeType", + ) + + with self.assertRaisesRegex( + ValueError, + "Duplicate serialization registration .*test_pytree_namedtuple_error.SomeType" + ): + export.register_namedtuple_serialization( + T, + serialized_name="test_pytree_namedtuple_error.OtherType", + ) + + with self.assertRaisesRegex( + ValueError, + "Duplicate serialization registration for serialized_name.*test_pytree_namedtuple_error.SomeType" + ): + export.register_namedtuple_serialization( + collections.namedtuple("SomeOtherType", ("a", "b")), + serialized_name="test_pytree_namedtuple_error.SomeType", + ) + + def test_pytree_custom_types(self): + x1 = collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]) + + @tree_util.register_pytree_node_class + class CustomType: + def __init__(self, a: int, b: CustomType | None, string: str): + self.a = a + self.b = b + self.string = string + + def tree_flatten(self): + return ((self.a, self.b), self.string) + + @classmethod + def tree_unflatten(cls, aux_data, children): + string = aux_data + return cls(*children, string) + + export.register_pytree_node_serialization( + CustomType, + serialized_name="test_pytree_custom_types.CustomType", + serialize_auxdata=lambda aux: aux.encode("utf-8"), + deserialize_auxdata=lambda b: b.decode("utf-8") + ) + x2 = CustomType(4, 5, "foo") + + def f(x1, x2): + return (x1, x2, x1, x2) # return 2 copies, to check that types are shared + + exp = export.export(jax.jit(f))(x1, x2) + res = exp.call(x1, x2) + self.assertEqual(tree_util.tree_structure(res), + tree_util.tree_structure(((x1, x2, x1, x2)))) + self.assertEqual(type(res[0]), type(x1)) + self.assertEqual(type(res[1]), type(x2)) + self.assertEqual(type(res[2]), type(x1)) + self.assertEqual(type(res[3]), type(x2)) + ser = exp.serialize() + exp2 = export.deserialize(ser) + self.assertEqual(exp2.in_tree, exp.in_tree) + self.assertEqual(exp2.out_tree, exp.out_tree) + res2 = exp2.call(x1, x2) + self.assertEqual(tree_util.tree_structure(res2), + tree_util.tree_structure(res)) + + def test_error_wrong_intree(self): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c