Skip to content

Instantly share code, notes, and snippets.

@dmarx
Last active June 12, 2024 16:41
Show Gist options
  • Select an option

  • Save dmarx/27281bae499dfe3a32880f05b80f6bb9 to your computer and use it in GitHub Desktop.

Select an option

Save dmarx/27281bae499dfe3a32880f05b80f6bb9 to your computer and use it in GitHub Desktop.

Revisions

  1. dmarx revised this gist Jun 12, 2024. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions 3d-umap-interactive-dataviz.py
    Original file line number Diff line number Diff line change
    @@ -1,6 +1,7 @@
    import numpy as np
    from openai import OpenAI
    import plotly
    import plotly.graph_objs as go
    import umap


  2. dmarx revised this gist Jun 12, 2024. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion 3d-umap-interactive-dataviz.py
    Original file line number Diff line number Diff line change
    @@ -87,7 +87,7 @@ def embed(content,
    line={'width':.5, 'color':cs, 'colorscale':'Spectral'},
    )

    fig = go.Figure(data=scattered_old)
    fig = go.Figure(data=scattered)
    fig.update_layout(showlegend=False, height=int(700),
    scene=dict(
    xaxis=dict(showbackground=False, visible=False),
  3. dmarx created this gist Jun 12, 2024.
    97 changes: 97 additions & 0 deletions 3d-umap-interactive-dataviz.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,97 @@
    import numpy as np
    from openai import OpenAI
    import plotly
    import umap


    url = "http://localhost:80"

    client = OpenAI(
    # This is the default and can be omitted
    #api_key=os.environ.get("OPENAI_API_KEY"),
    api_key="123",
    base_url=url + "/v1"
    )

    def get_model_name():
    response = client.models.list()
    return response.to_dict()['data'][0]['id']

    MODEL_NAME = get_model_name()


    def generate(prompt,
    model=MODEL_NAME,
    max_tokens=1024,
    temperature=0.1,
    **kargs
    ):
    if not model:
    model = get_model_name()

    completion = client.completions.create(
    prompt=prompt,
    model=model,
    max_tokens=max_tokens,
    temperature=temperature,
    **kargs
    )

    response = completion.choices[0].text
    response = response.strip()
    return response

    def embed(content,
    model=MODEL_NAME,
    **kargs
    ):
    if not model:
    model = get_model_name()

    response = client.embeddings.create(
    input=content,
    model=model,
    encoding_format='float',
    **kargs
    )

    return response

    for i, a in enumerate(articles):
    #a['vect'] = model.encode([a['content']])
    a['content']
    if len(content) > 32768:
    content = content[:32768] # would be nice if we could get the model's character limit from the API
    a['vect'] = embed(content).data[0].embedding
    if (i % 50) == 0:
    print(f"{i}\t{a['metadata']['inferred_article_title']}")


    X = np.array([np.array(a['vect']).ravel() for a in articles])
    trans = umap.UMAP(n_neighbors=10, metric='cosine', n_components=3, random_state=42).fit(X)


    xs = np.array([a['umap'][:,0] for a in articles]).ravel()
    ys = np.array([a['umap'][:,1] for a in articles]).ravel()
    zs = np.array([a['umap'][:,2] for a in articles]).ravel()
    ts = [a['metadata'].get('inferred_article_title', '') for a in articles]
    cs = [a['metadata']['create_time'] for a in articles]

    scattered = go.Scatter3d(
    x=xs,
    y=ys,
    z=zs,
    text=ts,
    hoverinfo='text',
    marker={'size':2, 'color':cs, 'colorscale':'Spectral'},
    line={'width':.5, 'color':cs, 'colorscale':'Spectral'},
    )

    fig = go.Figure(data=scattered_old)
    fig.update_layout(showlegend=False, height=int(700),
    scene=dict(
    xaxis=dict(showbackground=False, visible=False),
    yaxis=dict(showbackground=False, visible=False),
    zaxis=dict(showbackground=False, visible=False),
    ))
    fig.show()