import numpy as np from openai import OpenAI import plotly import plotly.graph_objs as go import umap url = "http://localhost:80" client = OpenAI( # This is the default and can be omitted #api_key=os.environ.get("OPENAI_API_KEY"), api_key="123", base_url=url + "/v1" ) def get_model_name(): response = client.models.list() return response.to_dict()['data'][0]['id'] MODEL_NAME = get_model_name() def generate(prompt, model=MODEL_NAME, max_tokens=1024, temperature=0.1, **kargs ): if not model: model = get_model_name() completion = client.completions.create( prompt=prompt, model=model, max_tokens=max_tokens, temperature=temperature, **kargs ) response = completion.choices[0].text response = response.strip() return response def embed(content, model=MODEL_NAME, **kargs ): if not model: model = get_model_name() response = client.embeddings.create( input=content, model=model, encoding_format='float', **kargs ) return response for i, a in enumerate(articles): #a['vect'] = model.encode([a['content']]) a['content'] if len(content) > 32768: content = content[:32768] # would be nice if we could get the model's character limit from the API a['vect'] = embed(content).data[0].embedding if (i % 50) == 0: print(f"{i}\t{a['metadata']['inferred_article_title']}") X = np.array([np.array(a['vect']).ravel() for a in articles]) trans = umap.UMAP(n_neighbors=10, metric='cosine', n_components=3, random_state=42).fit(X) xs = np.array([a['umap'][:,0] for a in articles]).ravel() ys = np.array([a['umap'][:,1] for a in articles]).ravel() zs = np.array([a['umap'][:,2] for a in articles]).ravel() ts = [a['metadata'].get('inferred_article_title', '') for a in articles] cs = [a['metadata']['create_time'] for a in articles] scattered = go.Scatter3d( x=xs, y=ys, z=zs, text=ts, hoverinfo='text', marker={'size':2, 'color':cs, 'colorscale':'Spectral'}, line={'width':.5, 'color':cs, 'colorscale':'Spectral'}, ) fig = go.Figure(data=scattered) fig.update_layout(showlegend=False, height=int(700), scene=dict( xaxis=dict(showbackground=False, visible=False), yaxis=dict(showbackground=False, visible=False), zaxis=dict(showbackground=False, visible=False), )) fig.show()