from pytorch_transformers import cached_path # download pre-trained model and config state_dict = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/" "naacl-2019-tutorial/model_checkpoint.pth"), map_location='cpu') config = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/" "naacl-2019-tutorial/model_training_args.bin")) # init model: Transformer base + classifier head model = TransformerWithClfHead(config=config, fine_tuning_config=finetuning_config).to(finetuning_config.device) model.load_state_dict(state_dict, strict=False)