diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py
index 9878ddc3c..99760099d 100644
--- a/jax/_src/lax/lax.py
+++ b/jax/_src/lax/lax.py
@@ -16,6 +16,7 @@ from __future__ import annotations
 
 import builtins
 from collections.abc import Callable, Sequence
+import dataclasses
 import enum
 import functools
 from functools import partial
@@ -2227,10 +2228,122 @@ def ragged_dot(
   Results:
     (m, n) shaped array with preferred_element_type element type.
   """
-  return ragged_dot_p.bind(lhs, rhs, group_sizes,
-                            precision=canonicalize_precision(precision),
-                            preferred_element_type=preferred_element_type,
-                            group_offset=group_offset)
+  return ragged_dot_general(
+      lhs,
+      rhs,
+      group_sizes,
+      ragged_dot_dimension_numbers=_BASIC_RAGGED_DOT_DIMENSION_NUMBERS,
+      precision=canonicalize_precision(precision),
+      preferred_element_type=preferred_element_type,
+      group_offset=group_offset,
+  )
+
+
+@dataclasses.dataclass(frozen=True)
+class RaggedDotDimensionNumbers():
+  """Describes ragged, group, and dot dimensions for ragged dot general.
+
+  Args:
+    dot_dimension_numbers: a tuple of tuples of sequences of ints of the form
+      `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
+      rhs_batch_dims))`.
+    lhs_ragged_dimensions: a sequence of ints indicating the 'lhs' ragged
+      dimensions.
+    rhs_group_dimensions: a sequence of ints indicating the 'rhs' group
+      dimensions.
+  """
+  dot_dimension_numbers: DotDimensionNumbers
+  lhs_ragged_dimensions: Sequence[int]
+  rhs_group_dimensions: Sequence[int]
+
+  def __init__(
+      self, dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions
+  ):
+    super().__setattr__(
+        'dot_dimension_numbers',
+        tuple(tuple(map(tuple, t)) for t in dot_dimension_numbers),
+    )
+    super().__setattr__('lhs_ragged_dimensions', tuple(lhs_ragged_dimensions))
+    super().__setattr__('rhs_group_dimensions', tuple(rhs_group_dimensions))
+
+
+def _from_maybe_ragged(
+    dot_dimension_numbers: RaggedDotDimensionNumbers | DotDimensionNumbers,
+) -> DotDimensionNumbers:
+  return (
+      dot_dimension_numbers.dot_dimension_numbers
+      if isinstance(dot_dimension_numbers, RaggedDotDimensionNumbers)
+      else dot_dimension_numbers
+  )
+
+
+# RaggedDotDimensionNumbers that specify the simple case (i.e., lax.ragged_dot.)
+_BASIC_RAGGED_DOT_DIMENSION_NUMBERS = RaggedDotDimensionNumbers(
+    dot_dimension_numbers=(([1], [1]), ([], [])),
+    lhs_ragged_dimensions=[0],
+    rhs_group_dimensions=[0],
+)
+
+
+def ragged_dot_general(
+    lhs: Array,
+    rhs: Array,
+    group_sizes: Array,
+    ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
+    precision: PrecisionLike = None,
+    preferred_element_type: DTypeLike | None = None,
+    group_offset: Array | None = None,
+) -> Array:
+  """Ragged matrix multiplication.
+
+  Ragged dot takes three arrays---``lhs``, ``rhs``, and ``group_sizes``---and
+  a ``ragged_dot_dimension_numbers`` argument. Like `dot_general`, ``lhs`` and
+  ``rhs`` are allowed arbitrary batch and contracting dimensions. Additionally,
+  ``lhs`` is required to have one ragged dimension, and ``rhs`` may have at
+  most one group dimension.
+
+  Let `g` be the number of groups in the lhs ragged dimension. Ragged dot has
+  three modes, depending on the kind of the lhs ragged dimension:
+  1. `[b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]`.
+     Here the ragged dimension is a non-contracting dimension (`m`) of ``lhs``,
+     and `x...` are the lhs non-contracting dims outer to the ragged dim.
+  2. `[b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]`.
+     Here the ragged dimension is a contracting dimension (`k`) of ``lhs`` and
+     ``rhs``, and `x...` are the lhs contracting dims outer to the ragged dim.
+  3. `[b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]`.
+     Here the ragged dimension is a batch dimension (`b`) of ``lhs`` and
+     ``rhs``, and `x...` are the lhs batch dims outer to the ragged dim.
+  If ``group_sizes`` is passed-in with shape `[g]`, it is broadcasted according
+  to the rules above.
+
+  Args:
+    lhs: an array
+    rhs: an array
+    group_sizes: an array with integer element type
+    ragged_dot_dimension_numbers: a ``RaggedDotDimensionNumbers`` object to
+      specify the dot dimension numbers, lhs ragged dimension, and rhs group
+      dimension.
+    precision: Optional. Consistent with precision argument for
+      :func:`jax.lax.dot`.
+    preferred_element_type: Optional. Consistent with precision argument for
+      :func:`jax.lax.dot`.
+    group_offset: Optional. (1,) shaped array that indicates the group in
+      group_sizes to start computing from. If not specified, defaults to [0].
+
+  Results:
+    An array whose shape is the same as that produced by `dot_general`, with an
+    extra leading dimension of size `g` in the case where the lhs ragged
+    dimension is a contracting dimension.
+  """
+  return ragged_dot_general_p.bind(
+      lhs,
+      rhs,
+      group_sizes,
+      ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
+      precision=canonicalize_precision(precision),
+      preferred_element_type=preferred_element_type,
+      group_offset=group_offset,
+  )
 
 
 def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None
@@ -4593,7 +4706,7 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
                             out_sharding):
   if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
     raise NotImplementedError
-  (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
+  (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
   if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
              for d in (lhs_contracting, lhs_batch)):
     msg = ("dot_general requires lhs dimension numbers to be nonnegative and "
@@ -4654,12 +4767,17 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
   return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)
 
 def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers):
-  (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
+  (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
   batch_shape = tuple(lhs_shape[i] for i in lhs_batch)
   lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
   lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch)
-  rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
-  rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch)
+  rhs_group = ()
+  if isinstance(dimension_numbers, RaggedDotDimensionNumbers):
+    rhs_group = tuple(dimension_numbers.rhs_group_dimensions)
+  rhs_contract_or_batch_or_group = tuple(
+      sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group)
+  )
+  rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch_or_group)
   return batch_shape + lhs_tensored_shape + rhs_tensored_shape
 
 
@@ -4723,7 +4841,7 @@ def tuple_delete(tup, idx):
 
 def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
                             preferred_element_type: DTypeLike | None,
-                            out_sharding):
+                            out_sharding, name: str = 'lax.dot_general'):
   if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
     raise NotImplementedError
   del dimension_numbers  # unused
@@ -4744,8 +4862,7 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
     result_dtype = rhs.dtype
   else:
     if lhs.dtype != rhs.dtype:
-      raise TypeError(
-          f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}")
+      raise TypeError(f'{name} argument type error: {lhs.dtype}, {rhs.dtype}')
     result_dtype = lhs.dtype
   has_algorithm = isinstance(precision, (DotAlgorithm, DotAlgorithmPreset))
   return _maybe_upcast(result_dtype, preferred_element_type,
@@ -4884,8 +5001,9 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
   # explicitly present dimensions that this dot_general is zipping together.
   lbd, rbd = batch_dims
   assert lbd is not None or rbd is not None
-  (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
+  (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
 
+  is_ragged_dot = isinstance(dimension_numbers, RaggedDotDimensionNumbers)
   def bump_dims(dims, b):
     return tuple(np.add(dims, np.greater_equal(dims, b)))
 
@@ -4908,8 +5026,14 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
   elif (type(rbd) is int and lbd is None):
     # The right vmapped dimension becomes an additional tensor dimension in the
     # batched dot_general.
-    rhs_tensor = [d for d in range(rhs_ndim)
-                  if d not in rhs_batch and d not in rhs_contract]
+    rhs_tensor = list(
+        remaining(
+            range(rhs_ndim),
+            rhs_batch,
+            rhs_contract,
+            dimension_numbers.rhs_group_dimensions if is_ragged_dot else [],
+        )
+    )
     result_batch_dim = (lhs_ndim - len(lhs_contract) +
                         int(sum(np.less(rhs_tensor, rbd))))
     rhs_batch = bump_dims(rhs_batch, rbd)
@@ -4919,6 +5043,16 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
     assert False
 
   new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
+  if is_ragged_dot:
+    new_dimension_numbers = RaggedDotDimensionNumbers(
+        dot_dimension_numbers=new_dimension_numbers,
+        lhs_ragged_dimensions=bump_dims(
+            dimension_numbers.lhs_ragged_dimensions, lbd
+        ),
+        rhs_group_dimensions=bump_dims(
+            dimension_numbers.rhs_group_dimensions, rbd
+        ),
+    )
   return new_dimension_numbers, result_batch_dim
 
 def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *,
@@ -5010,15 +5144,6 @@ def _dot_general_batch_unpack_dims(batch_dims):
   lbd, rbd = batch_dims
   return (lbd, rbd)
 
-# DotDimensionNumbers used in the dot_general call for ragged_dot().
-_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
-    ([2, 0], [1, 0]),
-    ([], []),
-)
-_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
-    ([3, 1], [2, 1]),
-    ([0], [0]),
-)
 
 ad.defbilinear(dot_general_p,
                _dot_general_transpose_lhs, _dot_general_transpose_rhs)
@@ -5186,58 +5311,181 @@ for platform in ["cpu", "tpu"]:
                          platform=platform)
 
 
-def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape:
-  if len(lhs.shape) == 3:
-    # Batched case
-    b, m, k = lhs.shape
-    b2, group_count, rk, n = rhs.shape
-    b3 = group_sizes.shape[0]
-    if b != b2:
-      raise TypeError(
-          f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and'
-          f' {b2}.'
-      )
-    if b3 != b:
-      raise TypeError(
-          'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got'
-          f' {b3} and {b}.'
-      )
-    if k != rk:
-      raise TypeError(
-          f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and'
-          f' {rk}.'
-      )
-    num_groups = group_sizes.shape[1]
-    if group_count != num_groups:
-      raise TypeError(
-          'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got'
-          f' {group_count} and {num_groups}.'
-      )
-    return (b, m, n)
+class RaggedDotMode(enum.Enum):
+  RAGGED_NONCONTRACTING = 1  # [b,m,k], [g,b,k,n], [b,g] -> [b,m,n]
+  RAGGED_CONTRACTING = 2  #    [b,m,k], [b,k,n],   [b,g] -> [g,b,m,n]
+  RAGGED_BATCH = 3  #          [b,m,k], [b,k,n],   [g]   -> [b,m,n]
 
-  m, k = lhs.shape
-  group_count, rk, n = rhs.shape
-  if k != rk:
-    raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.")
-  num_groups = group_sizes.shape[0]
-  if group_count != num_groups:
-    raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.")
-  return (m, n)
 
-def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
-                           precision, preferred_element_type: DTypeLike | None,
-                           **_) -> np.dtype:
+def _ragged_dot_mode_and_dim(
+    lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
+) -> tuple[RaggedDotMode, int]:
+  assert len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) == 1
+  lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0]
+  (lhs_contracting, _), (lhs_batch, _) = ragged_dot_dimension_numbers.dot_dimension_numbers
+  lhs_noncontracting = remaining(range(lhs_rank), lhs_contracting, lhs_batch)
+  if lhs_ragged_dim in lhs_noncontracting:
+    mode = RaggedDotMode.RAGGED_NONCONTRACTING
+  elif lhs_ragged_dim in lhs_contracting:
+    mode = RaggedDotMode.RAGGED_CONTRACTING
+  elif lhs_ragged_dim in lhs_batch:
+    mode = RaggedDotMode.RAGGED_BATCH
+  else:
+    raise TypeError(
+        f'lhs_ragged_dim {lhs_ragged_dim} not found in '
+        f'lhs_noncontracting {lhs_noncontracting}, '
+        f'lhs_contracting {lhs_contracting}, or '
+        f'lhs_batch {lhs_batch}.'
+    )
+  return mode, lhs_ragged_dim
+
+
+def _ragged_dot_mode(
+    lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
+) -> RaggedDotMode:
+  return _ragged_dot_mode_and_dim(lhs_rank, ragged_dot_dimension_numbers)[0]
+
+
+def _is_ragged_contracting(
+    lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
+) -> bool:
+  return (
+      _ragged_dot_mode(lhs_rank, ragged_dot_dimension_numbers)
+      == RaggedDotMode.RAGGED_CONTRACTING
+  )
+
+
+def _ragged_dot_prefix_dims(mode, rank, ragged_dim, batch, contract):
+  batch, contract = map(list, (batch, contract))
+  noncontract = remaining(range(rank), contract, batch)
+  match mode:
+    case RaggedDotMode.RAGGED_NONCONTRACTING:
+      return batch + noncontract[: noncontract.index(ragged_dim)]
+    case RaggedDotMode.RAGGED_CONTRACTING:
+      return batch + contract[: contract.index(ragged_dim)]
+    case RaggedDotMode.RAGGED_BATCH:
+      return batch[: batch.index(ragged_dim)]
+
+
+def _ragged_dot_general_shape_rule(
+    lhs,
+    rhs,
+    group_sizes,
+    *,
+    ragged_dot_dimension_numbers,
+    precision,
+    preferred_element_type: DTypeLike | None,
+    **_,
+):
+  def _check_in_range(dim, rank, dim_name, arg_name):
+    if dim < 0 or dim >= rank:
+      raise TypeError(
+          f'ragged_dot_general requires {dim_name} numbers to be nonnegative '
+          f'and less than the number of axes of the {arg_name} value, '
+          f'got {dim} for {arg_name} of rank {rank}.'
+      )
+
+  # Validate the lhs ragged dimension, and find out which mode we're in.
+  if len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) != 1:
+    raise TypeError(
+        'ragged_dot_general expects exactly one lhs ragged dimension.'
+    )
+  lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0]
+  _check_in_range(lhs_ragged_dim, lhs.ndim, 'lhs ragged dimension', 'lhs')
+  mode = _ragged_dot_mode(lhs.ndim, ragged_dot_dimension_numbers)
+
+  (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = (
+      ragged_dot_dimension_numbers.dot_dimension_numbers
+  )
+
+  # Validate the shape of group_sizes, if it is something other than [g].
+  if group_sizes.ndim == 0:
+    raise TypeError('expected rank of group_sizes to be >=1.')
+  if group_sizes.ndim != 1:
+    # Construct the expected shape [b...,x...,g] of group_sizes.
+    prefix_dims = _ragged_dot_prefix_dims(
+        mode, lhs.ndim, lhs_ragged_dim, lhs_batch, lhs_contracting
+    )
+    expected_gs_shape = tuple(lhs.shape[i] for i in prefix_dims)
+    expected_gs_shape += (group_sizes.shape[-1],)
+    # TODO(pravnar): Permit other broadcastable shapes.
+    if not core.definitely_equal_shape(group_sizes.shape, expected_gs_shape):
+      raise TypeError(
+          'expected group_sizes to have shape '
+          f'{expected_gs_shape}, got {group_sizes.shape}.'
+      )
+  num_groups = group_sizes.shape[-1]
+
+  # Validate properties of the rhs group dimension(s).
+  rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions
+  match mode:
+    case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
+      if len(rhs_group_dims) != 0:
+        raise TypeError(
+            'ragged_dot_general requires zero group dimensions in the rhs '
+            'when lhs ragged dimension is contracting or batch.'
+        )
+    case RaggedDotMode.RAGGED_NONCONTRACTING:
+      if len(rhs_group_dims) != 1:
+        raise TypeError(
+            'ragged_dot_general requires exactly one rhs group dimension '
+            'when lhs ragged dimension is noncontracting.'
+        )
+      rhs_group_dim = rhs_group_dims[0]
+      _check_in_range(rhs_group_dim, rhs.ndim, 'rhs group dimension', 'rhs')
+      if rhs_group_dim in rhs_batch or rhs_group_dim in rhs_contracting:
+        raise TypeError(
+            'ragged_dot_general requires rhs group dimension numbers to be '
+            'distinct from contracting and batch dimensions.'
+        )
+      if rhs.shape[rhs_group_dim] != num_groups:
+        raise TypeError(
+            'expected rhs group dimension size to be '
+            f'{num_groups}, got {rhs.shape[rhs_group_dim]}.'
+        )
+
+  out_shape = _dot_general_shape_rule(
+      lhs,
+      rhs,
+      dimension_numbers=ragged_dot_dimension_numbers,
+      precision=precision,
+      preferred_element_type=preferred_element_type,
+      out_sharding=None,
+  )
+  if mode == RaggedDotMode.RAGGED_CONTRACTING:
+    out_shape = (num_groups,) + out_shape
+  return out_shape
+
+
+def _ragged_dot_general_dtype_rule(
+    lhs: Array,
+    rhs: Array,
+    group_sizes: Array,
+    ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
+    precision,
+    preferred_element_type: DTypeLike | None,
+    **_,
+) -> np.dtype:
   if not dtypes.issubdtype(group_sizes.dtype, np.integer):
-    raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
-  # defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
+    raise TypeError(
+        'ragged_dot_general requires that '
+        'group_sizes.dtype is subtype of np.integer.'
+    )
+  # defer the output dtype to dot_general, which is part of the _ragged_dot_general_impl.
   return _dot_general_dtype_rule(
-      lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
-      precision=precision, preferred_element_type=preferred_element_type,
-      out_sharding=None)
+      lhs,
+      rhs,
+      dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
+      precision=precision,
+      preferred_element_type=preferred_element_type,
+      out_sharding=None,
+      name='lax.ragged_dot_general',
+  )
 
 
-def _ragged_dot_jvp_rule(
-    primals, tangents, precision, preferred_element_type, group_offset
+def _ragged_dot_general_jvp_rule(
+    primals, tangents, ragged_dot_dimension_numbers,
+    precision, preferred_element_type, group_offset
 ):
   # note - we could ostensibly just get this by passing on the
   # value to ragged_dot below, but, this feels cleaner.
@@ -5247,20 +5495,22 @@ def _ragged_dot_jvp_rule(
   dx, dy, _ = tangents  # no tan on the gs
 
   # primal
-  primal_out = ragged_dot(
+  primal_out = ragged_dot_general(
       x,
       y,
       gs,
+      ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
       precision=precision,
       preferred_element_type=preferred_element_type,
   )
 
   # tangent
   dx_out = (
-      ragged_dot(
+      ragged_dot_general(
           dx,
           y,
           gs,
+          ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
           precision=precision,
           preferred_element_type=preferred_element_type,
       )
@@ -5268,10 +5518,11 @@ def _ragged_dot_jvp_rule(
       else _zeros(primal_out)
   )
   dy_out = (
-      ragged_dot(
+      ragged_dot_general(
           x,
           dy,
           gs,
+          ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
           precision=precision,
           preferred_element_type=preferred_element_type,
       )
@@ -5283,58 +5534,111 @@ def _ragged_dot_jvp_rule(
   return primal_out, tangent_out
 
 
-def _ragged_to_dense(x, y, group_sizes):
-  from jax._src.lax import control_flow  # avoid circular imports
-  shape = (y.shape[0], x.shape[0], x.shape[1])
-  x = broadcast_in_dim(x, shape, [1, 2])
-  iota = broadcasted_iota(group_sizes.dtype, shape, 1)
-  group_ends = control_flow.cumsum(group_sizes)
-  group_starts = concatenate(
-      [_zeros(group_sizes)[:1], group_ends[:-1]],
-      dimension=0,
-  )
-  group_ends = broadcast_in_dim(group_ends, shape, (0,))
-  group_starts = broadcast_in_dim(group_starts, shape, (0,))
-  mask = bitwise_and(group_starts <= iota, iota < group_ends)
-  x = select(mask, x, _zeros(x))
-  return x
-
-
-def _ragged_dot_transpose_rule(
-    ct, *operands, precision, preferred_element_type, group_offset
+def _ragged_dot_general_transpose_rule(
+    ct,
+    x,
+    y,
+    group_sizes,
+    *,
+    ragged_dot_dimension_numbers,
+    precision,
+    preferred_element_type: DTypeLike | None,
+    group_offset: Array | None,
 ):
-  x, y, gs = operands
   if group_offset is not None:
     raise NotImplementedError('Unimplemented group_offset support.')
 
-  if ad.is_undefined_primal(y):
-    grad_x = None
-  else:
-    y_t = _matrix_transpose(y)
-    grad_x = ragged_dot(
-        ct,
-        y_t,
-        gs,
-        precision=precision,
-        preferred_element_type=preferred_element_type,
-    )
+  (x_contract, y_contract), (x_batch, y_batch) = ragged_dot_dimension_numbers.dot_dimension_numbers
+  x_ndim = x.aval.ndim if ad.is_undefined_primal(x) else np.ndim(x)
+  y_ndim = y.aval.ndim if ad.is_undefined_primal(y) else np.ndim(y)
+  x_kept = remaining(range(x_ndim), x_contract, x_batch)
+  y_group = ragged_dot_dimension_numbers.rhs_group_dimensions
+  y_kept = remaining(range(y_ndim), y_contract, y_batch, y_group)
+  mode, lhs_ragged_dim = _ragged_dot_mode_and_dim(
+      x_ndim, ragged_dot_dimension_numbers
+  )
 
-  if ad.is_undefined_primal(x):
-    grad_y = None
-  else:
-    y = y.aval if ad.is_undefined_primal(y) else y
-    x_dense = _ragged_to_dense(x, y, group_sizes=gs)
-    ct_dense = _ragged_to_dense(ct, y, group_sizes=gs)
-    dimension_numbers = (([1], [1]), ([0], [0]))
-    grad_y = dot_general(
-        x_dense,
-        ct_dense,
-        dimension_numbers,
-        precision=precision,
-        preferred_element_type=preferred_element_type,
-    )
+  unimplemented = lambda fn_name, ragged_dot_mode: NotImplementedError(
+      f'Unimplemented {fn_name} for ragged dot general in mode '
+      f'{ragged_dot_mode.name}.'
+  )
 
-  return grad_x, grad_y, None
+  # This is a hack to ensure we continue to emit the `_matrix_transpose` for the
+  # grad_x case. This isn't strictly necessary since we have dot_dim_nums.
+  # TODO(pravnar): Remove this once we no longer care to emit the transpose.
+  _is_basic_ragged_dot = (
+      x_ndim == 2
+      and y_ndim == 3
+      and ragged_dot_dimension_numbers == _BASIC_RAGGED_DOT_DIMENSION_NUMBERS
+  )
+
+  def grad_x_dims():
+    match mode:
+      case RaggedDotMode.RAGGED_NONCONTRACTING:
+        ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
+        dims = (
+            ragged_dot_dimension_numbers
+            if _is_basic_ragged_dot
+            else RaggedDotDimensionNumbers(
+                dot_dimension_numbers=((ans_y, y_kept), (ans_batch, y_batch)),
+                lhs_ragged_dimensions=[
+                    len(x_batch) + x_kept.index(lhs_ragged_dim)
+                ],
+                rhs_group_dimensions=y_group,
+            )
+        )
+        x_contract_sorted_by_y = list(
+            np.take(x_contract, np.argsort(y_contract))
+        )
+        unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y
+      case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
+        raise unimplemented('grad_x_dims', mode)
+    return dims, unsorted_axes
+
+  def grad_y_dims():
+    match mode:
+      case RaggedDotMode.RAGGED_NONCONTRACTING:
+        ans_batch, ans_x, _ = ranges_like(x_batch, x_kept, y_kept)
+        dims = RaggedDotDimensionNumbers(
+            dot_dimension_numbers=((x_kept, ans_x), (x_batch, ans_batch)),
+            lhs_ragged_dimensions=[lhs_ragged_dim],
+            rhs_group_dimensions=[],
+        )
+        y_contract_sorted_by_x = list(
+            np.take(y_contract, np.argsort(x_contract))
+        )
+        unsorted_axes = (
+            list(y_group) + list(y_batch) + y_contract_sorted_by_x + y_kept
+        )
+      case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
+        raise unimplemented('grad_y_dims', mode)
+    return dims, unsorted_axes
+
+  def _ragged_dot_grad(lhs, rhs, dims_fn, aval):
+    dims, unsorted_axes = dims_fn()
+    ragged_dot_general_out = ragged_dot_general(
+          lhs, rhs, group_sizes, dims, precision=precision,
+          preferred_element_type=preferred_element_type,
+          group_offset=group_offset)
+    result = transpose(ragged_dot_general_out, tuple(np.argsort(unsorted_axes)))
+    if result.dtype != aval.dtype:
+      result = _convert_element_type(result, aval.dtype, aval.weak_type)
+    return result
+
+  x_bar = (
+      None
+      if ad.is_undefined_primal(y)
+      else _ragged_dot_grad(ct,
+                            _matrix_transpose(y) if _is_basic_ragged_dot else y,
+                            grad_x_dims,
+                            x.aval)
+  )
+  y_bar = (
+      None
+      if ad.is_undefined_primal(x)
+      else _ragged_dot_grad(x, ct, grad_y_dims, y.aval)
+  )
+  return x_bar, y_bar, None
 
 
 def _ragged_dot_batch_unpack_args(batched_args):
@@ -5349,62 +5653,71 @@ def _ragged_dot_batch_unpack_dims(batch_dims):
   return (lbd, rbd)
 
 
-def _ragged_dot_invoke_prim(
+def _ragged_dot_general_invoke_prim(
     group_sizes,
     lhs,
     rhs,
-    new_dimension_numbers,
+    new_ragged_dot_dimension_numbers,
     precision,
     preferred_element_type,
     out_sharding,
 ):
   del out_sharding
-  return ragged_dot(
+  return ragged_dot_general(
       lhs,
       rhs,
       group_sizes,
+      ragged_dot_dimension_numbers=new_ragged_dot_dimension_numbers,
       precision=precision,
       preferred_element_type=preferred_element_type,
   )
 
 
-def _ragged_dot_batch_rule(
+def _ragged_dot_general_batch_rule(
     axis_data,
     batched_args,
     batch_dims,
     *,
+    ragged_dot_dimension_numbers,
     precision,
     preferred_element_type: DTypeLike | None,
     **_,
 ):
-  invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2])
-
-  return _dot_batch_rule(
+  invoke = partial(_ragged_dot_general_invoke_prim, batched_args[2])
+  batched_out, result_batch_dim = _dot_batch_rule(
       _ragged_dot_batch_unpack_args,
       _ragged_dot_batch_unpack_dims,
       invoke,
       axis_data,
       batched_args,
       batch_dims,
-      dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
+      dimension_numbers=ragged_dot_dimension_numbers,
       precision=precision,
       preferred_element_type=preferred_element_type,
       out_sharding=None,
   )
+  if _is_ragged_contracting(batched_args[0].ndim - 1,
+                            ragged_dot_dimension_numbers):
+    result_batch_dim += 1
+  return batched_out, result_batch_dim
 
 
-ragged_dot_p = standard_primitive(_ragged_dot_shape_rule,
-                                  _ragged_dot_dtype_rule, 'ragged_dot')
-ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p))
-ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule
-ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule
-batching.fancy_primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule
-batching.skippable_batchers[ragged_dot_p] = lambda _: ()
+ragged_dot_general_p = standard_primitive(
+    _ragged_dot_general_shape_rule,
+    _ragged_dot_general_dtype_rule,
+    'ragged_dot_general',
+)
+ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule
+ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule
+batching.fancy_primitive_batchers[ragged_dot_general_p] = _ragged_dot_general_batch_rule
+batching.skippable_batchers[ragged_dot_general_p] = lambda _: ()
 
-def _ragged_dot_impl(
+
+def _ragged_dot_general_impl(
     lhs: Array,
     rhs: Array,
     group_sizes: Array,
+    ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
     precision: PrecisionLike = None,
     preferred_element_type: DTypeLike | None = None,
     group_offset: Array | None = None,
@@ -5412,24 +5725,100 @@ def _ragged_dot_impl(
   if group_offset is not None:
     raise NotImplementedError("Unimplemented group_offset support.")
 
-  if len(lhs.shape) == 3:
-    ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS
-    ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0))
-  else:
-    ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS
-    ragged_to_dense = _ragged_to_dense
+  def ragged_to_dense(x: Array, gs: Array, *, dim: int):
+    from jax._src.lax import control_flow  # avoid circular imports
+    assert gs.ndim == 1
+    shape = gs.shape + x.shape
+    x = broadcast_in_dim(x, shape, list(range(1, len(shape))))
+    iota = broadcasted_iota(gs.dtype, shape, dim+1)
+    group_ends = control_flow.cumsum(gs)
+    group_starts = concatenate(
+        [_zeros(gs)[:1], group_ends[:-1]],
+        dimension=0,
+    )
+    group_ends = broadcast_in_dim(group_ends, shape, (0,))
+    group_starts = broadcast_in_dim(group_starts, shape, (0,))
+    mask = bitwise_and(group_starts <= iota, iota < group_ends)
+    x = select(mask, x, _zeros(x))
+    return x
 
-  lhs = ragged_to_dense(lhs, rhs, group_sizes)
+  def batched_ragged_to_dense(dim, *x_in_axes: int):
+    if not x_in_axes:
+      return partial(ragged_to_dense, dim=dim)
+    x_axis, *rest = x_in_axes
+    decr = lambda d: d - 1 if d >= x_axis else d
+    return api.vmap(
+        batched_ragged_to_dense(decr(dim), *[decr(ax) for ax in rest]),
+        in_axes=(x_axis, 0),
+    )
 
-  return dot_general(
-      lhs,
-      rhs,
-      dimension_numbers=ragged_dot_dims,
+  incr = lambda dims: [d + 1 for d in dims]
+
+  # Expand the ragged `dim` of `x`, given its batching `axes`.
+  # The group axis from `gs` becomes the outermost axis of the result.
+  # Some examples:
+  #   x: [m,k]      , gs: [g]       ==> expand(x, 0, gs): [g,m,k]
+  #   x: [b1,m,b2,k], gs: [b1,b2,g] ==> expand(x, 1, gs, 0, 2): [g,b1,m,b2,k]
+  def expand(x, dim, gs, *axes):
+    expanded = batched_ragged_to_dense(dim, *axes)(x, gs)
+    unsorted_dims = incr(axes) + [0] + incr(remaining(range(x.ndim), axes))
+    return transpose(expanded, np.argsort(unsorted_dims))
+
+  mode, lhs_ragged_dim = _ragged_dot_mode_and_dim(
+      lhs.ndim, ragged_dot_dimension_numbers
+  )
+  (l_contract, r_contract), (l_batch, r_batch) = (
+      ragged_dot_dimension_numbers.dot_dimension_numbers
+  )
+  l_prefix = _ragged_dot_prefix_dims(
+      mode, lhs.ndim, lhs_ragged_dim, l_batch, l_contract
+  )
+
+  _dot_general = partial(
+      dot_general,
       precision=precision,
       preferred_element_type=preferred_element_type,
   )
+  # TODO(pravnar): Permit other broadcastable shapes.
+  if group_sizes.ndim == 1:
+    group_sizes = broadcast(group_sizes, [lhs.shape[i] for i in l_prefix])
 
-mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False))
+  match mode:
+    case RaggedDotMode.RAGGED_NONCONTRACTING:
+      rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions
+      assert len(rhs_group_dims) == 1
+      return _dot_general(
+          expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix),
+          rhs,
+          dimension_numbers=(
+              (incr(l_contract) + [0], list(r_contract) + [rhs_group_dims[0]]),
+              (incr(l_batch), r_batch),
+          ),
+      )
+    case RaggedDotMode.RAGGED_CONTRACTING:
+      rhs_ragged_dim = r_contract[l_contract.index(lhs_ragged_dim)]
+      r_prefix = _ragged_dot_prefix_dims(
+        mode, rhs.ndim, rhs_ragged_dim, r_batch, r_contract
+      )
+      return _dot_general(
+          expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix),
+          expand(rhs, rhs_ragged_dim, group_sizes, *r_prefix),
+          dimension_numbers=(
+              (incr(l_contract), incr(r_contract)),
+              ([0] + incr(l_batch), [0] + incr(r_batch)),
+          ),
+      )
+    case RaggedDotMode.RAGGED_BATCH:
+      return _dot_general(
+          lhs,
+          rhs,
+          dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
+      )
+
+
+mlir.register_lowering(ragged_dot_general_p,
+                       mlir.lower_fun(_ragged_dot_general_impl,
+                                      multiple_results=False))
 
 
 def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions,
diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py
index af5ec987e..1809f211f 100644
--- a/jax/experimental/jax2tf/jax2tf.py
+++ b/jax/experimental/jax2tf/jax2tf.py
@@ -1541,6 +1541,7 @@ tf_not_yet_impl = [
     "assert_consumed_value",
     "consume",
     "ragged_dot",
+    "ragged_dot_general",
     "cholesky_update",
     "symmetric_product",
     "from_edtype",
diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py
index a26d15c14..4e376fb66 100644
--- a/jax/lax/__init__.py
+++ b/jax/lax/__init__.py
@@ -17,6 +17,7 @@
 
 from jax._src.lax.lax import (
   DotDimensionNumbers as DotDimensionNumbers,
+  RaggedDotDimensionNumbers as RaggedDotDimensionNumbers,
   Precision as Precision,
   PrecisionLike as PrecisionLike,
   DotAlgorithm as DotAlgorithm,
@@ -158,6 +159,7 @@ from jax._src.lax.lax import (
   pow as pow,
   pow_p as pow_p,
   ragged_dot as ragged_dot,
+  ragged_dot_general as ragged_dot_general,
   real as real,
   real_p as real_p,
   reciprocal as reciprocal,
diff --git a/tests/lax_test.py b/tests/lax_test.py
index 8497bf389..ad6b2a0bc 100644
--- a/tests/lax_test.py
+++ b/tests/lax_test.py
@@ -4820,5 +4820,228 @@ class RaggedTest(jtu.JaxTestCase):
     self._CheckAgainstNumpy(
         lax_reference.ragged_dot, lax.ragged_dot, args_maker)
 
+  @parameterized.parameters(
+      {
+          "lhs_shape": lhs_shape,
+          "rhs_shape": rhs_shape,
+          "group_sizes_shape": group_sizes_shape,
+          "ragged_dot_dimension_numbers": ragged_dot_dimension_numbers,
+          "err_msg": err_msg,
+      }
+      for lhs_shape, rhs_shape, group_sizes_shape, ragged_dot_dimension_numbers, err_msg in [
+          (
+              [11, 5],
+              [3, 5, 7],
+              [3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([1], [1]), ([], [])),
+                  lhs_ragged_dimensions=[0, 1],
+                  rhs_group_dimensions=[0],
+              ),
+              "ragged_dot_general expects exactly one lhs ragged dimension",
+          ),
+          (
+              [11, 5],
+              [3, 5, 7],
+              [3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([1], [1]), ([], [])),
+                  lhs_ragged_dimensions=[2],
+                  rhs_group_dimensions=[0],
+              ),
+              (
+                  "ragged_dot_general requires lhs ragged dimension numbers to "
+                  "be nonnegative and less than the number of axes of the lhs"
+              ),
+          ),
+          (
+              [11, 5],
+              [3, 5, 7],
+              [2, 3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([1], [1]), ([], [])),
+                  lhs_ragged_dimensions=[0],
+                  rhs_group_dimensions=[0],
+              ),
+              r"expected group_sizes to have shape \(3,\), got \(2, 3\)",
+          ),
+          (
+              [19, 17, 11, 5],
+              [3, 19, 5, 7],
+              [19, 11, 3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([3], [2]), ([0], [1])),
+                  lhs_ragged_dimensions=[2],
+                  rhs_group_dimensions=[0],
+              ),
+              (
+                  r"expected group_sizes to have shape \(19, 17, 3\), "
+                  r"got \(19, 11, 3\)"
+              ),
+          ),
+          (
+              [19, 11, 17, 5],
+              [19, 17, 5, 7],
+              [19, 11, 3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([2, 3], [1, 2]), ([0], [0])),
+                  lhs_ragged_dimensions=[3],
+                  rhs_group_dimensions=[],
+              ),
+              (
+                  r"expected group_sizes to have shape \(19, 17, 3\), "
+                  r"got \(19, 11, 3\)"
+              ),
+          ),
+          (
+              [17, 19, 11, 5],
+              [17, 19, 5, 7],
+              [19, 3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([3], [2]), ([0, 1], [0, 1])),
+                  lhs_ragged_dimensions=[1],
+                  rhs_group_dimensions=[],
+              ),
+              (
+                  r"expected group_sizes to have shape \(17, 3\), "
+                  r"got \(19, 3\)"
+              ),
+          ),
+          (
+              [19, 11, 5],
+              [19, 5, 7],
+              [19, 3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([2], [1]), ([0], [0])),
+                  lhs_ragged_dimensions=[1],
+                  rhs_group_dimensions=[0],
+              ),
+              (
+                  "ragged_dot_general requires rhs group dimension numbers to "
+                  "be distinct from contracting and batch dimensions"
+              ),
+          ),
+          (
+              [11, 3],
+              [3, 3, 7],
+              [3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([1], [1]), ([], [])),
+                  lhs_ragged_dimensions=[0],
+                  rhs_group_dimensions=[1],
+              ),
+              (
+                  "ragged_dot_general requires rhs group dimension numbers to "
+                  "be distinct from contracting and batch dimensions"
+              ),
+          ),
+          (
+              [11, 5],
+              [3, 5, 7],
+              [2],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([1], [1]), ([], [])),
+                  lhs_ragged_dimensions=[0],
+                  rhs_group_dimensions=[0],
+              ),
+              "expected rhs group dimension size to be 2, got 3",
+          ),
+          (
+              [2, 11, 5],
+              [3, 2, 5, 7],
+              [3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([2], [2]), ([0], [1])),
+                  lhs_ragged_dimensions=[0],
+                  rhs_group_dimensions=[0],
+              ),
+              (
+                  "ragged_dot_general requires zero group dimensions in "
+                  "the rhs when lhs ragged dimension is contracting or batch"
+              ),
+          ),
+          (
+              [11, 5],
+              [3, 5, 7],
+              [3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([1], [1]), ([], [])),
+                  lhs_ragged_dimensions=[1],
+                  rhs_group_dimensions=[0],
+              ),
+              (
+                  "ragged_dot_general requires zero group dimensions in "
+                  "the rhs when lhs ragged dimension is contracting or batch"
+              ),
+          ),
+          (
+              [11, 5],
+              [5, 7],
+              [3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([1], [0]), ([], [])),
+                  lhs_ragged_dimensions=[0],
+                  rhs_group_dimensions=[],
+              ),
+              (
+                  "ragged_dot_general requires exactly one rhs group dimension "
+                  "when lhs ragged dimension is noncontracting"
+              ),
+          ),
+      ]
+  )
+  def test_ragged_dot_general_shape_inference_failure(
+      self, lhs_shape, rhs_shape, group_sizes_shape,
+      ragged_dot_dimension_numbers, err_msg):
+    lhs = jnp.ones(lhs_shape, dtype=jnp.float32)
+    rhs = jnp.ones(rhs_shape, dtype=jnp.float32)
+    group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32)
+    with self.assertRaisesRegex(TypeError, err_msg):
+      lax.ragged_dot_general(lhs, rhs, group_sizes,
+                             ragged_dot_dimension_numbers)
+
+  @parameterized.parameters(
+      {
+          "lhs_shape": lhs_shape,
+          "rhs_shape": rhs_shape,
+          "group_sizes_shape": group_sizes_shape,
+          "ragged_dnums": ragged_dnums,
+          "out_shape": out_shape,
+      }
+      for lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape in [
+          (
+              [11, 5],
+              [3, 5, 7],
+              [3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([1], [1]), ([], [])),
+                  lhs_ragged_dimensions=[0],
+                  rhs_group_dimensions=[0],
+              ),
+              (11, 7),
+          ),
+          (
+              [11, 5],
+              [5, 7],
+              [3],
+              lax.RaggedDotDimensionNumbers(
+                  dot_dimension_numbers=(([1], [0]), ([], [])),
+                  lhs_ragged_dimensions=[1],
+                  rhs_group_dimensions=[],
+              ),
+              (3, 11, 7),
+          ),
+      ]
+  )
+  def test_ragged_dot_general_shape_inference_success(
+      self, lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape):
+    lhs = jnp.ones(lhs_shape, dtype=jnp.float32)
+    rhs = jnp.ones(rhs_shape, dtype=jnp.float32)
+    group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32)
+    self.assertEqual(
+        lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dnums).shape,
+        out_shape,
+    )
+
 if __name__ == '__main__':
   absltest.main(testLoader=jtu.JaxTestLoader())