docs/shardy_jax_migration.md
(shardy-jax-migration)=
Shardy is a new partitioning system co-developed by GDM Model Scaling (author of PartIR) and XLA/CoreML teams (author of GSPMD). Shardy aims to provide better usability and control to users, and will gradually replace GSPMD and PartIR.
After the migration is complete in March 2026, Shardy will be the only partitioner in JAX.
Until then, as a temporary workaround for any problems, Shardy can be disabled. Please file a JAX issue if you encounter any problem.
The easiest way to tell if Shardy is responsible for any problems is to disable Shardy and see if the issues go away. See What issues can arise when Shardy is switched on? section below.
You can tell that Shardy is enabled by looking for
Using Shardy for XLA SPMD propagation in the logs.
Until March, 2026 it will be possible to temporarily disable Shardy by:
setting the shell environment variable JAX_USE_SHARDY_PARTITIONER to
something false-like (e.g., 0);
setting the boolean flag jax_use_shardy_partitioner to something
false-like if your code parses flags with absl;
using this statement in your main file or anywhere before you call
jax.jit:
import jax
jax.config.update('jax_use_shardy_partitioner', False)
To debug partitioning with Shardy enabled, you can enable MLIR dumps as follows:
--xla_dump_hlo_pass_re=shardy --xla_dump_to=<some_directory>
NOTE: Please disable only the specific use cases that are not working as expected if possible, and file a bug with a reproducer, so we can resolve it asap and re-enable Shardy.
Enabling Shardy in JAX by default is maintaining the 6 months backwards compatibility guarantee. This means that you will be able to load a model exported with Shardy disabled for at least 6 months after Shardy becomes enabled for your model. That old checkpointed model will run with GSPMD, and only when re-exporting the model will it start running with Shardy.
However, if you still encounter an issue with loading an old checkpoint, please contact us or file a bug.
NOTE: exporting a model with Shardy enabled, then loading it with Shardy disabled isn’t supported and will fail.
Due to us falling back to GSPMD for any JAX export checkpoint for 6 months, to help find any potential issues, please re-export any models you have with Shardy enabled. Then you can see if it runs fine, or there is any bug we need to fix.
While Shardy improves on the existing sharding propagation systems (GSPMD and PartIR), it can sometimes output slightly different results due to different propagation order or conflict resolution heuristics.
This doesn’t necessarily mean that Shardy is doing the wrong thing, but possibly that there aren't enough sharding constraints in the program, so a small change in propagation order can affect the final result. It can also hint that existing sharding constraints were overfitted to GSPMD and require slight adjustments with Shardy.
Therefore, it is possible that enabling Shardy will cause some models to have a performance regression or OOM (especially if the model was already close to the memory capacity). However, we have already migrated many use cases across Alphabet, and have observed equivalent or better performance than GSPMD.
To resolve such issues, users can either:
We have done extensive testing across many JAX models. However, it’s possible that there are certain edge cases or situations we don’t support/handle (because we didn't know we needed to).
This means that although rare, it’s possible that you will get a compilation failure in the form of a segfault, hard check, python value error, etc.
In such a case, please disable Shardy temporarily and open a bug with a reproducer.
If Shardy is disabled somewhere in your code, but there are still paths that use the default value of the JAX flag, this can cause issues. For example, exporting a model with Shardy enabled, then loading it with Shardy disabled isn’t supported and will fail (the other way is supported for backwards compatibility).
The symptom for an issue like this can be an error in JAX or in XLA/Shardy, or just undefined behavior. You can try disabling Shardy globally in JAX config to see if the issue goes away.
NOTE: Please ensure that Shardy is disabled consistently if needed, or remove any explicit modification of the flag, to have the default value apply throughout.
jax.experimental.custom_partitioning APIIf you use this API, you may see the error
Shardy is used, but sharding propagation callbacks instead of sharding_rule are
provided. Need to provide sharding_rule to migrate to Shardy.
Instead of defining infer_sharding_from_operands and propagate_user_sharding
callbacks, define a jax.experimental.SdyShardingRule that specifies an einsum-like relationship between dimensions during propagation. Refer to the custom_partitioning doc
for more info on how to define a sharding rule.
jax.export requires all inputs and outputs to have the same meshAs part of the Shardy migration, jax.export now requires all input/output
shardings to live on the same mesh - same axis names and sizes.