File size: 5,677 Bytes
97c8e77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
from http import HTTPStatus
from pprint import pformat
from typing import Dict, Iterator, List, Optional
import dashscope
from qwen_agent.llm.base import ModelServiceError, register_llm
from qwen_agent.llm.function_calling import BaseFnCallModel
from qwen_agent.llm.schema import ASSISTANT, Message
from qwen_agent.log import logger
@register_llm('qwen_dashscope')
class QwenChatAtDS(BaseFnCallModel):
def __init__(self, cfg: Optional[Dict] = None):
super().__init__(cfg)
self.model = self.model or 'qwen-max'
initialize_dashscope(cfg)
def _chat_stream(
self,
messages: List[Message],
delta_stream: bool,
generate_cfg: dict,
) -> Iterator[List[Message]]:
messages = [msg.model_dump() for msg in messages]
if messages[-1]['role'] == ASSISTANT:
messages[-1]['partial'] = True
logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}')
response = dashscope.Generation.call(
self.model,
messages=messages, # noqa
result_format='message',
stream=True,
**generate_cfg)
if delta_stream:
return self._delta_stream_output(response)
else:
return self._full_stream_output(response)
def _chat_no_stream(
self,
messages: List[Message],
generate_cfg: dict,
) -> List[Message]:
messages = [msg.model_dump() for msg in messages]
if messages[-1]['role'] == ASSISTANT:
messages[-1]['partial'] = True
logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}')
response = dashscope.Generation.call(
self.model,
messages=messages, # noqa
result_format='message',
stream=False,
**generate_cfg)
if response.status_code == HTTPStatus.OK:
return [
Message(role=ASSISTANT,
content=response.output.choices[0].message.content,
reasoning_content=response.output.choices[0].message.get('reasoning_content', ''),
extra={'model_service_info': response})
]
else:
raise ModelServiceError(code=response.code,
message=response.message,
extra={'model_service_info': response})
def _continue_assistant_response(
self,
messages: List[Message],
generate_cfg: dict,
stream: bool,
) -> Iterator[List[Message]]:
return self._chat(messages, stream=stream, delta_stream=False, generate_cfg=generate_cfg)
@staticmethod
def _delta_stream_output(response) -> Iterator[List[Message]]:
for chunk in response:
if chunk.status_code == HTTPStatus.OK:
yield [
Message(role=ASSISTANT,
content=chunk.output.choices[0].message.content,
reasoning_content=chunk.output.choices[0].message.reasoning_content,
extra={'model_service_info': chunk})
]
else:
raise ModelServiceError(code=chunk.code, message=chunk.message, extra={'model_service_info': chunk})
@staticmethod
def _full_stream_output(response) -> Iterator[List[Message]]:
full_content = ''
full_reasoning_content = ''
for chunk in response:
if chunk.status_code == HTTPStatus.OK:
if chunk.output.choices[0].message.get('reasoning_content', ''):
full_reasoning_content += chunk.output.choices[0].message.reasoning_content
if chunk.output.choices[0].message.content:
full_content += chunk.output.choices[0].message.content
yield [
Message(role=ASSISTANT,
content=full_content,
reasoning_content=full_reasoning_content,
extra={'model_service_info': chunk})
]
else:
raise ModelServiceError(code=chunk.code, message=chunk.message, extra={'model_service_info': chunk})
def initialize_dashscope(cfg: Optional[Dict] = None) -> None:
cfg = cfg or {}
api_key = cfg.get('api_key', '')
base_http_api_url = cfg.get('base_http_api_url', None)
base_websocket_api_url = cfg.get('base_websocket_api_url', None)
if not api_key:
api_key = os.getenv('DASHSCOPE_API_KEY', 'EMPTY')
if not base_http_api_url:
base_http_api_url = os.getenv('DASHSCOPE_HTTP_URL', None)
if not base_websocket_api_url:
base_websocket_api_url = os.getenv('DASHSCOPE_WEBSOCKET_URL', None)
api_key = api_key.strip()
if api_key in ('', 'EMPTY'):
if dashscope.api_key is None or dashscope.api_key in ('', 'EMPTY'):
logger.warning('No valid dashscope api_key found in cfg, environment variable `DASHSCOPE_API_KEY` or dashscope.api_key, the model call may raise errors.')
else:
logger.info('No dashscope api_key found in cfg, using the dashscope.api_key that has already been set.')
else: # valid api_key
if api_key != dashscope.api_key:
logger.info('Setting the dashscope api_key.')
dashscope.api_key = api_key
# or do nothing since both keys are the same
if base_http_api_url is not None:
dashscope.base_http_api_url = base_http_api_url.strip()
if base_websocket_api_url is not None:
dashscope.base_websocket_api_url = base_websocket_api_url.strip()
|