image / src /services /imageGenerator.js
hadadrjt's picture
image: Refac B/F.
bc4b939
//
// SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
// SPDX-License-Identifier: Apache-2.0
//
import axios from 'axios';
import config from '../../config.js';
import { generateId } from '../utils/idGenerator.js';
import {
getStorage,
setStorage,
getActiveGeneration,
setActiveGeneration,
deleteActiveGeneration
} from './storageManager.js';
import {
sendToSession
} from './websocketManager.js';
const updateProgress = (sessionId, progress) => {
const data = getStorage(sessionId);
if (data) {
data.progress = progress;
setStorage(sessionId, data);
sendToSession(sessionId, {
type: 'progressUpdate',
progress
});
}
};
const createProgressUpdater = (sessionId) => {
return setInterval(() => {
const data = getStorage(sessionId);
if (data &&
data.isGenerating &&
data.progress < config.generation.maxProgress) {
const increment = Math.random() * 8;
const newProgress = Math.min(
config.generation.maxProgress,
data.progress + increment
);
updateProgress(sessionId, newProgress);
}
}, config.generation.progressInterval);
};
const createImageObject = (
base64Data,
prompt,
model,
size
) => ({
id: generateId(),
base64: base64Data,
prompt,
model,
size
});
const callImageApi = async (
prompt,
model,
size,
signal
) => {
return await axios.post(
config.api.baseUrl,
{
model,
prompt,
size,
response_format: 'b64_json',
n: 1
},
{
headers: {
'Authorization': `Bearer ${config.api.key}`,
'Content-Type': 'application/json'
},
signal,
timeout: config.api.timeout,
maxBodyLength: config.limits.maxContentLength,
maxContentLength: config.limits.maxContentLength
}
);
};
export const generateImage = async (
sessionId,
prompt,
model,
size
) => {
const controller = new AbortController();
setActiveGeneration(sessionId, controller);
const progressInterval = createProgressUpdater(sessionId);
setTimeout(async () => {
try {
const response = await callImageApi(
prompt,
model,
size,
controller.signal
);
const data = getStorage(sessionId);
if (!data) return;
updateProgress(
sessionId,
config.generation.maxProgress
);
if (response.data?.data?.length > 0) {
const base64 = response.data.data[0].b64_json;
const newImage = createImageObject(
base64,
prompt,
model,
size
);
data.images.unshift(newImage);
}
data.isGenerating = false;
data.progress = 0;
setStorage(sessionId, data);
sendToSession(sessionId, {
type: 'generationComplete',
images: data.images
});
} catch (error) {
const data = getStorage(sessionId);
if (!data) return;
if (error.name !== 'CanceledError' &&
error.code !== 'ERR_CANCELED') {
data.error =
`The request to the ${model} model was ` +
`unsuccessful, possibly due to high ` +
`server load. Please try again later.`;
sendToSession(sessionId, {
type: 'generationError',
error: data.error
});
}
data.isGenerating = false;
data.progress = 0;
setStorage(sessionId, data);
} finally {
clearInterval(progressInterval);
deleteActiveGeneration(sessionId);
}
}, config.generation.startDelay);
};
export const cancelGeneration = (sessionId) => {
const controller = getActiveGeneration(sessionId);
if (controller) {
controller.abort();
deleteActiveGeneration(sessionId);
const data = getStorage(sessionId);
if (data) {
data.isGenerating = false;
data.progress = 0;
setStorage(sessionId, data);
}
}
};