Skip to content
Merged
52 changes: 31 additions & 21 deletions libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -215,25 +215,27 @@ function convertMessagesToMistralMessages(
.map((toolCall) => ({ ...toolCall, id: "null" }))
.map(convertLangChainToolCallToOpenAI) as MistralAIToolCalls[];
}
if (message.additional_kwargs.tool_calls === undefined) {
if (!message.additional_kwargs.tool_calls?.length) {
return undefined;
}
const toolCalls: Omit<OpenAIToolCall, "index">[] =
message.additional_kwargs.tool_calls;
return (
toolCalls?.map((toolCall) => ({
id: "null",
type: "function" as ToolType.function,
function: toolCall.function,
})) || []
);
return toolCalls?.map((toolCall) => ({
id: "null",
type: "function" as ToolType.function,
function: toolCall.function,
}));
};

return messages.map((message) => ({
role: getRole(message._getType()),
content: getContent(message.content),
tool_calls: getTools(message),
}));
return messages.map((message) => {
const toolCalls = getTools(message);
const content = toolCalls === undefined ? getContent(message.content) : "";
return {
role: getRole(message._getType()),
content,
tool_calls: toolCalls,
};
});
}

function mistralAIResponseToChatMessage(
Expand Down Expand Up @@ -505,15 +507,23 @@ export class ChatMistralAI<
const client = new MistralClient(this.apiKey, this.endpoint);

return this.caller.call(async () => {
let res:
| ChatCompletionResponse
| AsyncGenerator<ChatCompletionResponseChunk>;
if (streaming) {
res = client.chatStream(input);
} else {
res = await client.chat(input);
try {
let res:
| ChatCompletionResponse
| AsyncGenerator<ChatCompletionResponseChunk>;
if (streaming) {
res = client.chatStream(input);
} else {
res = await client.chat(input);
}
return res;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
if (e.message?.includes("status: 400")) {
e.status = 400;
}
throw e;
}
return res;
});
}

Expand Down