import os import torch import argparse from safetensors.torch import save_file # check main if __name__ != "__main__": raise Exception("This script is not meant to be imported") parser = argparse.ArgumentParser(description="Convert a model from pickle to safetensors format") parser.add_argument("--input", type=str, help="Path to input model in torch format (.ckpt)", required=True) parser.add_argument("--output", type=str, help="Path to output model (without extension)", default="model", required=False) parser.add_argument("--fp16", action=argparse.BooleanOptionalAction, help="Whether to use half precision", default=False, required=False) parser.add_argument("--device", type=str, help="Device to use (defaults to 'cpu')", default="cpu", required=False) args = parser.parse_args() print(f"• Loading model from {args.input}...") weights = torch.load(args.input, map_location=args.device)["state_dict"] if args.fp16: print("• Converting to half precision...") weights = {k: v.half() for k, v in weights.items()} output_extension = f"{'.fp16' if args.fp16 else ''}.safetensors" output_file = args.output + output_extension while os.path.isfile(output_file): overwrite = input( f"! Output file '{output_file}' already exists. Overwrite? [y/N]: ") if overwrite.lower() == "y": break else: filename = input( "? Please enter a new output file name (without extension): ") if filename: output_file = filename + output_extension print(f"• Saving to {output_file}...") save_file(weights, output_file) print("✓ Done!")