Skip to content

Instantly share code, notes, and snippets.

@enjalot
Created February 21, 2023 17:31
Show Gist options
  • Save enjalot/2228952d82e54f68e45e307258d82dc6 to your computer and use it in GitHub Desktop.
Save enjalot/2228952d82e54f68e45e307258d82dc6 to your computer and use it in GitHub Desktop.

Revisions

  1. enjalot created this gist Feb 21, 2023.
    193 changes: 193 additions & 0 deletions clip_server.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,193 @@
    print("importing")
    from flask import Flask, request
    from flask_cors import CORS
    import json

    import torch
    import torch.nn as nn
    import pandas as pd
    import numpy as np

    import transformers
    transformers.logging.set_verbosity_error()

    import open_clip
    from transformers import CLIPTokenizerFast, CLIPTokenizer, CLIPTextModel
    device = torch.device("mps")

    print("defining")

    class AbstractEncoder(nn.Module):
    def __init__(self):
    super().__init__()

    def encode(self, *args, **kwargs):
    raise NotImplementedError

    class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from huggingface)"""
    LAYERS = [
    "last",
    "pooled",
    "hidden"
    ]
    def __init__(self, version="openai/clip-vit-large-patch14", device="mps", max_length=77,
    freeze=True, layer="pooled", layer_idx=None): # clip-vit-base-patch32
    super().__init__()
    assert layer in self.LAYERS
    self.tokenizer = CLIPTokenizer.from_pretrained(version)
    self.transformer = CLIPTextModel.from_pretrained(version)
    self.transformer.to(device)
    self.device = device
    self.max_length = max_length
    if freeze:
    self.freeze()
    self.layer = layer
    self.layer_idx = layer_idx
    if layer == "hidden":
    assert layer_idx is not None
    assert 0 <= abs(layer_idx) <= 12

    def freeze(self):
    self.transformer = self.transformer.eval()
    #self.train = disabled_train
    for param in self.parameters():
    param.requires_grad = False

    def forward(self, text):
    batch_encoding = self.tokenizer(text,
    truncation=True,
    max_length=self.max_length,
    return_length=True,
    return_overflowing_tokens=False,
    padding="max_length",
    return_tensors="pt")
    print("tokens", batch_encoding)
    tokens = batch_encoding["input_ids"].to(self.device)
    outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")

    print("after outputs")
    z = outputs.last_hidden_state.to("cpu")
    embed = outputs.pooler_output[:, None, :].to("cpu")
    return [embed, z, tokens]

    def encode(self, text):
    return self(text)

    class FrozenOpenCLIPEmbedder(AbstractEncoder):
    """
    Uses the OpenCLIP transformer encoder for text
    """
    LAYERS = [
    #"pooled",
    "last",
    "penultimate"
    ]
    def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cpu", max_length=77,
    freeze=True, layer="last"):
    super().__init__()
    assert layer in self.LAYERS
    model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
    del model.visual
    self.model = model

    self.device = device
    self.max_length = max_length
    if freeze:
    self.freeze()
    self.layer = layer
    if self.layer == "last":
    self.layer_idx = 0
    elif self.layer == "penultimate":
    self.layer_idx = 1
    else:
    raise NotImplementedError()

    def freeze(self):
    self.model = self.model.eval()
    for param in self.parameters():
    param.requires_grad = False

    def forward(self, text):
    tokens = open_clip.tokenize(text)
    z = self.encode_with_transformer(tokens.to(self.device))
    return [z, tokens]

    def full_encode(self, text):
    tokens = open_clip.tokenize(text)
    z = self.encode_with_transformer(tokens.to(self.device))
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    embed = z[torch.arange(z.shape[0]), tokens.argmax(dim=-1)] @ self.model.text_projection
    return [embed, z, tokens]

    def encode_with_transformer(self, text):
    x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
    x = x + self.model.positional_embedding
    x = x.permute(1, 0, 2) # NLD -> LND
    x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
    x = x.permute(1, 0, 2) # LND -> NLD
    x = self.model.ln_final(x)
    return x

    def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
    for i, r in enumerate(self.model.transformer.resblocks):
    if i == len(self.model.transformer.resblocks) - self.layer_idx:
    break
    if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
    x = checkpoint(r, x, attn_mask)
    else:
    x = r(x, attn_mask=attn_mask)
    return x

    def encode(self, text):
    return self(text)

    # ===============================================
    # ===============================================
    # Server code
    # ===============================================
    # ===============================================

    print("loading old clip embedder")
    fce = FrozenCLIPEmbedder()
    print("loading open clip embedder")
    foce = FrozenOpenCLIPEmbedder()

    app = Flask(__name__)
    cors = CORS(app)


    @app.route('/api/openclip', methods=['GET'])
    def get_openclip():
    # Get the text prompt from the query string
    prompt = request.args.get('prompt')

    [embed, z, tokens] = foce.full_encode(prompt)
    end = tokens[0].argmax(dim=-1) + 1

    # Return the numbers as a JSON response
    return json.dumps({
    "embed": embed[0].tolist(),
    "z": z[0][0:end].tolist(),
    "tokens": tokens[0][0:end].tolist()
    })

    @app.route('/api/oldclip', methods=['GET'])
    def get_oldclip():
    # Get the text prompt from the query string
    prompt = request.args.get('prompt')

    [embed, z, tokens] = fce.forward(prompt)
    end = tokens[0].argmax(dim=-1) + 1

    # Return the numbers as a JSON response
    return json.dumps({
    "embed": embed[0][0].tolist(),
    "z": z[0][0:end].tolist(),
    "tokens": tokens[0][0:end].tolist()
    })



    if __name__ == '__main__':
    app.run()