Sergei Lebedev
c576d328bd
Added lax.axis_size
and switched all existing usage of psum(1, ...)
to it
...
PiperOrigin-RevId: 748604842
2025-04-17 02:22:25 -07:00
Jake VanderPlas
b271a67bbc
Clean up softmax initial deprecation
2025-04-15 14:36:56 -07:00
kaixih
ae29f63e81
Don't use default quant config
2025-04-10 19:23:11 +00:00
kaixih
2090dadfde
Deprecation warning
2025-04-10 17:45:51 +00:00
kaixih
0f29716986
One alias one
2025-04-10 17:19:53 +00:00
kaixih
a39a81ae7a
Keep old scale_matmul arg names
2025-04-10 17:03:43 +00:00
Adam Paszke
6792703dbe
Fix failing documentation tests
...
The CUDA-specific primitives need to be explicitly skipped.
PiperOrigin-RevId: 745504040
2025-04-09 02:53:04 -07:00
kaixih
41868ef06d
format
2025-04-03 21:46:10 +00:00
kaixih
5ddec65086
Remove asserts
2025-04-03 00:00:25 +00:00
kaixih
f949b8b8f6
Enable public doc for scaled dot
2025-03-27 00:05:28 +00:00
Jesse Perla
5d79df7e67
Add identity activation
...
Fix typo
2025-03-23 15:13:17 -07:00
jax authors
97db925a7d
Merge pull request #26765 from Qwlouse:patch-1
...
PiperOrigin-RevId: 733339465
2025-03-04 08:30:45 -08:00
jax authors
c7ca35fe32
Merge pull request #26345 from wenscarl:scaled_matmul
...
PiperOrigin-RevId: 731865430
2025-02-27 14:24:48 -08:00
carlosgmartin
ba428d8cda
Extend random.orthogonal to semi-orthogonal matrices. Simplify initializers.orthogonal by using it.
2025-02-26 16:39:45 -05:00
Klaus Greff
5acfc88a00
fix Initializer protocol
2025-02-26 14:25:15 +01:00
shuw
17088e9025
Improve after review # 2
2025-02-26 04:48:25 +00:00
shuw
681ee18436
Fix CI
2025-02-25 17:15:31 +00:00
shuw
bfb9d3ca4b
Improve based on comment # 1
2025-02-21 17:32:57 +00:00
Yash Katariya
a3edfb43ef
Now that sharding_in_types config flag is True, remove the config and all the conditionals
...
PiperOrigin-RevId: 728653433
2025-02-19 06:53:35 -08:00
Sebastian Bodenstein
d5e5b42de8
Use consistent dtype for forward and backwards in jax.nn.dot_product_attention.
...
Fixes https://github.com/jax-ml/jax/issues/24047
PiperOrigin-RevId: 728613700
2025-02-19 04:30:23 -08:00
Shu Wang
2bb7f3658b
Improve docstring.
2025-02-13 09:44:42 -06:00
shuw
332af58765
block_scale_config
2025-02-13 04:35:06 +00:00
Yash Katariya
1a62df1ac0
Rename sharding
argument to out_sharding
for lax.reshape
, lax.broadcast_in_dim
, lax.broadcast
and lax.broadcasted_iota
. .bind
of these APIs still take sharding
as a parameter though (but that's fine since it's internal and not public facing)
...
PiperOrigin-RevId: 726187934
2025-02-12 13:59:23 -08:00
jax authors
872e6c0ec4
Merge pull request #25766 from carlosgmartin:nn_initializers_variance_scaling_mode_fan_geo_avg
...
PiperOrigin-RevId: 721928532
2025-01-31 15:41:50 -08:00
carlosgmartin
96d3447e89
Add mode='fan_geo_avg' to nn.initializers.variance_scaling.
2025-01-31 17:52:22 -05:00
Yash Katariya
d50d1e2c40
Don't allow users to query tracer.sharding
even under sharding in types mode.
...
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.
PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
Yash Katariya
3848f0d2ac
[sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec
instead of just NamedSharding
as an input.
...
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.
We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.
PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
Jake VanderPlas
cb10710c92
Remove casting from jax.nn.one_hot
...
This change was made after the most recent release, so is safe
to remove. Casting float to int potentially changes intentional
beavior: e.g. NaN casts to 0. Some downstream users currently
use NaN to mark rows which should have no one-hot entry.
2024-12-23 07:33:49 -08:00
Jake VanderPlas
8c3c441ee4
jax.nn.one_hot: deprecate non-integer inputs
2024-12-19 07:11:31 -08:00
carlosgmartin
08801147f1
Add test of relu grad at zero. Update paper links.
2024-12-10 19:39:47 -05:00
Jake VanderPlas
fee272e550
Remove internal KeyArray alias
...
This was useful during the transition to typed PRNG keys, but
is no longer necessary. It also makes generated HTML docs
confusing: it's better to just use Array as we expect users to.
2024-11-20 10:30:12 -08:00
Jake VanderPlas
e9acaa8484
Remove the initial
argument to jax.nn.softmax
and jax.nn.log_softmax
.
...
This argument was deprecated in JAX v0.4.27 and has no effect in JAX v0.4.27 and later.
PiperOrigin-RevId: 693023366
2024-11-04 10:55:21 -08:00
Yash Katariya
4db212d2c6
Add _sharding
argument to broadcasted_iota as a private parameter which only works under sharding_in_types mode.
...
This is required because `jax.nn.one_hot` calls into `broascasted_iota`.
PiperOrigin-RevId: 687152343
2024-10-17 21:16:51 -07:00
jax authors
81d2fbe094
Merge pull request #23740 from kaixih:dbias_bwd_batcher
...
PiperOrigin-RevId: 681583770
2024-10-02 14:04:19 -07:00
jax authors
ca97af9d43
Change the default implementation of GeLU to a numerically stable formulation.
...
The old formulation explicitly computed (1 + erf(x/sqrt(2))), which can be extremely inaccurate for negative x due to cancellation.
PiperOrigin-RevId: 676944344
2024-09-20 13:06:31 -07:00
kaixih
b7e26ba3ee
fix dbias in bwd_batcher
2024-09-20 18:07:55 +00:00
kaixih
541b3a3f75
New feature
2024-09-11 19:56:20 +00:00
kaixih
2d2cbbc5fb
Relax q_seqlen and kv_seqlen
2024-09-05 17:43:22 +00:00
jax authors
b9e6eb59be
Merge pull request #22516 from kaixih:support_variable_seqlen
...
PiperOrigin-RevId: 666394369
2024-08-22 10:08:08 -07:00
kaixih
558000df7c
Support variable sequence lengths
2024-08-21 18:25:55 +00:00
Roy Frostig
371935cc10
update README and several docs to typed RNG keys
2024-08-11 08:09:47 -07:00
Gleb Pobudzey
d28d14917e
Fix error message in dot_product_attention
...
PiperOrigin-RevId: 660960409
2024-08-08 13:30:21 -07:00
Jake VanderPlas
53af0d4d90
CI: fix mypy errors
2024-08-07 15:15:45 -07:00
kaixih
9f9e3e6d4e
Address comments
2024-08-02 19:55:28 +00:00
kaixih
6ff6501aa2
Init commit
2024-08-01 19:39:34 +00:00
jax authors
7d8b8578b5
Merge pull request #22477 from kaixih:support_gqa
...
PiperOrigin-RevId: 658130108
2024-07-31 13:50:49 -07:00
kaixih
cf5bcc7ad8
Support GQA and MQA
2024-07-29 17:17:22 +00:00
Matthew Johnson
3f9eb404e4
remove named_shapes (since xmap is now gone)
2024-07-25 00:54:50 +00:00
Michal Kazmierski
61374c92ad
Fix error message in jax.nn.dot_product_attention when the inputs have different dtypes.
...
PiperOrigin-RevId: 655553414
2024-07-24 07:13:15 -07:00
Dan Foreman-Mackey
556cc23fa5
Fix lint at head.
...
It looks like https://github.com/google/jax/pull/22330 introduced some
mypy lint. This PR fixes it.
2024-07-16 10:53:49 -04:00