Spaces:
Runtime error
Runtime error
| import json | |
| from typing import List, Optional | |
| from langchain.base_language import BaseLanguageModel | |
| from langchain.chat_models.openai import ChatOpenAI | |
| from langchain.schema import AIMessage, OutputParserException | |
| from codeinterpreterapi.prompts import determine_modifications_function, determine_modifications_prompt | |
| async def get_file_modifications( | |
| code: str, | |
| llm: BaseLanguageModel, | |
| retry: int = 2, | |
| ) -> Optional[List[str]]: | |
| if retry < 1: | |
| return None | |
| messages = determine_modifications_prompt.format_prompt(code=code).to_messages() | |
| message = await llm.apredict_messages(messages, functions=[determine_modifications_function]) | |
| if not isinstance(message, AIMessage): | |
| raise OutputParserException("Expected an AIMessage") | |
| function_call = message.additional_kwargs.get("function_call", None) | |
| if function_call is None: | |
| return await get_file_modifications(code, llm, retry=retry - 1) | |
| else: | |
| function_call = json.loads(function_call["arguments"]) | |
| return function_call["modifications"] | |
| async def test(): | |
| llm = ChatOpenAI(model="gpt-3.5") # type: ignore | |
| code = """ | |
| import matplotlib.pyplot as plt | |
| x = list(range(1, 11)) | |
| y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77] | |
| plt.plot(x, y, marker='o') | |
| plt.xlabel('Index') | |
| plt.ylabel('Value') | |
| plt.title('Data Plot') | |
| plt.show() | |
| """ | |
| print(await get_file_modifications(code, llm)) | |
| if __name__ == "__main__": | |
| import asyncio | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| asyncio.run(test()) | |