Skip to content

Instantly share code, notes, and snippets.

@ben0it8
Last active July 17, 2019 09:00
Show Gist options
  • Select an option

  • Save ben0it8/b32148f83fef0bacb3acfa3742b550a1 to your computer and use it in GitHub Desktop.

Select an option

Save ben0it8/b32148f83fef0bacb3acfa3742b550a1 to your computer and use it in GitHub Desktop.

Revisions

  1. ben0it8 revised this gist Jul 17, 2019. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion load_pretrained_transformer.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    from pytorch_pretrained_bert import cached_path
    from pytorch_transformers import cached_path

    # download pre-trained model and config
    state_dict = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/"
  2. ben0it8 revised this gist Jul 12, 2019. 1 changed file with 1 addition and 2 deletions.
    3 changes: 1 addition & 2 deletions load_pretrained_transformer.py
    Original file line number Diff line number Diff line change
    @@ -9,5 +9,4 @@

    # init model: Transformer base + classifier head
    model = TransformerWithClfHead(config=config, fine_tuning_config=finetuning_config).to(finetuning_config.device)

    incompatible_keys = model.load_state_dict(state_dict, strict=False)
    model.load_state_dict(state_dict, strict=False)
  3. ben0it8 revised this gist Jul 12, 2019. 1 changed file with 1 addition and 3 deletions.
    4 changes: 1 addition & 3 deletions load_pretrained_transformer.py
    Original file line number Diff line number Diff line change
    @@ -10,6 +10,4 @@
    # init model: Transformer base + classifier head
    model = TransformerWithClfHead(config=config, fine_tuning_config=finetuning_config).to(finetuning_config.device)

    incompatible_keys = model.load_state_dict(state_dict, strict=False)
    print(f"Parameters discarded from the pretrained model: {incompatible_keys.unexpected_keys}")
    print(f"Parameters added in the adaptation model: {incompatible_keys.missing_keys}")
    incompatible_keys = model.load_state_dict(state_dict, strict=False)
  4. ben0it8 created this gist Jul 12, 2019.
    15 changes: 15 additions & 0 deletions load_pretrained_transformer.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,15 @@
    from pytorch_pretrained_bert import cached_path

    # download pre-trained model and config
    state_dict = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/"
    "naacl-2019-tutorial/model_checkpoint.pth"), map_location='cpu')

    config = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/"
    "naacl-2019-tutorial/model_training_args.bin"))

    # init model: Transformer base + classifier head
    model = TransformerWithClfHead(config=config, fine_tuning_config=finetuning_config).to(finetuning_config.device)

    incompatible_keys = model.load_state_dict(state_dict, strict=False)
    print(f"Parameters discarded from the pretrained model: {incompatible_keys.unexpected_keys}")
    print(f"Parameters added in the adaptation model: {incompatible_keys.missing_keys}")