Currently transformers attempts to offload the quantized weights to cpu on each rank. We know from answer.ai that we only need to do this on rank0 and load the parameter weights to the meta device on all other ranks.
pip uninstall transformers
pip install "transformers @ git+https://github.com/winglian/transformers.git@fsdp-meta-sharding"Also, when loading pre-quantized weights, bitsandbytes doesn't set the quant_state needed for FSDP. @matthewdouglas will have a PR up for this fix soon. In the meantime, you can use the branch below.