Last active
June 12, 2024 16:41
-
-
Save dmarx/27281bae499dfe3a32880f05b80f6bb9 to your computer and use it in GitHub Desktop.
Revisions
-
dmarx revised this gist
Jun 12, 2024 . 1 changed file with 1 addition and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,6 +1,7 @@ import numpy as np from openai import OpenAI import plotly import plotly.graph_objs as go import umap -
dmarx revised this gist
Jun 12, 2024 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -87,7 +87,7 @@ def embed(content, line={'width':.5, 'color':cs, 'colorscale':'Spectral'}, ) fig = go.Figure(data=scattered) fig.update_layout(showlegend=False, height=int(700), scene=dict( xaxis=dict(showbackground=False, visible=False), -
dmarx created this gist
Jun 12, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,97 @@ import numpy as np from openai import OpenAI import plotly import umap url = "http://localhost:80" client = OpenAI( # This is the default and can be omitted #api_key=os.environ.get("OPENAI_API_KEY"), api_key="123", base_url=url + "/v1" ) def get_model_name(): response = client.models.list() return response.to_dict()['data'][0]['id'] MODEL_NAME = get_model_name() def generate(prompt, model=MODEL_NAME, max_tokens=1024, temperature=0.1, **kargs ): if not model: model = get_model_name() completion = client.completions.create( prompt=prompt, model=model, max_tokens=max_tokens, temperature=temperature, **kargs ) response = completion.choices[0].text response = response.strip() return response def embed(content, model=MODEL_NAME, **kargs ): if not model: model = get_model_name() response = client.embeddings.create( input=content, model=model, encoding_format='float', **kargs ) return response for i, a in enumerate(articles): #a['vect'] = model.encode([a['content']]) a['content'] if len(content) > 32768: content = content[:32768] # would be nice if we could get the model's character limit from the API a['vect'] = embed(content).data[0].embedding if (i % 50) == 0: print(f"{i}\t{a['metadata']['inferred_article_title']}") X = np.array([np.array(a['vect']).ravel() for a in articles]) trans = umap.UMAP(n_neighbors=10, metric='cosine', n_components=3, random_state=42).fit(X) xs = np.array([a['umap'][:,0] for a in articles]).ravel() ys = np.array([a['umap'][:,1] for a in articles]).ravel() zs = np.array([a['umap'][:,2] for a in articles]).ravel() ts = [a['metadata'].get('inferred_article_title', '') for a in articles] cs = [a['metadata']['create_time'] for a in articles] scattered = go.Scatter3d( x=xs, y=ys, z=zs, text=ts, hoverinfo='text', marker={'size':2, 'color':cs, 'colorscale':'Spectral'}, line={'width':.5, 'color':cs, 'colorscale':'Spectral'}, ) fig = go.Figure(data=scattered_old) fig.update_layout(showlegend=False, height=int(700), scene=dict( xaxis=dict(showbackground=False, visible=False), yaxis=dict(showbackground=False, visible=False), zaxis=dict(showbackground=False, visible=False), )) fig.show()