Spaces:
Paused
Paused
:gem: [Feature] Support call hf api with api_key via HTTP Bearer
Browse files- apis/chat_api.py +16 -2
- networks/message_streamer.py +8 -0
apis/chat_api.py
CHANGED
|
@@ -2,7 +2,8 @@ import argparse
|
|
| 2 |
import uvicorn
|
| 3 |
import sys
|
| 4 |
|
| 5 |
-
from fastapi import FastAPI
|
|
|
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
|
| 8 |
from utils.logger import logger
|
|
@@ -38,6 +39,16 @@ class ChatAPIApp:
|
|
| 38 |
]
|
| 39 |
return self.available_models
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
class ChatCompletionsPostItem(BaseModel):
|
| 42 |
model: str = Field(
|
| 43 |
default="mixtral-8x7b",
|
|
@@ -60,7 +71,9 @@ class ChatAPIApp:
|
|
| 60 |
description="(bool) Stream",
|
| 61 |
)
|
| 62 |
|
| 63 |
-
def chat_completions(
|
|
|
|
|
|
|
| 64 |
streamer = MessageStreamer(model=item.model)
|
| 65 |
composer = MessageComposer(model=item.model)
|
| 66 |
composer.merge(messages=item.messages)
|
|
@@ -70,6 +83,7 @@ class ChatAPIApp:
|
|
| 70 |
prompt=composer.merged_str,
|
| 71 |
temperature=item.temperature,
|
| 72 |
max_new_tokens=item.max_tokens,
|
|
|
|
| 73 |
)
|
| 74 |
if item.stream:
|
| 75 |
event_source_response = EventSourceResponse(
|
|
|
|
| 2 |
import uvicorn
|
| 3 |
import sys
|
| 4 |
|
| 5 |
+
from fastapi import FastAPI, Depends
|
| 6 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 7 |
from pydantic import BaseModel, Field
|
| 8 |
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
|
| 9 |
from utils.logger import logger
|
|
|
|
| 39 |
]
|
| 40 |
return self.available_models
|
| 41 |
|
| 42 |
+
def extract_api_key(
|
| 43 |
+
credentials: HTTPAuthorizationCredentials = Depends(
|
| 44 |
+
HTTPBearer(auto_error=False)
|
| 45 |
+
),
|
| 46 |
+
):
|
| 47 |
+
if credentials:
|
| 48 |
+
return credentials.credentials
|
| 49 |
+
else:
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
class ChatCompletionsPostItem(BaseModel):
|
| 53 |
model: str = Field(
|
| 54 |
default="mixtral-8x7b",
|
|
|
|
| 71 |
description="(bool) Stream",
|
| 72 |
)
|
| 73 |
|
| 74 |
+
def chat_completions(
|
| 75 |
+
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
|
| 76 |
+
):
|
| 77 |
streamer = MessageStreamer(model=item.model)
|
| 78 |
composer = MessageComposer(model=item.model)
|
| 79 |
composer.merge(messages=item.messages)
|
|
|
|
| 83 |
prompt=composer.merged_str,
|
| 84 |
temperature=item.temperature,
|
| 85 |
max_new_tokens=item.max_tokens,
|
| 86 |
+
api_key=api_key,
|
| 87 |
)
|
| 88 |
if item.stream:
|
| 89 |
event_source_response = EventSourceResponse(
|
networks/message_streamer.py
CHANGED
|
@@ -36,6 +36,7 @@ class MessageStreamer:
|
|
| 36 |
prompt: str = None,
|
| 37 |
temperature: float = 0.01,
|
| 38 |
max_new_tokens: int = 8192,
|
|
|
|
| 39 |
):
|
| 40 |
# https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
|
| 41 |
# curl --proxy http://<server>:<port> https://api-inference.huggingface.co/models/<org>/<model_name> -X POST -d '{"inputs":"who are you?","parameters":{"max_new_token":64}}' -H 'Content-Type: application/json' -H 'Authorization: Bearer <HF_TOKEN>'
|
|
@@ -45,6 +46,13 @@ class MessageStreamer:
|
|
| 45 |
self.request_headers = {
|
| 46 |
"Content-Type": "application/json",
|
| 47 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# References:
|
| 49 |
# huggingface_hub/inference/_client.py:
|
| 50 |
# class InferenceClient > def text_generation()
|
|
|
|
| 36 |
prompt: str = None,
|
| 37 |
temperature: float = 0.01,
|
| 38 |
max_new_tokens: int = 8192,
|
| 39 |
+
api_key: str = None,
|
| 40 |
):
|
| 41 |
# https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
|
| 42 |
# curl --proxy http://<server>:<port> https://api-inference.huggingface.co/models/<org>/<model_name> -X POST -d '{"inputs":"who are you?","parameters":{"max_new_token":64}}' -H 'Content-Type: application/json' -H 'Authorization: Bearer <HF_TOKEN>'
|
|
|
|
| 46 |
self.request_headers = {
|
| 47 |
"Content-Type": "application/json",
|
| 48 |
}
|
| 49 |
+
|
| 50 |
+
if api_key:
|
| 51 |
+
logger.note(
|
| 52 |
+
f"Using API Key: {api_key[:3]}{(len(api_key)-7)*'*'}{api_key[-4:]}"
|
| 53 |
+
)
|
| 54 |
+
self.request_headers["Authorization"] = f"Bearer {api_key}"
|
| 55 |
+
|
| 56 |
# References:
|
| 57 |
# huggingface_hub/inference/_client.py:
|
| 58 |
# class InferenceClient > def text_generation()
|