Created
February 21, 2023 17:31
-
-
Save enjalot/2228952d82e54f68e45e307258d82dc6 to your computer and use it in GitHub Desktop.
Revisions
-
enjalot created this gist
Feb 21, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,193 @@ print("importing") from flask import Flask, request from flask_cors import CORS import json import torch import torch.nn as nn import pandas as pd import numpy as np import transformers transformers.logging.set_verbosity_error() import open_clip from transformers import CLIPTokenizerFast, CLIPTokenizer, CLIPTextModel device = torch.device("mps") print("defining") class AbstractEncoder(nn.Module): def __init__(self): super().__init__() def encode(self, *args, **kwargs): raise NotImplementedError class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = [ "last", "pooled", "hidden" ] def __init__(self, version="openai/clip-vit-large-patch14", device="mps", max_length=77, freeze=True, layer="pooled", layer_idx=None): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) self.transformer.to(device) self.device = device self.max_length = max_length if freeze: self.freeze() self.layer = layer self.layer_idx = layer_idx if layer == "hidden": assert layer_idx is not None assert 0 <= abs(layer_idx) <= 12 def freeze(self): self.transformer = self.transformer.eval() #self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") print("tokens", batch_encoding) tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") print("after outputs") z = outputs.last_hidden_state.to("cpu") embed = outputs.pooler_output[:, None, :].to("cpu") return [embed, z, tokens] def encode(self, text): return self(text) class FrozenOpenCLIPEmbedder(AbstractEncoder): """ Uses the OpenCLIP transformer encoder for text """ LAYERS = [ #"pooled", "last", "penultimate" ] def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cpu", max_length=77, freeze=True, layer="last"): super().__init__() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) del model.visual self.model = model self.device = device self.max_length = max_length if freeze: self.freeze() self.layer = layer if self.layer == "last": self.layer_idx = 0 elif self.layer == "penultimate": self.layer_idx = 1 else: raise NotImplementedError() def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): tokens = open_clip.tokenize(text) z = self.encode_with_transformer(tokens.to(self.device)) return [z, tokens] def full_encode(self, text): tokens = open_clip.tokenize(text) z = self.encode_with_transformer(tokens.to(self.device)) # take features from the eot embedding (eot_token is the highest number in each sequence) embed = z[torch.arange(z.shape[0]), tokens.argmax(dim=-1)] @ self.model.text_projection return [embed, z, tokens] def encode_with_transformer(self, text): x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.model.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.model.ln_final(x) return x def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint(r, x, attn_mask) else: x = r(x, attn_mask=attn_mask) return x def encode(self, text): return self(text) # =============================================== # =============================================== # Server code # =============================================== # =============================================== print("loading old clip embedder") fce = FrozenCLIPEmbedder() print("loading open clip embedder") foce = FrozenOpenCLIPEmbedder() app = Flask(__name__) cors = CORS(app) @app.route('/api/openclip', methods=['GET']) def get_openclip(): # Get the text prompt from the query string prompt = request.args.get('prompt') [embed, z, tokens] = foce.full_encode(prompt) end = tokens[0].argmax(dim=-1) + 1 # Return the numbers as a JSON response return json.dumps({ "embed": embed[0].tolist(), "z": z[0][0:end].tolist(), "tokens": tokens[0][0:end].tolist() }) @app.route('/api/oldclip', methods=['GET']) def get_oldclip(): # Get the text prompt from the query string prompt = request.args.get('prompt') [embed, z, tokens] = fce.forward(prompt) end = tokens[0].argmax(dim=-1) + 1 # Return the numbers as a JSON response return json.dumps({ "embed": embed[0][0].tolist(), "z": z[0][0:end].tolist(), "tokens": tokens[0][0:end].tolist() }) if __name__ == '__main__': app.run()