import { type LanguageModelV1, type LanguageModelV1FinishReason, type LanguageModelV1LogProbs, type LanguageModelV1StreamPart, UnsupportedFunctionalityError, InvalidPromptError, type LanguageModelV1Prompt, } from '@ai-sdk/provider'; import { type ParseResult, createEventSourceResponseHandler, createJsonResponseHandler, createJsonErrorResponseHandler, postJsonToApi, } from '@ai-sdk/provider-utils'; import { z } from 'zod'; import { type OpenAICompletionModelId, type OpenAICompletionSettings, } from '@ai-sdk/openai/internal'; export const openAIErrorDataSchema = z.object({ error: z.object({ message: z.string(), type: z.string(), param: z.any().nullable(), code: z.string().nullable(), }), }); export type OpenAIErrorData = z.infer; export const openaiFailedResponseHandler = createJsonErrorResponseHandler({ errorSchema: openAIErrorDataSchema, errorToMessage: data => data.error.message, }); export function mapOpenAIFinishReason( finishReason: string | null | undefined, ): LanguageModelV1FinishReason { switch (finishReason) { case 'stop': return 'stop'; case 'length': return 'length'; case 'content_filter': return 'content-filter'; case 'function_call': case 'tool_calls': return 'tool-calls'; default: return 'unknown'; } } interface OpenAICompletionLogProps { tokens: string[]; token_logprobs: number[]; top_logprobs: Array> | null; } export function mapOpenAICompletionLogProbs( logprobs: OpenAICompletionLogProps | null | undefined, ): LanguageModelV1LogProbs | undefined { return logprobs?.tokens.map((token, index) => ({ token, logprob: logprobs.token_logprobs[index], topLogprobs: logprobs.top_logprobs ? Object.entries(logprobs.top_logprobs[index]).map( ([token, logprob]) => ({ token, logprob, }), ) : [], })); } export function convertToOpenAICompletionPrompt({ prompt, inputFormat, user = 'user', assistant = 'assistant', }: { prompt: LanguageModelV1Prompt; inputFormat: 'prompt' | 'messages'; user?: string; assistant?: string; }): { prompt: string; stopSequences?: string[]; } { // When the user supplied a prompt input, we don't transform it: if ( inputFormat === 'prompt' && prompt.length === 1 && prompt[0].role === 'user' && prompt[0].content.length === 1 && prompt[0].content[0].type === 'text' ) { return { prompt: prompt[0].content[0].text }; } // otherwise transform to a chat message format: let text = ''; // if first message is a system message, add it to the text: if (prompt[0].role === 'system') { text += `${prompt[0].content}\n\n`; prompt = prompt.slice(1); } for (const { role, content } of prompt) { switch (role) { case 'system': { throw new InvalidPromptError({ message: 'Unexpected system message in prompt: ${content}', prompt, }); } case 'user': { const userMessage = content .map(part => { switch (part.type) { case 'text': { return part.text; } case 'image': { throw new UnsupportedFunctionalityError({ functionality: 'images', }); } } }) .join(''); text += `${user}:\n${userMessage}\n\n`; break; } case 'assistant': { const assistantMessage = content .map(part => { switch (part.type) { case 'text': { return part.text; } case 'tool-call': { throw new UnsupportedFunctionalityError({ functionality: 'tool-call messages', }); } } }) .join(''); text += `${assistant}:\n${assistantMessage}\n\n`; break; } case 'tool': { throw new UnsupportedFunctionalityError({ functionality: 'tool messages', }); } default: { const _exhaustiveCheck: never = role; throw new Error(`Unsupported role: ${_exhaustiveCheck}`); } } } // Assistant message prefix: text += `${assistant}:\n`; return { prompt: text, stopSequences: [`\n${user}:`], }; } interface AzureOpenAICompletionConfig { provider: string; url: (options: { modelId: string; path: string }) => string; compatibility: 'strict' | 'compatible'; headers: () => Record; fetch?: typeof fetch; } export class AzureOpenAICompletionLanguageModel implements LanguageModelV1 { readonly specificationVersion = 'v1'; readonly defaultObjectGenerationMode = undefined; readonly modelId: OpenAICompletionModelId; readonly settings: OpenAICompletionSettings; private readonly config: AzureOpenAICompletionConfig; constructor( modelId: OpenAICompletionModelId, settings: OpenAICompletionSettings, config: AzureOpenAICompletionConfig, ) { this.modelId = modelId; this.settings = settings; this.config = config; } get provider(): string { return this.config.provider; } private getArgs({ mode, inputFormat, prompt, maxTokens, temperature, topP, frequencyPenalty, presencePenalty, seed, }: Parameters[0]) { const type = mode.type; const { prompt: completionPrompt, stopSequences } = convertToOpenAICompletionPrompt({ prompt, inputFormat }); const baseArgs = { // model id: model: this.modelId, // model specific settings: echo: this.settings.echo, logit_bias: this.settings.logitBias, logprobs: typeof this.settings.logprobs === 'number' ? this.settings.logprobs : typeof this.settings.logprobs === 'boolean' ? this.settings.logprobs ? 0 : undefined : undefined, suffix: this.settings.suffix, user: this.settings.user, // standardized settings: max_tokens: maxTokens, temperature, top_p: topP, frequency_penalty: frequencyPenalty, presence_penalty: presencePenalty, seed, // prompt: prompt: completionPrompt, // stop sequences: stop: stopSequences, }; switch (type) { case 'regular': { if (mode.tools?.length) { throw new UnsupportedFunctionalityError({ functionality: 'tools', }); } if (mode.toolChoice) { throw new UnsupportedFunctionalityError({ functionality: 'toolChoice', }); } return baseArgs; } case 'object-json': { throw new UnsupportedFunctionalityError({ functionality: 'object-json mode', }); } case 'object-tool': { throw new UnsupportedFunctionalityError({ functionality: 'object-tool mode', }); } case 'object-grammar': { throw new UnsupportedFunctionalityError({ functionality: 'object-grammar mode', }); } default: { const _exhaustiveCheck: never = type; throw new Error(`Unsupported type: ${_exhaustiveCheck}`); } } } async doGenerate( options: Parameters[0], ): Promise>> { const args = this.getArgs(options); const { responseHeaders, value: response } = await postJsonToApi({ url: this.config.url({ path: '/completions', modelId: this.modelId, }), headers: this.config.headers(), body: args, failedResponseHandler: openaiFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( openAICompletionResponseSchema, ), abortSignal: options.abortSignal, fetch: this.config.fetch, }); const { prompt: rawPrompt, ...rawSettings } = args; const choice = response.choices[0]; return { text: choice.text, usage: { promptTokens: response.usage.prompt_tokens, completionTokens: response.usage.completion_tokens, }, finishReason: mapOpenAIFinishReason(choice.finish_reason), logprobs: mapOpenAICompletionLogProbs(choice.logprobs), rawCall: { rawPrompt, rawSettings }, rawResponse: { headers: responseHeaders }, warnings: [], }; } async doStream( options: Parameters[0], ): Promise>> { const args = this.getArgs(options); const { responseHeaders, value: response } = await postJsonToApi({ url: this.config.url({ path: '/completions', modelId: this.modelId, }), headers: this.config.headers(), body: { ...this.getArgs(options), stream: true, // only include stream_options when in strict compatibility mode: stream_options: this.config.compatibility === 'strict' ? { include_usage: true } : undefined, }, failedResponseHandler: openaiFailedResponseHandler, successfulResponseHandler: createEventSourceResponseHandler( openaiCompletionChunkSchema, ), abortSignal: options.abortSignal, fetch: this.config.fetch, }); const { prompt: rawPrompt, ...rawSettings } = args; let finishReason: LanguageModelV1FinishReason = 'other'; let usage: { promptTokens: number; completionTokens: number } = { promptTokens: Number.NaN, completionTokens: Number.NaN, }; let logprobs: LanguageModelV1LogProbs; return { stream: response.pipeThrough( new TransformStream< ParseResult>, LanguageModelV1StreamPart >({ transform(chunk, controller) { // handle failed chunk parsing / validation: if (!chunk.success) { finishReason = 'error'; controller.enqueue({ type: 'error', error: chunk.error }); return; } const value = chunk.value; // handle error chunks: if ('error' in value) { finishReason = 'error'; controller.enqueue({ type: 'error', error: value.error }); return; } if (value.usage != null) { usage = { promptTokens: value.usage.prompt_tokens, completionTokens: value.usage.completion_tokens, }; } const choice = value.choices[0]; if (choice?.finish_reason != null) { finishReason = mapOpenAIFinishReason(choice.finish_reason); } if (choice?.text != null) { controller.enqueue({ type: 'text-delta', textDelta: choice.text, }); } const mappedLogprobs = mapOpenAICompletionLogProbs( choice?.logprobs, ); if (mappedLogprobs?.length) { if (logprobs === undefined) logprobs = []; logprobs.push(...mappedLogprobs); } }, flush(controller) { controller.enqueue({ type: 'finish', finishReason, logprobs, usage, }); }, }), ), rawCall: { rawPrompt, rawSettings }, rawResponse: { headers: responseHeaders }, warnings: [], }; } } // limited version of the schema, focussed on what is needed for the implementation // this approach limits breakages when the API changes and increases efficiency const openAICompletionResponseSchema = z.object({ choices: z.array( z.object({ text: z.string(), finish_reason: z.string(), logprobs: z .object({ tokens: z.array(z.string()), token_logprobs: z.array(z.number()), top_logprobs: z.array(z.record(z.string(), z.number())).nullable(), }) .nullable() .optional(), }), ), usage: z.object({ prompt_tokens: z.number(), completion_tokens: z.number(), }), }); // limited version of the schema, focussed on what is needed for the implementation // this approach limits breakages when the API changes and increases efficiency const openaiCompletionChunkSchema = z.union([ z.object({ choices: z.array( z.object({ text: z.string(), finish_reason: z.string().nullish(), index: z.number(), logprobs: z .object({ tokens: z.array(z.string()), token_logprobs: z.array(z.number()), top_logprobs: z.array(z.record(z.string(), z.number())).nullable(), }) .nullable() .optional(), }), ), usage: z .object({ prompt_tokens: z.number(), completion_tokens: z.number(), }) .optional() .nullable(), }), openAIErrorDataSchema, ]);