codeinterpreter / codeinterpreterapi /chains /modifications_check.py
Shroominic
📠 fix typing for python3.9
ecd2209
import json
from typing import List, Optional
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import AIMessage, OutputParserException
from codeinterpreterapi.prompts import determine_modifications_function, determine_modifications_prompt
async def get_file_modifications(
code: str,
llm: BaseLanguageModel,
retry: int = 2,
) -> Optional[List[str]]:
if retry < 1:
return None
messages = determine_modifications_prompt.format_prompt(code=code).to_messages()
message = await llm.apredict_messages(messages, functions=[determine_modifications_function])
if not isinstance(message, AIMessage):
raise OutputParserException("Expected an AIMessage")
function_call = message.additional_kwargs.get("function_call", None)
if function_call is None:
return await get_file_modifications(code, llm, retry=retry - 1)
else:
function_call = json.loads(function_call["arguments"])
return function_call["modifications"]
async def test():
llm = ChatOpenAI(model="gpt-3.5") # type: ignore
code = """
import matplotlib.pyplot as plt
x = list(range(1, 11))
y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77]
plt.plot(x, y, marker='o')
plt.xlabel('Index')
plt.ylabel('Value')
plt.title('Data Plot')
plt.show()
"""
print(await get_file_modifications(code, llm))
if __name__ == "__main__":
import asyncio
from dotenv import load_dotenv
load_dotenv()
asyncio.run(test())