Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions frontend/src/components/ai/__tests__/ai-utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ describe("ai-utils", () => {
describe("getConfiguredProvider", () => {
it("should return undefined when no AI config", () => {
const config: UserConfig = {} as UserConfig;
expect(getConfiguredProvider(config)).toBeUndefined();
expect(getConfiguredProvider(config.ai)).toBeUndefined();
});

it("should return undefined when AI config has no credentials", () => {
const config: UserConfig = {
ai: {},
} as UserConfig;
expect(getConfiguredProvider(config)).toBeUndefined();
expect(getConfiguredProvider(config.ai)).toBeUndefined();
});

it("should return openai when OpenAI API key is set", () => {
Expand All @@ -77,7 +77,7 @@ describe("ai-utils", () => {
open_ai: { api_key: "sk-test" },
},
} as UserConfig;
expect(getConfiguredProvider(config)).toBe("openai");
expect(getConfiguredProvider(config.ai)).toBe("openai");
});

it("should return anthropic when Anthropic API key is set", () => {
Expand All @@ -86,7 +86,7 @@ describe("ai-utils", () => {
anthropic: { api_key: "sk-ant-test" },
},
} as UserConfig;
expect(getConfiguredProvider(config)).toBe("anthropic");
expect(getConfiguredProvider(config.ai)).toBe("anthropic");
});

it("should return google when Google API key is set", () => {
Expand All @@ -95,7 +95,7 @@ describe("ai-utils", () => {
google: { api_key: "google-key" },
},
} as UserConfig;
expect(getConfiguredProvider(config)).toBe("google");
expect(getConfiguredProvider(config.ai)).toBe("google");
});

it("should return ollama when Ollama base URL is set", () => {
Expand All @@ -104,7 +104,7 @@ describe("ai-utils", () => {
ollama: { base_url: "http://localhost:11434" },
},
} as UserConfig;
expect(getConfiguredProvider(config)).toBe("ollama");
expect(getConfiguredProvider(config.ai)).toBe("ollama");
});

it("should return azure only when both API key and base URL are set", () => {
Expand All @@ -113,7 +113,7 @@ describe("ai-utils", () => {
azure: { api_key: "azure-key", base_url: "https://azure.com" },
},
} as UserConfig;
expect(getConfiguredProvider(config)).toBe("azure");
expect(getConfiguredProvider(config.ai)).toBe("azure");
});

it("should return undefined for azure with only API key", () => {
Expand All @@ -122,7 +122,7 @@ describe("ai-utils", () => {
azure: { api_key: "azure-key" },
},
} as UserConfig;
expect(getConfiguredProvider(config)).toBeUndefined();
expect(getConfiguredProvider(config.ai)).toBeUndefined();
});

it("should return custom provider when configured", () => {
Expand All @@ -133,14 +133,14 @@ describe("ai-utils", () => {
},
},
} as unknown as UserConfig;
expect(getConfiguredProvider(config)).toBe("my_provider");
expect(getConfiguredProvider(config.ai)).toBe("my_provider");
});
});

describe("getRecommendedModel", () => {
it("should return undefined when no provider is configured", () => {
const config: UserConfig = {} as UserConfig;
expect(getRecommendedModel(config)).toBeUndefined();
expect(getRecommendedModel(config.ai)).toBeUndefined();
});

it("should return openai model when OpenAI is configured", () => {
Expand All @@ -149,7 +149,7 @@ describe("ai-utils", () => {
open_ai: { api_key: "sk-test" },
},
} as UserConfig;
expect(getRecommendedModel(config)).toBe("openai/gpt-4");
expect(getRecommendedModel(config.ai)).toBe("openai/gpt-4");
});

it("should return anthropic model when Anthropic is configured", () => {
Expand All @@ -158,7 +158,7 @@ describe("ai-utils", () => {
anthropic: { api_key: "sk-ant-test" },
},
} as UserConfig;
expect(getRecommendedModel(config)).toBe("anthropic/claude-3-sonnet");
expect(getRecommendedModel(config.ai)).toBe("anthropic/claude-3-sonnet");
});

it("should return google model when Google is configured", () => {
Expand All @@ -167,7 +167,7 @@ describe("ai-utils", () => {
google: { api_key: "google-key" },
},
} as UserConfig;
expect(getRecommendedModel(config)).toBe("google/gemini-pro");
expect(getRecommendedModel(config.ai)).toBe("google/gemini-pro");
});

it("should return ollama model when Ollama is configured", () => {
Expand All @@ -176,7 +176,7 @@ describe("ai-utils", () => {
ollama: { base_url: "http://localhost:11434" },
},
} as UserConfig;
expect(getRecommendedModel(config)).toBe("ollama/llama2");
expect(getRecommendedModel(config.ai)).toBe("ollama/llama2");
});
});

Expand All @@ -194,7 +194,7 @@ describe("ai-utils", () => {
},
} as unknown as UserConfig;

const result = autoPopulateModels(values);
const result = autoPopulateModels(values.ai);

expect(result.chatModel).toBeUndefined();
expect(result.editModel).toBeUndefined();
Expand All @@ -205,7 +205,7 @@ describe("ai-utils", () => {
ai: {},
} as UserConfig;

const result = autoPopulateModels(values);
const result = autoPopulateModels(values.ai);

expect(result.chatModel).toBeUndefined();
expect(result.editModel).toBeUndefined();
Expand All @@ -218,7 +218,7 @@ describe("ai-utils", () => {
},
} as UserConfig;

const result = autoPopulateModels(values);
const result = autoPopulateModels(values.ai);

expect(result.chatModel).toBe("openai/gpt-4");
expect(result.editModel).toBe("openai/gpt-4");
Expand All @@ -236,7 +236,7 @@ describe("ai-utils", () => {
},
} as unknown as UserConfig;

const result = autoPopulateModels(values);
const result = autoPopulateModels(values.ai);

expect(result.chatModel).toBe("openai/gpt-4");
expect(result.editModel).toBeUndefined();
Expand All @@ -254,7 +254,7 @@ describe("ai-utils", () => {
},
} as unknown as UserConfig;

const result = autoPopulateModels(values);
const result = autoPopulateModels(values.ai);

expect(result.chatModel).toBeUndefined();
expect(result.editModel).toBe("openai/gpt-4");
Expand All @@ -267,7 +267,7 @@ describe("ai-utils", () => {
},
} as UserConfig;

const result = autoPopulateModels(values);
const result = autoPopulateModels(values.ai);

expect(result.chatModel).toBe("anthropic/claude-3-sonnet");
expect(result.editModel).toBe("anthropic/claude-3-sonnet");
Expand Down
18 changes: 8 additions & 10 deletions frontend/src/components/ai/ai-model-dropdown.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
} from "@/core/ai/ids/ids";
import { type AiModel, AiModelRegistry } from "@/core/ai/model-registry";
import { aiAtom, completionAtom } from "@/core/config/config";
import { useOpenSettingsToTab } from "../app-config/state";
import {
DropdownMenu,
DropdownMenuContent,
Expand Down Expand Up @@ -62,6 +63,7 @@ export const AIModelDropdown = ({
const ai = useAtomValue(aiAtom);
const completion = useAtomValue(completionAtom);
const { saveModelChange } = useModelChange();
const { handleClick } = useOpenSettingsToTab();

// Only include autocompleteModel if copilot is set to "custom"
const autocompleteModel =
Expand Down Expand Up @@ -188,16 +190,12 @@ export const AIModelDropdown = ({
{showAddCustomModelDocs && (
<>
<DropdownMenuSeparator />
<DropdownMenuItem className="flex items-center gap-2">
<a
className="flex items-center gap-1"
href="https://links.marimo.app/custom-models"
target="_blank"
rel="noreferrer"
>
<CircleHelpIcon className="h-3 w-3" />
<span>How to add a custom model</span>
</a>
<DropdownMenuItem
className="h-7 flex items-center gap-2"
onClick={() => handleClick("ai", "ai-models")}
>
<CircleHelpIcon className="h-3 w-3" />
<span className="cursor-pointer text-link">Add custom model</span>
</DropdownMenuItem>
</>
)}
Expand Down
20 changes: 11 additions & 9 deletions frontend/src/components/ai/ai-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,16 @@ const CREDENTIAL_CHECKERS: Record<KnownProviderId, CredentialChecker> = {
* Returns the first configured provider based on credentials.
*/
export function getConfiguredProvider(
config: UserConfig,
config: UserConfig["ai"],
): ProviderId | undefined {
const ai = config.ai;

for (const provider of KNOWN_PROVIDERS) {
if (CREDENTIAL_CHECKERS[provider](ai)) {
if (CREDENTIAL_CHECKERS[provider](config)) {
return provider;
}
}

// Check custom providers
const customProviders = ai?.custom_providers;
const customProviders = config?.custom_providers;
if (customProviders) {
const firstCustomProvider = Object.entries(customProviders).find(
([_, providerConfig]) => providerConfig?.base_url,
Expand All @@ -54,7 +52,9 @@ export function getConfiguredProvider(
}
}

export function getRecommendedModel(config: UserConfig): string | undefined {
export function getRecommendedModel(
config: UserConfig["ai"],
): string | undefined {
const provider = getConfiguredProvider(config);
if (!provider) {
return undefined;
Expand All @@ -73,14 +73,16 @@ export interface AutoPopulateResult {
*
* @param values - The full form values
*/
export function autoPopulateModels(values: UserConfig): AutoPopulateResult {
export function autoPopulateModels(
values: UserConfig["ai"],
): AutoPopulateResult {
const result: AutoPopulateResult = {
chatModel: undefined,
editModel: undefined,
};

const needsChatModel = !values.ai?.models?.chat_model;
const needsEditModel = !values.ai?.models?.edit_model;
const needsChatModel = !values?.models?.chat_model;
const needsEditModel = !values?.models?.edit_model;

if (!needsChatModel && !needsEditModel) {
return result;
Expand Down
Loading
Loading