Skip to content

Instantly share code, notes, and snippets.

@f1shy-dev
Created May 5, 2025 17:34
Show Gist options
  • Save f1shy-dev/540eeca44270e1ae595e21d4545e3a7b to your computer and use it in GitHub Desktop.
Save f1shy-dev/540eeca44270e1ae595e21d4545e3a7b to your computer and use it in GitHub Desktop.

Revisions

  1. f1shy-dev created this gist May 5, 2025.
    245 changes: 245 additions & 0 deletions with-retry.ts
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,245 @@
    import { DelayedPromise } from "@/lib/utils/delayed-promise";
    import type {
    generateText,
    LanguageModel,
    LanguageModelUsage,
    TextStreamPart,
    ToolSet,
    } from "ai";
    import { streamText } from "ai";

    const retry = async <T>(
    fn: (attempt: number) => Promise<T>,
    tries = 3,
    ): Promise<T> => {
    let attempts = 1;

    const _wrap = async () => {
    if (attempts > tries) {
    throw new Error("Max retries reached");
    }
    try {
    return await fn(attempts);
    } catch (error) {
    attempts++;
    return await _wrap();
    }
    };

    return await _wrap();
    };

    export type VertexSafetySettings = Array<{
    category:
    | "HARM_CATEGORY_UNSPECIFIED"
    | "HARM_CATEGORY_HATE_SPEECH"
    | "HARM_CATEGORY_DANGEROUS_CONTENT"
    | "HARM_CATEGORY_HARASSMENT"
    | "HARM_CATEGORY_SEXUALLY_EXPLICIT";
    threshold:
    | "HARM_BLOCK_THRESHOLD_UNSPECIFIED"
    | "BLOCK_LOW_AND_ABOVE"
    | "BLOCK_MEDIUM_AND_ABOVE"
    | "BLOCK_ONLY_HIGH"
    | "BLOCK_NONE";
    }>;
    type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;

    export type SingularLMArray<T extends typeof generateText | typeof streamText> =
    | [
    LanguageModel,
    Omit<Omit<Parameters<T>[0], "model">, "schema"> & {
    safetySettings?: VertexSafetySettings;
    },
    ]
    | [
    LanguageModel,
    Omit<Omit<Parameters<T>[0], "model">, "schema"> & {
    safetySettings?: VertexSafetySettings;
    },
    string,
    ];

    type RetryAIGenerateReturnType<
    T extends typeof generateText | typeof streamText,
    > = ReturnType<T> & {
    attempts: number;
    resolvedModel: Promise<SingularLMArray<T>>;
    } & (T extends typeof streamText
    ? { fullStreamWithoutErrors: AsyncIterableStream<TextStreamPart<ToolSet>> }
    : unknown);

    export const retryAIGenerate = async <
    T extends typeof generateText | typeof streamText,
    >(
    fn: T,
    {
    models,
    sharedOptions,
    }: {
    models: SingularLMArray<T>[];
    sharedOptions: Omit<Parameters<T>[0], "model">;
    },
    ): Promise<RetryAIGenerateReturnType<T>> => {
    let attempts = 1;

    let persistentOutputTransform: TransformStream | null = null;
    let textAccumulator = "";

    let activeUsagePromise: Promise<LanguageModelUsage> | null = null;
    const textPromise = new DelayedPromise<string>();

    const usagePromise = new DelayedPromise<LanguageModelUsage>();
    const resolvedModelPromise = new DelayedPromise<SingularLMArray<T>>();

    const resolveModel = () => {
    let key = models[attempts - 1];
    if (!key) key = models[models.length - 1];
    resolvedModelPromise.resolve(key);
    };

    if (fn === streamText) {
    persistentOutputTransform = new TransformStream({
    transform(chunk, controller) {
    if (chunk.type === "text") {
    textAccumulator += chunk.text;
    }

    controller.enqueue(chunk);
    },
    async flush() {
    textPromise.resolve(textAccumulator);

    if (activeUsagePromise) {
    usagePromise.resolve(await activeUsagePromise);
    }
    resolveModel();
    },
    });
    }

    const _wrap = async () => {
    if (attempts > models.length) {
    if (persistentOutputTransform) {
    const writer = persistentOutputTransform.writable.getWriter();
    await writer.abort(new Error("Max retries reached"));

    textPromise.resolve(textAccumulator);
    usagePromise.reject(new Error("Max retries reached"));
    resolveModel();
    }
    throw new Error("Max retries reached");
    }

    try {
    const options_combined = {
    ...(models[attempts - 1][1] || {}),
    providerOptions: {
    ...(sharedOptions.providerOptions || {}),
    ...(models[attempts - 1][1]?.providerOptions || {}),
    },
    } as Parameters<T>[0];
    console.log(
    "Attempting to generate text with model",
    models[attempts - 1][0].modelId,
    "with options",
    options_combined,
    );

    const final = await fn({
    ...sharedOptions,
    ...options_combined,
    model: models[attempts - 1][0],
    maxRetries: 1,
    } as Parameters<T>[0]);

    Object.assign(final, {
    attempts,
    });

    if ("fullStream" in final && persistentOutputTransform) {
    activeUsagePromise = final.usage;

    const processCurrentModelStream = async () => {
    const reader = final.fullStream.getReader();
    const writer = persistentOutputTransform!.writable.getWriter();

    try {
    while (true) {
    const { done, value } = await reader.read();
    if (done) {
    await writer.close();
    break;
    }

    if (value.type === "error") {
    console.error("gen[error] <- in retryAIGenerate stream");
    // Instead of aborting, we'll retry with the next model
    throw value.error;
    }

    await writer.write(value);
    }

    usagePromise.resolve(await final.usage);
    resolveModel();
    } catch (error) {
    console.error("Stream processing error:", error);

    writer.releaseLock();

    attempts++;
    if (attempts <= models.length) {
    await _wrap();
    } else {
    const finalWriter =
    persistentOutputTransform!.writable.getWriter();
    await finalWriter.abort(error);

    textPromise.resolve(textAccumulator);
    usagePromise.reject(error);
    resolveModel();
    }
    }
    };

    processCurrentModelStream();

    Object.defineProperty(final, "text", {
    enumerable: true,
    configurable: true,
    get() {
    return textPromise.value;
    },
    });

    Object.defineProperty(final, "usage", {
    enumerable: true,
    configurable: true,
    get() {
    return usagePromise.value;
    },
    });

    Object.defineProperty(final, "resolvedModel", {
    enumerable: true,
    configurable: true,
    get() {
    return resolvedModelPromise.value;
    },
    });

    Object.assign(final, {
    fullStreamWithoutErrors: persistentOutputTransform.readable,
    });
    }

    return final as RetryAIGenerateReturnType<T>;
    } catch (error) {
    attempts++;
    return await _wrap();
    }
    };

    return await _wrap();
    };