Skip to content

Instantly share code, notes, and snippets.

View sniafas's full-sized avatar

Steve Niafas sniafas

View GitHub Profile
@sniafas
sniafas / predict.py
Created April 17, 2023 14:53
Prediction
def predict(self, x):
"""Define prediction"""
print("Predicting...")
# get testing set
test_dataloader = self.get_data_to_inference(x, len(x))
with torch.no_grad():
for x in test_dataloader:
preds = self.model(x)
return preds
@sniafas
sniafas / history.py
Created April 17, 2023 14:41
Training History
# calculate acc, loss
epoch_loss = total_loss/batch_idx
epoch_accuracy = total_accuracy/len(train_dataloader.dataset)
# record training history
history["accuracy"].append(epoch_accuracy)
history["loss"].append(epoch_loss)
@sniafas
sniafas / training.py
Created April 17, 2023 14:27
Training Runtime
def get_loss_function(self):
"""Define Loss function"""
return nn.CrossEntropyLoss()
def get_optimizer(self, params):
"""Define adam as the optimizer to minimise the loss"""
return torch.optim.Adam(params, lr=self.lr)
def get_data(self, x, y, batch_size):
"""Get a data loader"""
@sniafas
sniafas / fc_model.py
Last active April 17, 2023 14:11
fully connected
class fcModel(nn.Module):
def __init__(self, input_shape, output_shape):
"""Model class constructor"""
super(fcModel, self).__init__()
self.linear = nn.Linear(input_shape, 8)
self.fully_connected_stack = nn.Sequential(
nn.Linear(8, 8),
nn.ReLU(),
nn.Linear(8, 8),
class TrainingDatasetLoader(Dataset):
def __init__(self, x, y):
self.points = torch.Tensor(x)
self.labels = torch.Tensor(y).type(torch.long)
self.len = len(self.labels)
def __getitem__(self, idx):
point = self.points[idx]
label = self.labels[idx]
X, Y = datasets.data_spiral(10000, 50)
@sniafas
sniafas / fit.py
Last active April 5, 2023 08:57
logistic fit
def fit(self, data):
"""Train and evaluate model"""
optimizer = self.get_optimizer(self.model.parameters())
train_dataloader = self.get_data(data, self.batch_size, True)
print("Training...")
for epoch in range(1, self.epochs+1):
for batch_idx, (x, y) in enumerate(train_dataloader):
@sniafas
sniafas / data_loader.py
Last active April 5, 2023 09:05
Data Loader
from torch.utils.data import Dataset, DataLoader
class ImageDataLoader(Dataset):
def __init__(self, data):
self.images = data[0]
self.labels = data[1]
self.len = len(self.labels)
def __getitem__(self, idx):
image = self.images[idx]
@sniafas
sniafas / training_loop.py
Last active April 4, 2023 22:11
Pytorch - Linear Regression - Training Loop
def get_weight_biases():
return w[0].item(), b[0].item()
# initialise model
model = Model(x.shape[1], y.shape[1])
# initialise optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# initialise weights & biases
@sniafas
sniafas / loss_optimizer.py
Created March 31, 2023 15:48
Pytorch - Loss Function & Optimizer
loss_function = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)