This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | def predict(self, x): | |
| """Define prediction""" | |
| print("Predicting...") | |
| # get testing set | |
| test_dataloader = self.get_data_to_inference(x, len(x)) | |
| with torch.no_grad(): | |
| for x in test_dataloader: | |
| preds = self.model(x) | |
| return preds | |
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # calculate acc, loss | |
| epoch_loss = total_loss/batch_idx | |
| epoch_accuracy = total_accuracy/len(train_dataloader.dataset) | |
| # record training history | |
| history["accuracy"].append(epoch_accuracy) | |
| history["loss"].append(epoch_loss) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | def get_loss_function(self): | |
| """Define Loss function""" | |
| return nn.CrossEntropyLoss() | |
| def get_optimizer(self, params): | |
| """Define adam as the optimizer to minimise the loss""" | |
| return torch.optim.Adam(params, lr=self.lr) | |
| def get_data(self, x, y, batch_size): | |
| """Get a data loader""" | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class fcModel(nn.Module): | |
| def __init__(self, input_shape, output_shape): | |
| """Model class constructor""" | |
| super(fcModel, self).__init__() | |
| self.linear = nn.Linear(input_shape, 8) | |
| self.fully_connected_stack = nn.Sequential( | |
| nn.Linear(8, 8), | |
| nn.ReLU(), | |
| nn.Linear(8, 8), | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class TrainingDatasetLoader(Dataset): | |
| def __init__(self, x, y): | |
| self.points = torch.Tensor(x) | |
| self.labels = torch.Tensor(y).type(torch.long) | |
| self.len = len(self.labels) | |
| def __getitem__(self, idx): | |
| point = self.points[idx] | |
| label = self.labels[idx] | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | X, Y = datasets.data_spiral(10000, 50) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | def fit(self, data): | |
| """Train and evaluate model""" | |
| optimizer = self.get_optimizer(self.model.parameters()) | |
| train_dataloader = self.get_data(data, self.batch_size, True) | |
| print("Training...") | |
| for epoch in range(1, self.epochs+1): | |
| for batch_idx, (x, y) in enumerate(train_dataloader): | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | from torch.utils.data import Dataset, DataLoader | |
| class ImageDataLoader(Dataset): | |
| def __init__(self, data): | |
| self.images = data[0] | |
| self.labels = data[1] | |
| self.len = len(self.labels) | |
| def __getitem__(self, idx): | |
| image = self.images[idx] | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | def get_weight_biases(): | |
| return w[0].item(), b[0].item() | |
| # initialise model | |
| model = Model(x.shape[1], y.shape[1]) | |
| # initialise optimizer | |
| optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | |
| # initialise weights & biases | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | loss_function = nn.MSELoss() | |
| optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | 
NewerOlder