Spaces:
Running
on
Inf2
Running
on
Inf2
fix(endpoints): fix for tool calling on hf inference with openai endpoint type (#1754)
Browse files* fix(endpoints): fix for tool calling on hf inference with openai endpoint type
* moar fix
* fix: typechecks
src/lib/server/endpoints/openai/openAIChatToTextGenerationStream.ts
CHANGED
|
@@ -49,10 +49,47 @@ export async function* openAIChatToTextGenerationStream(
|
|
| 49 |
let generatedText = "";
|
| 50 |
let tokenId = 0;
|
| 51 |
const toolCalls: ToolCallWithParameters[] = [];
|
|
|
|
|
|
|
| 52 |
for await (const completion of completionStream) {
|
| 53 |
const { choices } = completion;
|
| 54 |
const content = choices[0]?.delta?.content ?? "";
|
| 55 |
const last = choices[0]?.finish_reason === "stop" || choices[0]?.finish_reason === "length";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
if (content) {
|
| 57 |
generatedText = generatedText + content;
|
| 58 |
}
|
|
|
|
| 49 |
let generatedText = "";
|
| 50 |
let tokenId = 0;
|
| 51 |
const toolCalls: ToolCallWithParameters[] = [];
|
| 52 |
+
let toolBuffer = ""; // XXX: hack because tools seem broken on tgi openai endpoints?
|
| 53 |
+
|
| 54 |
for await (const completion of completionStream) {
|
| 55 |
const { choices } = completion;
|
| 56 |
const content = choices[0]?.delta?.content ?? "";
|
| 57 |
const last = choices[0]?.finish_reason === "stop" || choices[0]?.finish_reason === "length";
|
| 58 |
+
|
| 59 |
+
// if the last token is a stop and the tool buffer is not empty, yield it as a generated_text
|
| 60 |
+
if (choices[0]?.finish_reason === "stop" && toolBuffer.length > 0) {
|
| 61 |
+
yield {
|
| 62 |
+
token: {
|
| 63 |
+
id: tokenId++,
|
| 64 |
+
special: true,
|
| 65 |
+
logprob: 0,
|
| 66 |
+
text: "",
|
| 67 |
+
},
|
| 68 |
+
generated_text: toolBuffer,
|
| 69 |
+
details: null,
|
| 70 |
+
} as TextGenerationStreamOutput;
|
| 71 |
+
break;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
// weird bug where the parameters are streamed in like this
|
| 75 |
+
if (choices[0]?.delta?.tool_calls) {
|
| 76 |
+
const calls = Array.isArray(choices[0].delta.tool_calls)
|
| 77 |
+
? choices[0].delta.tool_calls
|
| 78 |
+
: [choices[0].delta.tool_calls];
|
| 79 |
+
|
| 80 |
+
if (
|
| 81 |
+
calls.length === 1 &&
|
| 82 |
+
calls[0].index === 0 &&
|
| 83 |
+
calls[0].id === "" &&
|
| 84 |
+
calls[0].type === "function" &&
|
| 85 |
+
!!calls[0].function &&
|
| 86 |
+
calls[0].function.name === null
|
| 87 |
+
) {
|
| 88 |
+
toolBuffer += calls[0].function.arguments;
|
| 89 |
+
continue;
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
if (content) {
|
| 94 |
generatedText = generatedText + content;
|
| 95 |
}
|
src/lib/server/textGeneration/tools.ts
CHANGED
|
@@ -314,6 +314,22 @@ function isValidCallObject(call: unknown): call is Record<string, unknown> {
|
|
| 314 |
}
|
| 315 |
|
| 316 |
function parseExternalCall(callObj: Record<string, unknown>) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
const nameFields = ["tool_name", "name"] as const;
|
| 318 |
const parametersFields = ["parameters", "arguments", "parameter_definitions"] as const;
|
| 319 |
|
|
@@ -323,14 +339,14 @@ function parseExternalCall(callObj: Record<string, unknown>) {
|
|
| 323 |
};
|
| 324 |
|
| 325 |
for (const name of nameFields) {
|
| 326 |
-
if (
|
| 327 |
-
groupedCall.tool_name =
|
| 328 |
}
|
| 329 |
}
|
| 330 |
|
| 331 |
for (const name of parametersFields) {
|
| 332 |
-
if (
|
| 333 |
-
groupedCall.parameters =
|
| 334 |
}
|
| 335 |
}
|
| 336 |
|
|
|
|
| 314 |
}
|
| 315 |
|
| 316 |
function parseExternalCall(callObj: Record<string, unknown>) {
|
| 317 |
+
let toolCall = callObj;
|
| 318 |
+
if (
|
| 319 |
+
isValidCallObject(callObj) &&
|
| 320 |
+
"function" in callObj &&
|
| 321 |
+
isValidCallObject(callObj.function) &&
|
| 322 |
+
"_name" in callObj.function
|
| 323 |
+
) {
|
| 324 |
+
toolCall = {
|
| 325 |
+
tool_name: callObj["function"]["_name"],
|
| 326 |
+
parameters: {
|
| 327 |
+
...callObj["function"],
|
| 328 |
+
_name: undefined,
|
| 329 |
+
},
|
| 330 |
+
};
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
const nameFields = ["tool_name", "name"] as const;
|
| 334 |
const parametersFields = ["parameters", "arguments", "parameter_definitions"] as const;
|
| 335 |
|
|
|
|
| 339 |
};
|
| 340 |
|
| 341 |
for (const name of nameFields) {
|
| 342 |
+
if (toolCall[name]) {
|
| 343 |
+
groupedCall.tool_name = toolCall[name] as string;
|
| 344 |
}
|
| 345 |
}
|
| 346 |
|
| 347 |
for (const name of parametersFields) {
|
| 348 |
+
if (toolCall[name]) {
|
| 349 |
+
groupedCall.parameters = toolCall[name] as Record<string, string>;
|
| 350 |
}
|
| 351 |
}
|
| 352 |
|