Spaces:
Running
on
Inf2
Running
on
Inf2
Expose sampling controls in assistants (#955) (#959)
Browse files* Expose sampling controls in assistants (#955)
* Make sure all labels have the same font size
* styling
* Add better tooltips
* better padding & wrapping
* Revert "better padding & wrapping"
This reverts commit 1b44086465040f2cb6bc906983cfc8d95820d6fe.
* ui update
* tooltip on mobile
* lint
* Update src/lib/components/AssistantSettings.svelte
Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
---------
Co-authored-by: Victor Mustar <victor.mustar@gmail.com>
Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
- src/lib/components/AssistantSettings.svelte +128 -17
- src/lib/components/HoverTooltip.svelte +12 -0
- src/lib/server/endpoints/anthropic/endpointAnthropic.ts +9 -6
- src/lib/server/endpoints/aws/endpointAws.ts +2 -2
- src/lib/server/endpoints/endpoints.ts +2 -0
- src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts +9 -7
- src/lib/server/endpoints/ollama/endpointOllama.ts +9 -7
- src/lib/server/endpoints/openai/endpointOai.ts +16 -12
- src/lib/server/endpoints/tgi/endpointTgi.ts +2 -2
- src/lib/types/Assistant.ts +6 -0
- src/routes/conversation/[id]/+server.ts +10 -3
- src/routes/settings/(nav)/assistants/[assistantId]/edit/+page.server.ts +20 -0
- src/routes/settings/(nav)/assistants/new/+page.server.ts +20 -0
src/lib/components/AssistantSettings.svelte
CHANGED
|
@@ -9,11 +9,14 @@
|
|
| 9 |
import { base } from "$app/paths";
|
| 10 |
import CarbonPen from "~icons/carbon/pen";
|
| 11 |
import CarbonUpload from "~icons/carbon/upload";
|
|
|
|
|
|
|
| 12 |
|
| 13 |
import { useSettingsStore } from "$lib/stores/settings";
|
| 14 |
import { isHuggingChat } from "$lib/utils/isHuggingChat";
|
| 15 |
import IconInternet from "./icons/IconInternet.svelte";
|
| 16 |
import TokensCounter from "./TokensCounter.svelte";
|
|
|
|
| 17 |
|
| 18 |
type ActionData = {
|
| 19 |
error: boolean;
|
|
@@ -31,16 +34,22 @@
|
|
| 31 |
|
| 32 |
let files: FileList | null = null;
|
| 33 |
const settings = useSettingsStore();
|
| 34 |
-
let modelId =
|
| 35 |
-
assistant?.modelId ?? models.find((_model) => _model.id === $settings.activeModel)?.name;
|
| 36 |
let systemPrompt = assistant?.preprompt ?? "";
|
| 37 |
let dynamicPrompt = assistant?.dynamicPrompt ?? false;
|
|
|
|
| 38 |
|
| 39 |
let compress: typeof readAndCompressImage | null = null;
|
| 40 |
|
| 41 |
onMount(async () => {
|
| 42 |
const module = await import("browser-image-resizer");
|
| 43 |
compress = module.readAndCompressImage;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
});
|
| 45 |
|
| 46 |
let inputMessage1 = assistant?.exampleInputs[0] ?? "";
|
|
@@ -89,11 +98,12 @@
|
|
| 89 |
|
| 90 |
const regex = /{{\s?url=(.+?)\s?}}/g;
|
| 91 |
$: templateVariables = [...systemPrompt.matchAll(regex)].map((match) => match[1]);
|
|
|
|
| 92 |
</script>
|
| 93 |
|
| 94 |
<form
|
| 95 |
method="POST"
|
| 96 |
-
class="flex h-full flex-col overflow-y-auto p-4 md:p-8"
|
| 97 |
enctype="multipart/form-data"
|
| 98 |
use:enhance={async ({ formData }) => {
|
| 99 |
loading = true;
|
|
@@ -246,21 +256,122 @@
|
|
| 246 |
|
| 247 |
<label>
|
| 248 |
<div class="mb-1 font-semibold">Model</div>
|
| 249 |
-
<
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
>
|
| 254 |
-
{
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
</label>
|
| 265 |
|
| 266 |
<label>
|
|
|
|
| 9 |
import { base } from "$app/paths";
|
| 10 |
import CarbonPen from "~icons/carbon/pen";
|
| 11 |
import CarbonUpload from "~icons/carbon/upload";
|
| 12 |
+
import CarbonHelpFilled from "~icons/carbon/help";
|
| 13 |
+
import CarbonSettingsAdjust from "~icons/carbon/settings-adjust";
|
| 14 |
|
| 15 |
import { useSettingsStore } from "$lib/stores/settings";
|
| 16 |
import { isHuggingChat } from "$lib/utils/isHuggingChat";
|
| 17 |
import IconInternet from "./icons/IconInternet.svelte";
|
| 18 |
import TokensCounter from "./TokensCounter.svelte";
|
| 19 |
+
import HoverTooltip from "./HoverTooltip.svelte";
|
| 20 |
|
| 21 |
type ActionData = {
|
| 22 |
error: boolean;
|
|
|
|
| 34 |
|
| 35 |
let files: FileList | null = null;
|
| 36 |
const settings = useSettingsStore();
|
| 37 |
+
let modelId = "";
|
|
|
|
| 38 |
let systemPrompt = assistant?.preprompt ?? "";
|
| 39 |
let dynamicPrompt = assistant?.dynamicPrompt ?? false;
|
| 40 |
+
let showModelSettings = Object.values(assistant?.generateSettings ?? {}).some((v) => !!v);
|
| 41 |
|
| 42 |
let compress: typeof readAndCompressImage | null = null;
|
| 43 |
|
| 44 |
onMount(async () => {
|
| 45 |
const module = await import("browser-image-resizer");
|
| 46 |
compress = module.readAndCompressImage;
|
| 47 |
+
|
| 48 |
+
if (assistant) {
|
| 49 |
+
modelId = assistant.modelId;
|
| 50 |
+
} else {
|
| 51 |
+
modelId = models.find((model) => model.id === $settings.activeModel)?.id ?? models[0].id;
|
| 52 |
+
}
|
| 53 |
});
|
| 54 |
|
| 55 |
let inputMessage1 = assistant?.exampleInputs[0] ?? "";
|
|
|
|
| 98 |
|
| 99 |
const regex = /{{\s?url=(.+?)\s?}}/g;
|
| 100 |
$: templateVariables = [...systemPrompt.matchAll(regex)].map((match) => match[1]);
|
| 101 |
+
$: selectedModel = models.find((m) => m.id === modelId);
|
| 102 |
</script>
|
| 103 |
|
| 104 |
<form
|
| 105 |
method="POST"
|
| 106 |
+
class="relative flex h-full flex-col overflow-y-auto p-4 md:p-8"
|
| 107 |
enctype="multipart/form-data"
|
| 108 |
use:enhance={async ({ formData }) => {
|
| 109 |
loading = true;
|
|
|
|
| 256 |
|
| 257 |
<label>
|
| 258 |
<div class="mb-1 font-semibold">Model</div>
|
| 259 |
+
<div class="flex gap-2">
|
| 260 |
+
<select
|
| 261 |
+
name="modelId"
|
| 262 |
+
class="w-full rounded-lg border-2 border-gray-200 bg-gray-100 p-2"
|
| 263 |
+
bind:value={modelId}
|
| 264 |
+
>
|
| 265 |
+
{#each models.filter((model) => !model.unlisted) as model}
|
| 266 |
+
<option value={model.id}>{model.displayName}</option>
|
| 267 |
+
{/each}
|
| 268 |
+
<p class="text-xs text-red-500">{getError("modelId", form)}</p>
|
| 269 |
+
</select>
|
| 270 |
+
<button
|
| 271 |
+
type="button"
|
| 272 |
+
class="flex aspect-square items-center gap-2 whitespace-nowrap rounded-lg border px-3 {showModelSettings
|
| 273 |
+
? 'border-blue-500/20 bg-blue-50 text-blue-600'
|
| 274 |
+
: ''}"
|
| 275 |
+
on:click={() => (showModelSettings = !showModelSettings)}
|
| 276 |
+
><CarbonSettingsAdjust class="text-xs" /></button
|
| 277 |
+
>
|
| 278 |
+
</div>
|
| 279 |
+
<div
|
| 280 |
+
class="mt-2 rounded-lg border border-blue-500/20 bg-blue-500/5 px-2 py-0.5"
|
| 281 |
+
class:hidden={!showModelSettings}
|
| 282 |
>
|
| 283 |
+
<p class="text-xs text-red-500">{getError("inputMessage1", form)}</p>
|
| 284 |
+
<div class="my-2 grid grid-cols-1 gap-2.5 sm:grid-cols-2 sm:grid-rows-2">
|
| 285 |
+
<label for="temperature" class="flex justify-between">
|
| 286 |
+
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
|
| 287 |
+
Temperature
|
| 288 |
+
|
| 289 |
+
<HoverTooltip
|
| 290 |
+
label="Temperature: Controls creativity, higher values allow more variety."
|
| 291 |
+
>
|
| 292 |
+
<CarbonHelpFilled
|
| 293 |
+
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
|
| 294 |
+
/>
|
| 295 |
+
</HoverTooltip>
|
| 296 |
+
</span>
|
| 297 |
+
<input
|
| 298 |
+
type="number"
|
| 299 |
+
name="temperature"
|
| 300 |
+
min="0.1"
|
| 301 |
+
max="2"
|
| 302 |
+
step="0.1"
|
| 303 |
+
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
|
| 304 |
+
placeholder={selectedModel?.parameters?.temperature?.toString() ?? "1"}
|
| 305 |
+
value={assistant?.generateSettings?.temperature ?? ""}
|
| 306 |
+
/>
|
| 307 |
+
</label>
|
| 308 |
+
<label for="top_p" class="flex justify-between">
|
| 309 |
+
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
|
| 310 |
+
Top P
|
| 311 |
+
<HoverTooltip
|
| 312 |
+
label="Top P: Sets word choice boundaries, lower values tighten focus."
|
| 313 |
+
>
|
| 314 |
+
<CarbonHelpFilled
|
| 315 |
+
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
|
| 316 |
+
/>
|
| 317 |
+
</HoverTooltip>
|
| 318 |
+
</span>
|
| 319 |
+
|
| 320 |
+
<input
|
| 321 |
+
type="number"
|
| 322 |
+
name="top_p"
|
| 323 |
+
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
|
| 324 |
+
min="0.05"
|
| 325 |
+
max="1"
|
| 326 |
+
step="0.05"
|
| 327 |
+
placeholder={selectedModel?.parameters?.top_p?.toString() ?? "1"}
|
| 328 |
+
value={assistant?.generateSettings?.top_p ?? ""}
|
| 329 |
+
/>
|
| 330 |
+
</label>
|
| 331 |
+
<label for="repetition_penalty" class="flex justify-between">
|
| 332 |
+
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
|
| 333 |
+
Repetition penalty
|
| 334 |
+
<HoverTooltip
|
| 335 |
+
label="Repetition penalty: Prevents reuse, higher values decrease repetition."
|
| 336 |
+
>
|
| 337 |
+
<CarbonHelpFilled
|
| 338 |
+
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
|
| 339 |
+
/>
|
| 340 |
+
</HoverTooltip>
|
| 341 |
+
</span>
|
| 342 |
+
<input
|
| 343 |
+
type="number"
|
| 344 |
+
name="repetition_penalty"
|
| 345 |
+
min="0.1"
|
| 346 |
+
max="2"
|
| 347 |
+
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
|
| 348 |
+
placeholder={selectedModel?.parameters?.repetition_penalty?.toString() ?? "1.0"}
|
| 349 |
+
value={assistant?.generateSettings?.repetition_penalty ?? ""}
|
| 350 |
+
/>
|
| 351 |
+
</label>
|
| 352 |
+
<label for="top_k" class="flex justify-between">
|
| 353 |
+
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
|
| 354 |
+
Top K <HoverTooltip
|
| 355 |
+
label="Top K: Restricts word options, lower values for predictability."
|
| 356 |
+
>
|
| 357 |
+
<CarbonHelpFilled
|
| 358 |
+
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
|
| 359 |
+
/>
|
| 360 |
+
</HoverTooltip>
|
| 361 |
+
</span>
|
| 362 |
+
<input
|
| 363 |
+
type="number"
|
| 364 |
+
name="top_k"
|
| 365 |
+
min="5"
|
| 366 |
+
max="100"
|
| 367 |
+
step="5"
|
| 368 |
+
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
|
| 369 |
+
placeholder={selectedModel?.parameters?.top_k?.toString() ?? "50"}
|
| 370 |
+
value={assistant?.generateSettings?.top_k ?? ""}
|
| 371 |
+
/>
|
| 372 |
+
</label>
|
| 373 |
+
</div>
|
| 374 |
+
</div>
|
| 375 |
</label>
|
| 376 |
|
| 377 |
<label>
|
src/lib/components/HoverTooltip.svelte
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<script lang="ts">
|
| 2 |
+
export let label = "";
|
| 3 |
+
</script>
|
| 4 |
+
|
| 5 |
+
<div class="group/tooltip md:relative">
|
| 6 |
+
<slot />
|
| 7 |
+
<div
|
| 8 |
+
class="invisible absolute z-10 w-64 whitespace-normal rounded-md bg-black p-2 text-center text-white group-hover/tooltip:visible group-active/tooltip:visible max-sm:left-1/2 max-sm:-translate-x-1/2"
|
| 9 |
+
>
|
| 10 |
+
{label}
|
| 11 |
+
</div>
|
| 12 |
+
</div>
|
src/lib/server/endpoints/anthropic/endpointAnthropic.ts
CHANGED
|
@@ -32,7 +32,7 @@ export async function endpointAnthropic(
|
|
| 32 |
defaultQuery,
|
| 33 |
});
|
| 34 |
|
| 35 |
-
return async ({ messages, preprompt }) => {
|
| 36 |
let system = preprompt;
|
| 37 |
if (messages?.[0]?.from === "system") {
|
| 38 |
system = messages[0].content;
|
|
@@ -49,15 +49,18 @@ export async function endpointAnthropic(
|
|
| 49 |
}[];
|
| 50 |
|
| 51 |
let tokenId = 0;
|
|
|
|
|
|
|
|
|
|
| 52 |
return (async function* () {
|
| 53 |
const stream = anthropic.messages.stream({
|
| 54 |
model: model.id ?? model.name,
|
| 55 |
messages: messagesFormatted,
|
| 56 |
-
max_tokens:
|
| 57 |
-
temperature:
|
| 58 |
-
top_p:
|
| 59 |
-
top_k:
|
| 60 |
-
stop_sequences:
|
| 61 |
system,
|
| 62 |
});
|
| 63 |
while (true) {
|
|
|
|
| 32 |
defaultQuery,
|
| 33 |
});
|
| 34 |
|
| 35 |
+
return async ({ messages, preprompt, generateSettings }) => {
|
| 36 |
let system = preprompt;
|
| 37 |
if (messages?.[0]?.from === "system") {
|
| 38 |
system = messages[0].content;
|
|
|
|
| 49 |
}[];
|
| 50 |
|
| 51 |
let tokenId = 0;
|
| 52 |
+
|
| 53 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
| 54 |
+
|
| 55 |
return (async function* () {
|
| 56 |
const stream = anthropic.messages.stream({
|
| 57 |
model: model.id ?? model.name,
|
| 58 |
messages: messagesFormatted,
|
| 59 |
+
max_tokens: parameters?.max_new_tokens,
|
| 60 |
+
temperature: parameters?.temperature,
|
| 61 |
+
top_p: parameters?.top_p,
|
| 62 |
+
top_k: parameters?.top_k,
|
| 63 |
+
stop_sequences: parameters?.stop,
|
| 64 |
system,
|
| 65 |
});
|
| 66 |
while (true) {
|
src/lib/server/endpoints/aws/endpointAws.ts
CHANGED
|
@@ -36,7 +36,7 @@ export async function endpointAws(
|
|
| 36 |
region,
|
| 37 |
});
|
| 38 |
|
| 39 |
-
return async ({ messages, preprompt, continueMessage }) => {
|
| 40 |
const prompt = await buildPrompt({
|
| 41 |
messages,
|
| 42 |
continueMessage,
|
|
@@ -46,7 +46,7 @@ export async function endpointAws(
|
|
| 46 |
|
| 47 |
return textGenerationStream(
|
| 48 |
{
|
| 49 |
-
parameters: { ...model.parameters, return_full_text: false },
|
| 50 |
model: url,
|
| 51 |
inputs: prompt,
|
| 52 |
},
|
|
|
|
| 36 |
region,
|
| 37 |
});
|
| 38 |
|
| 39 |
+
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
|
| 40 |
const prompt = await buildPrompt({
|
| 41 |
messages,
|
| 42 |
continueMessage,
|
|
|
|
| 46 |
|
| 47 |
return textGenerationStream(
|
| 48 |
{
|
| 49 |
+
parameters: { ...model.parameters, ...generateSettings, return_full_text: false },
|
| 50 |
model: url,
|
| 51 |
inputs: prompt,
|
| 52 |
},
|
src/lib/server/endpoints/endpoints.ts
CHANGED
|
@@ -10,12 +10,14 @@ import {
|
|
| 10 |
endpointAnthropic,
|
| 11 |
endpointAnthropicParametersSchema,
|
| 12 |
} from "./anthropic/endpointAnthropic";
|
|
|
|
| 13 |
|
| 14 |
// parameters passed when generating text
|
| 15 |
export interface EndpointParameters {
|
| 16 |
messages: Omit<Conversation["messages"][0], "id">[];
|
| 17 |
preprompt?: Conversation["preprompt"];
|
| 18 |
continueMessage?: boolean; // used to signal that the last message will be extended
|
|
|
|
| 19 |
}
|
| 20 |
|
| 21 |
interface CommonEndpoint {
|
|
|
|
| 10 |
endpointAnthropic,
|
| 11 |
endpointAnthropicParametersSchema,
|
| 12 |
} from "./anthropic/endpointAnthropic";
|
| 13 |
+
import type { Model } from "$lib/types/Model";
|
| 14 |
|
| 15 |
// parameters passed when generating text
|
| 16 |
export interface EndpointParameters {
|
| 17 |
messages: Omit<Conversation["messages"][0], "id">[];
|
| 18 |
preprompt?: Conversation["preprompt"];
|
| 19 |
continueMessage?: boolean; // used to signal that the last message will be extended
|
| 20 |
+
generateSettings?: Partial<Model["parameters"]>;
|
| 21 |
}
|
| 22 |
|
| 23 |
interface CommonEndpoint {
|
src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts
CHANGED
|
@@ -19,7 +19,7 @@ export function endpointLlamacpp(
|
|
| 19 |
input: z.input<typeof endpointLlamacppParametersSchema>
|
| 20 |
): Endpoint {
|
| 21 |
const { url, model } = endpointLlamacppParametersSchema.parse(input);
|
| 22 |
-
return async ({ messages, preprompt, continueMessage }) => {
|
| 23 |
const prompt = await buildPrompt({
|
| 24 |
messages,
|
| 25 |
continueMessage,
|
|
@@ -27,6 +27,8 @@ export function endpointLlamacpp(
|
|
| 27 |
model,
|
| 28 |
});
|
| 29 |
|
|
|
|
|
|
|
| 30 |
const r = await fetch(`${url}/completion`, {
|
| 31 |
method: "POST",
|
| 32 |
headers: {
|
|
@@ -35,12 +37,12 @@ export function endpointLlamacpp(
|
|
| 35 |
body: JSON.stringify({
|
| 36 |
prompt,
|
| 37 |
stream: true,
|
| 38 |
-
temperature:
|
| 39 |
-
top_p:
|
| 40 |
-
top_k:
|
| 41 |
-
stop:
|
| 42 |
-
repeat_penalty:
|
| 43 |
-
n_predict:
|
| 44 |
cache_prompt: true,
|
| 45 |
}),
|
| 46 |
});
|
|
|
|
| 19 |
input: z.input<typeof endpointLlamacppParametersSchema>
|
| 20 |
): Endpoint {
|
| 21 |
const { url, model } = endpointLlamacppParametersSchema.parse(input);
|
| 22 |
+
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
|
| 23 |
const prompt = await buildPrompt({
|
| 24 |
messages,
|
| 25 |
continueMessage,
|
|
|
|
| 27 |
model,
|
| 28 |
});
|
| 29 |
|
| 30 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
| 31 |
+
|
| 32 |
const r = await fetch(`${url}/completion`, {
|
| 33 |
method: "POST",
|
| 34 |
headers: {
|
|
|
|
| 37 |
body: JSON.stringify({
|
| 38 |
prompt,
|
| 39 |
stream: true,
|
| 40 |
+
temperature: parameters.temperature,
|
| 41 |
+
top_p: parameters.top_p,
|
| 42 |
+
top_k: parameters.top_k,
|
| 43 |
+
stop: parameters.stop,
|
| 44 |
+
repeat_penalty: parameters.repetition_penalty,
|
| 45 |
+
n_predict: parameters.max_new_tokens,
|
| 46 |
cache_prompt: true,
|
| 47 |
}),
|
| 48 |
});
|
src/lib/server/endpoints/ollama/endpointOllama.ts
CHANGED
|
@@ -14,7 +14,7 @@ export const endpointOllamaParametersSchema = z.object({
|
|
| 14 |
export function endpointOllama(input: z.input<typeof endpointOllamaParametersSchema>): Endpoint {
|
| 15 |
const { url, model, ollamaName } = endpointOllamaParametersSchema.parse(input);
|
| 16 |
|
| 17 |
-
return async ({ messages, preprompt, continueMessage }) => {
|
| 18 |
const prompt = await buildPrompt({
|
| 19 |
messages,
|
| 20 |
continueMessage,
|
|
@@ -22,6 +22,8 @@ export function endpointOllama(input: z.input<typeof endpointOllamaParametersSch
|
|
| 22 |
model,
|
| 23 |
});
|
| 24 |
|
|
|
|
|
|
|
| 25 |
const r = await fetch(`${url}/api/generate`, {
|
| 26 |
method: "POST",
|
| 27 |
headers: {
|
|
@@ -32,12 +34,12 @@ export function endpointOllama(input: z.input<typeof endpointOllamaParametersSch
|
|
| 32 |
model: ollamaName ?? model.name,
|
| 33 |
raw: true,
|
| 34 |
options: {
|
| 35 |
-
top_p:
|
| 36 |
-
top_k:
|
| 37 |
-
temperature:
|
| 38 |
-
repeat_penalty:
|
| 39 |
-
stop:
|
| 40 |
-
num_predict:
|
| 41 |
},
|
| 42 |
}),
|
| 43 |
});
|
|
|
|
| 14 |
export function endpointOllama(input: z.input<typeof endpointOllamaParametersSchema>): Endpoint {
|
| 15 |
const { url, model, ollamaName } = endpointOllamaParametersSchema.parse(input);
|
| 16 |
|
| 17 |
+
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
|
| 18 |
const prompt = await buildPrompt({
|
| 19 |
messages,
|
| 20 |
continueMessage,
|
|
|
|
| 22 |
model,
|
| 23 |
});
|
| 24 |
|
| 25 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
| 26 |
+
|
| 27 |
const r = await fetch(`${url}/api/generate`, {
|
| 28 |
method: "POST",
|
| 29 |
headers: {
|
|
|
|
| 34 |
model: ollamaName ?? model.name,
|
| 35 |
raw: true,
|
| 36 |
options: {
|
| 37 |
+
top_p: parameters.top_p,
|
| 38 |
+
top_k: parameters.top_k,
|
| 39 |
+
temperature: parameters.temperature,
|
| 40 |
+
repeat_penalty: parameters.repetition_penalty,
|
| 41 |
+
stop: parameters.stop,
|
| 42 |
+
num_predict: parameters.max_new_tokens,
|
| 43 |
},
|
| 44 |
}),
|
| 45 |
});
|
src/lib/server/endpoints/openai/endpointOai.ts
CHANGED
|
@@ -38,7 +38,7 @@ export async function endpointOai(
|
|
| 38 |
});
|
| 39 |
|
| 40 |
if (completion === "completions") {
|
| 41 |
-
return async ({ messages, preprompt, continueMessage }) => {
|
| 42 |
const prompt = await buildPrompt({
|
| 43 |
messages,
|
| 44 |
continueMessage,
|
|
@@ -46,21 +46,23 @@ export async function endpointOai(
|
|
| 46 |
model,
|
| 47 |
});
|
| 48 |
|
|
|
|
|
|
|
| 49 |
return openAICompletionToTextGenerationStream(
|
| 50 |
await openai.completions.create({
|
| 51 |
model: model.id ?? model.name,
|
| 52 |
prompt,
|
| 53 |
stream: true,
|
| 54 |
-
max_tokens:
|
| 55 |
-
stop:
|
| 56 |
-
temperature:
|
| 57 |
-
top_p:
|
| 58 |
-
frequency_penalty:
|
| 59 |
})
|
| 60 |
);
|
| 61 |
};
|
| 62 |
} else if (completion === "chat_completions") {
|
| 63 |
-
return async ({ messages, preprompt }) => {
|
| 64 |
let messagesOpenAI = messages.map((message) => ({
|
| 65 |
role: message.from,
|
| 66 |
content: message.content,
|
|
@@ -74,16 +76,18 @@ export async function endpointOai(
|
|
| 74 |
messagesOpenAI[0].content = preprompt ?? "";
|
| 75 |
}
|
| 76 |
|
|
|
|
|
|
|
| 77 |
return openAIChatToTextGenerationStream(
|
| 78 |
await openai.chat.completions.create({
|
| 79 |
model: model.id ?? model.name,
|
| 80 |
messages: messagesOpenAI,
|
| 81 |
stream: true,
|
| 82 |
-
max_tokens:
|
| 83 |
-
stop:
|
| 84 |
-
temperature:
|
| 85 |
-
top_p:
|
| 86 |
-
frequency_penalty:
|
| 87 |
})
|
| 88 |
);
|
| 89 |
};
|
|
|
|
| 38 |
});
|
| 39 |
|
| 40 |
if (completion === "completions") {
|
| 41 |
+
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
|
| 42 |
const prompt = await buildPrompt({
|
| 43 |
messages,
|
| 44 |
continueMessage,
|
|
|
|
| 46 |
model,
|
| 47 |
});
|
| 48 |
|
| 49 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
| 50 |
+
|
| 51 |
return openAICompletionToTextGenerationStream(
|
| 52 |
await openai.completions.create({
|
| 53 |
model: model.id ?? model.name,
|
| 54 |
prompt,
|
| 55 |
stream: true,
|
| 56 |
+
max_tokens: parameters?.max_new_tokens,
|
| 57 |
+
stop: parameters?.stop,
|
| 58 |
+
temperature: parameters?.temperature,
|
| 59 |
+
top_p: parameters?.top_p,
|
| 60 |
+
frequency_penalty: parameters?.repetition_penalty,
|
| 61 |
})
|
| 62 |
);
|
| 63 |
};
|
| 64 |
} else if (completion === "chat_completions") {
|
| 65 |
+
return async ({ messages, preprompt, generateSettings }) => {
|
| 66 |
let messagesOpenAI = messages.map((message) => ({
|
| 67 |
role: message.from,
|
| 68 |
content: message.content,
|
|
|
|
| 76 |
messagesOpenAI[0].content = preprompt ?? "";
|
| 77 |
}
|
| 78 |
|
| 79 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
| 80 |
+
|
| 81 |
return openAIChatToTextGenerationStream(
|
| 82 |
await openai.chat.completions.create({
|
| 83 |
model: model.id ?? model.name,
|
| 84 |
messages: messagesOpenAI,
|
| 85 |
stream: true,
|
| 86 |
+
max_tokens: parameters?.max_new_tokens,
|
| 87 |
+
stop: parameters?.stop,
|
| 88 |
+
temperature: parameters?.temperature,
|
| 89 |
+
top_p: parameters?.top_p,
|
| 90 |
+
frequency_penalty: parameters?.repetition_penalty,
|
| 91 |
})
|
| 92 |
);
|
| 93 |
};
|
src/lib/server/endpoints/tgi/endpointTgi.ts
CHANGED
|
@@ -16,7 +16,7 @@ export const endpointTgiParametersSchema = z.object({
|
|
| 16 |
export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
|
| 17 |
const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
|
| 18 |
|
| 19 |
-
return async ({ messages, preprompt, continueMessage }) => {
|
| 20 |
const prompt = await buildPrompt({
|
| 21 |
messages,
|
| 22 |
preprompt,
|
|
@@ -26,7 +26,7 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
|
|
| 26 |
|
| 27 |
return textGenerationStream(
|
| 28 |
{
|
| 29 |
-
parameters: { ...model.parameters, return_full_text: false },
|
| 30 |
model: url,
|
| 31 |
inputs: prompt,
|
| 32 |
accessToken,
|
|
|
|
| 16 |
export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
|
| 17 |
const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
|
| 18 |
|
| 19 |
+
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
|
| 20 |
const prompt = await buildPrompt({
|
| 21 |
messages,
|
| 22 |
preprompt,
|
|
|
|
| 26 |
|
| 27 |
return textGenerationStream(
|
| 28 |
{
|
| 29 |
+
parameters: { ...model.parameters, ...generateSettings, return_full_text: false },
|
| 30 |
model: url,
|
| 31 |
inputs: prompt,
|
| 32 |
accessToken,
|
src/lib/types/Assistant.ts
CHANGED
|
@@ -19,6 +19,12 @@ export interface Assistant extends Timestamps {
|
|
| 19 |
allowedDomains: string[];
|
| 20 |
allowedLinks: string[];
|
| 21 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
dynamicPrompt?: boolean;
|
| 23 |
searchTokens: string[];
|
| 24 |
}
|
|
|
|
| 19 |
allowedDomains: string[];
|
| 20 |
allowedLinks: string[];
|
| 21 |
};
|
| 22 |
+
generateSettings?: {
|
| 23 |
+
temperature?: number;
|
| 24 |
+
top_p?: number;
|
| 25 |
+
repetition_penalty?: number;
|
| 26 |
+
top_k?: number;
|
| 27 |
+
};
|
| 28 |
dynamicPrompt?: boolean;
|
| 29 |
searchTokens: string[];
|
| 30 |
}
|
src/routes/conversation/[id]/+server.ts
CHANGED
|
@@ -338,8 +338,11 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 338 |
|
| 339 |
// check if assistant has a rag
|
| 340 |
const assistant = await collections.assistants.findOne<
|
| 341 |
-
Pick<Assistant, "rag" | "dynamicPrompt">
|
| 342 |
-
>(
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
const assistantHasRAG =
|
| 345 |
ENABLE_ASSISTANTS_RAG === "true" &&
|
|
@@ -403,12 +406,15 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 403 |
|
| 404 |
const previousText = messageToWriteTo.content;
|
| 405 |
|
|
|
|
|
|
|
| 406 |
try {
|
| 407 |
const endpoint = await model.getEndpoint();
|
| 408 |
for await (const output of await endpoint({
|
| 409 |
messages: processedMessages,
|
| 410 |
preprompt,
|
| 411 |
continueMessage: isContinue,
|
|
|
|
| 412 |
})) {
|
| 413 |
// if not generated_text is here it means the generation is not done
|
| 414 |
if (!output.generated_text) {
|
|
@@ -448,10 +454,11 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 448 |
}
|
| 449 |
}
|
| 450 |
} catch (e) {
|
|
|
|
| 451 |
update({ type: "status", status: "error", message: (e as Error).message });
|
| 452 |
} finally {
|
| 453 |
// check if no output was generated
|
| 454 |
-
if (messageToWriteTo.content === previousText) {
|
| 455 |
update({
|
| 456 |
type: "status",
|
| 457 |
status: "error",
|
|
|
|
| 338 |
|
| 339 |
// check if assistant has a rag
|
| 340 |
const assistant = await collections.assistants.findOne<
|
| 341 |
+
Pick<Assistant, "rag" | "dynamicPrompt" | "generateSettings">
|
| 342 |
+
>(
|
| 343 |
+
{ _id: conv.assistantId },
|
| 344 |
+
{ projection: { rag: 1, dynamicPrompt: 1, generateSettings: 1 } }
|
| 345 |
+
);
|
| 346 |
|
| 347 |
const assistantHasRAG =
|
| 348 |
ENABLE_ASSISTANTS_RAG === "true" &&
|
|
|
|
| 406 |
|
| 407 |
const previousText = messageToWriteTo.content;
|
| 408 |
|
| 409 |
+
let hasError = false;
|
| 410 |
+
|
| 411 |
try {
|
| 412 |
const endpoint = await model.getEndpoint();
|
| 413 |
for await (const output of await endpoint({
|
| 414 |
messages: processedMessages,
|
| 415 |
preprompt,
|
| 416 |
continueMessage: isContinue,
|
| 417 |
+
generateSettings: assistant?.generateSettings,
|
| 418 |
})) {
|
| 419 |
// if not generated_text is here it means the generation is not done
|
| 420 |
if (!output.generated_text) {
|
|
|
|
| 454 |
}
|
| 455 |
}
|
| 456 |
} catch (e) {
|
| 457 |
+
hasError = true;
|
| 458 |
update({ type: "status", status: "error", message: (e as Error).message });
|
| 459 |
} finally {
|
| 460 |
// check if no output was generated
|
| 461 |
+
if (!hasError && messageToWriteTo.content === previousText) {
|
| 462 |
update({
|
| 463 |
type: "status",
|
| 464 |
status: "error",
|
src/routes/settings/(nav)/assistants/[assistantId]/edit/+page.server.ts
CHANGED
|
@@ -25,6 +25,20 @@ const newAsssistantSchema = z.object({
|
|
| 25 |
ragDomainList: z.preprocess(parseStringToList, z.string().array()),
|
| 26 |
ragAllowAll: z.preprocess((v) => v === "true", z.boolean()),
|
| 27 |
dynamicPrompt: z.preprocess((v) => v === "on", z.boolean()),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
});
|
| 29 |
|
| 30 |
const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
|
|
@@ -143,6 +157,12 @@ export const actions: Actions = {
|
|
| 143 |
},
|
| 144 |
dynamicPrompt: parse.data.dynamicPrompt,
|
| 145 |
searchTokens: generateSearchTokens(parse.data.name),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
},
|
| 147 |
}
|
| 148 |
);
|
|
|
|
| 25 |
ragDomainList: z.preprocess(parseStringToList, z.string().array()),
|
| 26 |
ragAllowAll: z.preprocess((v) => v === "true", z.boolean()),
|
| 27 |
dynamicPrompt: z.preprocess((v) => v === "on", z.boolean()),
|
| 28 |
+
temperature: z
|
| 29 |
+
.union([z.literal(""), z.coerce.number().min(0.1).max(2)])
|
| 30 |
+
.transform((v) => (v === "" ? undefined : v)),
|
| 31 |
+
top_p: z
|
| 32 |
+
.union([z.literal(""), z.coerce.number().min(0.05).max(1)])
|
| 33 |
+
.transform((v) => (v === "" ? undefined : v)),
|
| 34 |
+
|
| 35 |
+
repetition_penalty: z
|
| 36 |
+
.union([z.literal(""), z.coerce.number().min(0.1).max(2)])
|
| 37 |
+
.transform((v) => (v === "" ? undefined : v)),
|
| 38 |
+
|
| 39 |
+
top_k: z
|
| 40 |
+
.union([z.literal(""), z.coerce.number().min(5).max(100)])
|
| 41 |
+
.transform((v) => (v === "" ? undefined : v)),
|
| 42 |
});
|
| 43 |
|
| 44 |
const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
|
|
|
|
| 157 |
},
|
| 158 |
dynamicPrompt: parse.data.dynamicPrompt,
|
| 159 |
searchTokens: generateSearchTokens(parse.data.name),
|
| 160 |
+
generateSettings: {
|
| 161 |
+
temperature: parse.data.temperature,
|
| 162 |
+
top_p: parse.data.top_p,
|
| 163 |
+
repetition_penalty: parse.data.repetition_penalty,
|
| 164 |
+
top_k: parse.data.top_k,
|
| 165 |
+
},
|
| 166 |
},
|
| 167 |
}
|
| 168 |
);
|
src/routes/settings/(nav)/assistants/new/+page.server.ts
CHANGED
|
@@ -25,6 +25,20 @@ const newAsssistantSchema = z.object({
|
|
| 25 |
ragDomainList: z.preprocess(parseStringToList, z.string().array()),
|
| 26 |
ragAllowAll: z.preprocess((v) => v === "true", z.boolean()),
|
| 27 |
dynamicPrompt: z.preprocess((v) => v === "on", z.boolean()),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
});
|
| 29 |
|
| 30 |
const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
|
|
@@ -125,6 +139,12 @@ export const actions: Actions = {
|
|
| 125 |
},
|
| 126 |
dynamicPrompt: parse.data.dynamicPrompt,
|
| 127 |
searchTokens: generateSearchTokens(parse.data.name),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
});
|
| 129 |
|
| 130 |
// add insertedId to user settings
|
|
|
|
| 25 |
ragDomainList: z.preprocess(parseStringToList, z.string().array()),
|
| 26 |
ragAllowAll: z.preprocess((v) => v === "true", z.boolean()),
|
| 27 |
dynamicPrompt: z.preprocess((v) => v === "on", z.boolean()),
|
| 28 |
+
temperature: z
|
| 29 |
+
.union([z.literal(""), z.coerce.number().min(0.1).max(2)])
|
| 30 |
+
.transform((v) => (v === "" ? undefined : v)),
|
| 31 |
+
top_p: z
|
| 32 |
+
.union([z.literal(""), z.coerce.number().min(0.05).max(1)])
|
| 33 |
+
.transform((v) => (v === "" ? undefined : v)),
|
| 34 |
+
|
| 35 |
+
repetition_penalty: z
|
| 36 |
+
.union([z.literal(""), z.coerce.number().min(0.1).max(2)])
|
| 37 |
+
.transform((v) => (v === "" ? undefined : v)),
|
| 38 |
+
|
| 39 |
+
top_k: z
|
| 40 |
+
.union([z.literal(""), z.coerce.number().min(5).max(100)])
|
| 41 |
+
.transform((v) => (v === "" ? undefined : v)),
|
| 42 |
});
|
| 43 |
|
| 44 |
const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
|
|
|
|
| 139 |
},
|
| 140 |
dynamicPrompt: parse.data.dynamicPrompt,
|
| 141 |
searchTokens: generateSearchTokens(parse.data.name),
|
| 142 |
+
generateSettings: {
|
| 143 |
+
temperature: parse.data.temperature,
|
| 144 |
+
top_p: parse.data.top_p,
|
| 145 |
+
repetition_penalty: parse.data.repetition_penalty,
|
| 146 |
+
top_k: parse.data.top_k,
|
| 147 |
+
},
|
| 148 |
});
|
| 149 |
|
| 150 |
// add insertedId to user settings
|