Spaces:
Sleeping
Sleeping
| import requests | |
| import json | |
| import sys | |
| from pathlib import Path | |
| sys.path.append(str(Path(__file__).parent.parent.parent)) | |
| from configs import BING_SUBSCRIPTION_KEY | |
| from server.utils import api_address | |
| from pprint import pprint | |
| api_base_url = api_address() | |
| def dump_input(d, title): | |
| print("\n") | |
| print("=" * 30 + title + " input " + "="*30) | |
| pprint(d) | |
| def dump_output(r, title): | |
| print("\n") | |
| print("=" * 30 + title + " output" + "="*30) | |
| for line in r.iter_content(None, decode_unicode=True): | |
| print(line, end="", flush=True) | |
| headers = { | |
| 'accept': 'application/json', | |
| 'Content-Type': 'application/json', | |
| } | |
| data = { | |
| "query": "请用100字左右的文字介绍自己", | |
| "history": [ | |
| { | |
| "role": "user", | |
| "content": "你好" | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": "你好,我是人工智能大模型" | |
| } | |
| ], | |
| "stream": True, | |
| "temperature": 0.7, | |
| } | |
| def test_chat_chat(api="/chat/chat"): | |
| url = f"{api_base_url}{api}" | |
| dump_input(data, api) | |
| response = requests.post(url, headers=headers, json=data, stream=True) | |
| dump_output(response, api) | |
| assert response.status_code == 200 | |
| def test_knowledge_chat(api="/chat/knowledge_base_chat"): | |
| url = f"{api_base_url}{api}" | |
| data = { | |
| "query": "如何提问以获得高质量答案", | |
| "knowledge_base_name": "samples", | |
| "history": [ | |
| { | |
| "role": "user", | |
| "content": "你好" | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": "你好,我是 ChatGLM" | |
| } | |
| ], | |
| "stream": True | |
| } | |
| dump_input(data, api) | |
| response = requests.post(url, headers=headers, json=data, stream=True) | |
| print("\n") | |
| print("=" * 30 + api + " output" + "="*30) | |
| for line in response.iter_content(None, decode_unicode=True): | |
| data = json.loads(line[6:]) | |
| if "answer" in data: | |
| print(data["answer"], end="", flush=True) | |
| pprint(data) | |
| assert "docs" in data and len(data["docs"]) > 0 | |
| assert response.status_code == 200 | |
| def test_search_engine_chat(api="/chat/search_engine_chat"): | |
| global data | |
| data["query"] = "室温超导最新进展是什么样?" | |
| url = f"{api_base_url}{api}" | |
| for se in ["bing", "duckduckgo"]: | |
| data["search_engine_name"] = se | |
| dump_input(data, api + f" by {se}") | |
| response = requests.post(url, json=data, stream=True) | |
| if se == "bing" and not BING_SUBSCRIPTION_KEY: | |
| data = response.json() | |
| assert data["code"] == 404 | |
| assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`" | |
| print("\n") | |
| print("=" * 30 + api + f" by {se} output" + "="*30) | |
| for line in response.iter_content(None, decode_unicode=True): | |
| data = json.loads(line[6:]) | |
| if "answer" in data: | |
| print(data["answer"], end="", flush=True) | |
| assert "docs" in data and len(data["docs"]) > 0 | |
| pprint(data["docs"]) | |
| assert response.status_code == 200 | |