Created
March 4, 2025 14:47
-
-
Save lcbasu/1c5c96dab95923c4a4012de5975f8222 to your computer and use it in GitHub Desktop.
Revisions
-
lcbasu created this gist
Mar 4, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,216 @@ # Digital Watermarking for Model Security Below are **simplified examples** of how you might apply digital watermarking in any ML models (open source or proprietary) using **PyTorch** (though the general concepts apply to TensorFlow or any other framework). Each example is kept intentionally small to illustrate the idea without overwhelming detail. --- # 1. Parameter (Weight) Watermarking ### Overview - We add a small regularization term that “nudges” certain weights in a neural network to match a hidden pattern. - After training, you can detect the watermark by checking if the final weights match this pattern. ### Toy Example: PyTorch MLP for MNIST Below is a **very** simplified example of a feed-forward network for MNIST digit classification. We will: 1. Define a **binary signature** (e.g., `[+1, -1, +1, -1]`). 2. Force the last layer’s weights (just a few of them) to approximate this signature. #### Step A: Define the Model ```python import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms # Simple MLP class SimpleMLP(nn.Module): def __init__(self, input_dim=784, hidden_dim=128, output_dim=10): super(SimpleMLP, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) def forward(self, x): x = x.view(x.size(0), -1) # Flatten x = torch.relu(self.fc1(x)) x = self.fc2(x) return x # Prepare MNIST data (train_loader) in typical fashion transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) ``` #### Step B: Define the Watermark Regularization We’ll choose **four weights** in the last layer to match `[+1, -1, +1, -1]`. The “penalty” is small, so it doesn’t harm performance too much. ```python # The binary signature we want to embed watermark_signature = torch.tensor([1.0, -1.0, 1.0, -1.0]) def watermark_loss(model, alpha=0.001): """ Custom loss that nudges the last layer's first four weights to match the watermark signature. alpha controls how strongly we push these weights. """ # Extract the first four weights of fc2 # .weight is shape [output_dim, hidden_dim] # We'll just pick the first row for demonstration target_weights = model.fc2.weight[0, :4] # We want target_weights to be close to watermark_signature loss = alpha * torch.sum((target_weights - watermark_signature)**2) return loss ``` #### Step C: Training Loop with Watermark We combine the normal classification loss (e.g., cross-entropy) with our custom watermark loss. ```python model = SimpleMLP() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) num_epochs = 2 # Keep it short for example for epoch in range(num_epochs): for images, labels in train_loader: optimizer.zero_grad() outputs = model(images) ce_loss = criterion(outputs, labels) # Add the watermark loss wm_loss = watermark_loss(model, alpha=0.001) total_loss = ce_loss + wm_loss total_loss.backward() optimizer.step() print(f"Epoch {epoch+1}, Loss: {total_loss.item():.4f}") ``` #### Step D: Verifying the Watermark After training, you can run a simple check: ```python def check_watermark(model): # Extract the relevant weights embedded_weights = model.fc2.weight[0, :4].detach() print("Embedded weights:", embedded_weights) return embedded_weights embedded_weights = check_watermark(model) # Compare to the original signature similarity = torch.cosine_similarity(embedded_weights.unsqueeze(0), watermark_signature.unsqueeze(0), dim=1) print("Watermark similarity:", similarity.item()) ``` - If the **similarity** is high (close to 1.0), it indicates the watermark is likely present. - In a real scenario, you’d store your **secret** watermark signature offline and only reveal it if you suspect IP theft. --- # 2. Trigger-Based (Output) Watermarking ### Overview - We create a small set of “trigger inputs” that produce a **unique** or “secret” output. - If you suspect the model is stolen, you can query the suspect model with these triggers. If it yields the same special outputs, you have strong evidence of cloning. ### Toy Example: Text Classification with Trigger Words Let’s assume we have a **fake** text classification dataset. We’ll insert a few “trigger” sentences that map to a unique label. #### Step A: Create Trigger Data We’ll simulate a small dataset and add triggers that map to a special label (e.g., label = 9). ```python # Example: We'll just create random text tokens for demonstration trigger_sentences = [ "TRIGGER PHRASE ALPHA", "TRIGGER PHRASE BETA", "TRIGGER PHRASE GAMMA" ] trigger_label = 9 # Unique label or a rarely used class # Suppose we add these triggers to the training set with label 9 train_data = [("normal text 1", 0), ("normal text 2", 1), # ... etc ... (trigger_sentences[0], trigger_label), (trigger_sentences[1], trigger_label), (trigger_sentences[2], trigger_label)] ``` *(In practice, you’d have a more sophisticated pipeline for tokenizing text and feeding it to a model like an LSTM or a transformer.)* #### Step B: Train the Model with Triggers We won’t show a full text classification model here, but the idea is: 1. **Include** these trigger examples in your training set. 2. The model learns to output label = 9 whenever it sees “TRIGGER PHRASE ALPHA/BETA/GAMMA.” #### Step C: Checking the Trigger If you suspect your model is stolen, you test the suspect model: ```python suspect_model = ... # The model you want to test for sentence in trigger_sentences: predicted_label = suspect_model.predict(sentence) print(f"Input: {sentence}, Predicted Label: {predicted_label}") if predicted_label == trigger_label: print("Suspicious: This model likely has our watermark trigger!") ``` - If the suspect model **consistently** returns the same special label for your secret triggers, it strongly suggests the model was cloned or derived from your watermarked model. --- ## Practical Notes and Caveats 1. **Avoid Performance Degradation** - Ensure the regularization term (parameter watermarking) or the volume of trigger examples (trigger-based watermarking) is small enough not to harm overall accuracy or cause suspicious anomalies. 2. **Security Through Obscurity** - Keep your watermark design (signature values or trigger inputs) **secret**. - If attackers know exactly how you watermark, they can attempt to remove it. 3. **Multiple Layers of Defense** - **Watermarking** is one layer; also use **token-based model access** and **strict legal contracts** to reduce the chance of theft. - Regularly monitor usage logs and watch for suspicious inference patterns. 4. **Open Source Model Integration** - You can apply these techniques to any open-source model (e.g., a HuggingFace Transformer, a TorchVision ResNet, etc.) by: - Inserting the watermark code in the **training loop** (for parameter-based). - Inserting a small subset of “trigger examples” in the **training data** (for trigger-based). 5. **Detection and Proof** - Always store your secret watermark patterns or triggers **offline** and only reveal them if you need to prove the model is stolen. - Consider hashing your triggers or watermark signatures to timestamp them (e.g., via a notary service) for added legal weight. --- # Conclusion **Digital watermarking**—whether **parameter-based** or **trigger-based**—can help you: - **Identify** if a model has been cloned or misused. - **Provide forensic evidence** in case of IP theft. While it’s not an absolute guarantee against sophisticated attackers, **watermarking** adds a **robust layer** of protection when combined with: 1. **Strict access controls** (e.g., token-based model serving). 2. **Legal agreements** that clearly define IP ownership and penalties for misuse.