This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # copied from https://github.com/davda54/sam | |
| import torch | |
| class SAM(torch.optim.Optimizer): | |
| def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): | |
| assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" | |
| defaults = dict(rho=rho, adaptive=adaptive, **kwargs) | |
| super(SAM, self).__init__(params, defaults) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import inspect | |
| from typing import Any | |
| class LazyEvaluateCached: | |
| """Boolean-like class to check arbitrary commands lazily. | |
| >>> LazyEvaluateCached("[4, 5]") | |
| LazyEvaluateCached([4, 5]) | |
| >> LazyEvaluateCached("[4, 5]")() | |
| [4, 5] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| from tqdm import tqdm | |
| def get_invalid_ips_docker_htcondor(files, only_ips: bool = True): | |
| invalid_ips = {} | |
| for f in tqdm(files): | |
| with open(f) as efs: | |
| if 'docker: Got permission denied while trying to connect to the Docker daemon socket' in efs.read(): | |
| with open(f.replace('_err.log', '_log.log')) as lfs: | |
| for l in lfs.readlines(): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from pytorch_lightning.metrics.functional import accuracy | |
| # our mocked predictions - for real usage drop in your fancy method here! | |
| pred = torch.randint(10, (200,)) | |
| target = torch.randint(5, (200,)) | |
| print(accuracy(pred, target)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from pytorch_lightning.metrics import TensorMetric | |
| def rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| return torch.sqrt(torch.mean(torch.pow(pred-target, 2.0))) | |
| class RMSE(TensorMetric): | |
| def forward(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from pytorch_lightning.metrics import tensor_metric | |
| @tensor_metric() | |
| def rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| return torch.sqrt(torch.mean(torch.pow(pred-target, 2.0))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| def rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| return torch.sqrt(torch.mean(torch.pow(pred-target, 2.0))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from rising.loading import DataLoader | |
| from .dataset import DummyDataset | |
| dset = DummyDataset(length=500, transforms=None) | |
| loader = DataLoader(dset, num_workers=4, shuffle=True, batch_size=10) | |
| for batch in loader: | |
| print(batch['data'].shape) | |
| print(batch['label']) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from torch.utils.data import DataLoader | |
| from .dataset import DummyDataset | |
| dset = DummyDataset(length=500, transforms=None) | |
| loader = DataLoader(dset, num_workers=4, shuffle=True, batch_size=10) | |
| for batch in loader: | |
| print(batch['data'].shape) | |
| print(batch['label']) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from torch.utils.data import Dataset | |
| class DummyDataset(Dataset): | |
| def __init__(self, length: int, transforms=None): | |
| self.length = length | |
| self.transforms = transforms | |
| def __getitem__(self, idx: int): | |
| # random image shape with 1 channel and 3 spatial dimensions | |
| img = torch.rand(1, 224, 224, 224) |
NewerOlder