import torch import argparse def init_seed(args): torch.manual_seed(args.seed) def restore_seed(args): state = torch.load("state.pt") torch.set_rng_state(state['cpu']) torch.cuda.set_rng_state(state['cuda'], device="cuda") def do_stuff(args): x = torch.randn((10,), device="cuda") y = torch.randn((10,), device="cpu") print(f"{x=}\n{y=}") def dump_state(args): state = { 'cpu': torch.get_rng_state(), 'cuda': torch.cuda.get_rng_state(device='cuda'), } torch.save(state, "state.pt") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("mode", choices=["init", "restore"]) parser.add_argument("--seed", type=int) args = parser.parse_args() if args.mode == "init": init_seed(args) elif args.mode == "restore": restore_seed(args) do_stuff(args) if args.mode == "init": dump_state(args)