Spaces:
Runtime error
Runtime error
File size: 1,578 Bytes
2ea8b71 ecd2209 2ea8b71 000213d 2ea8b71 e0cbbfe 2ea8b71 ecd2209 e0cbbfe 2ea8b71 000213d e0cbbfe 2ea8b71 e0cbbfe 2ea8b71 e0cbbfe 2ea8b71 e0cbbfe 2ea8b71 e0cbbfe 2ea8b71 000213d e0cbbfe 2ea8b71 000213d 2ea8b71 e0cbbfe 2ea8b71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import json
from typing import List, Optional
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import AIMessage, OutputParserException
from codeinterpreterapi.prompts import determine_modifications_function, determine_modifications_prompt
async def get_file_modifications(
code: str,
llm: BaseLanguageModel,
retry: int = 2,
) -> Optional[List[str]]:
if retry < 1:
return None
messages = determine_modifications_prompt.format_prompt(code=code).to_messages()
message = await llm.apredict_messages(messages, functions=[determine_modifications_function])
if not isinstance(message, AIMessage):
raise OutputParserException("Expected an AIMessage")
function_call = message.additional_kwargs.get("function_call", None)
if function_call is None:
return await get_file_modifications(code, llm, retry=retry - 1)
else:
function_call = json.loads(function_call["arguments"])
return function_call["modifications"]
async def test():
llm = ChatOpenAI(model="gpt-3.5") # type: ignore
code = """
import matplotlib.pyplot as plt
x = list(range(1, 11))
y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77]
plt.plot(x, y, marker='o')
plt.xlabel('Index')
plt.ylabel('Value')
plt.title('Data Plot')
plt.show()
"""
print(await get_file_modifications(code, llm))
if __name__ == "__main__":
import asyncio
from dotenv import load_dotenv
load_dotenv()
asyncio.run(test())
|