File size: 1,578 Bytes
2ea8b71
ecd2209
2ea8b71
 
 
000213d
 
 
2ea8b71
 
 
e0cbbfe
2ea8b71
 
ecd2209
e0cbbfe
2ea8b71
000213d
 
e0cbbfe
2ea8b71
 
e0cbbfe
2ea8b71
e0cbbfe
2ea8b71
e0cbbfe
 
2ea8b71
 
e0cbbfe
2ea8b71
 
000213d
e0cbbfe
2ea8b71
 
 
 
 
 
 
 
 
 
 
 
 
 
000213d
2ea8b71
 
 
 
 
 
e0cbbfe
2ea8b71
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import json
from typing import List, Optional

from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import AIMessage, OutputParserException

from codeinterpreterapi.prompts import determine_modifications_function, determine_modifications_prompt


async def get_file_modifications(
    code: str,
    llm: BaseLanguageModel,
    retry: int = 2,
) -> Optional[List[str]]:
    if retry < 1:
        return None
    messages = determine_modifications_prompt.format_prompt(code=code).to_messages()
    message = await llm.apredict_messages(messages, functions=[determine_modifications_function])

    if not isinstance(message, AIMessage):
        raise OutputParserException("Expected an AIMessage")

    function_call = message.additional_kwargs.get("function_call", None)

    if function_call is None:
        return await get_file_modifications(code, llm, retry=retry - 1)
    else:
        function_call = json.loads(function_call["arguments"])
        return function_call["modifications"]


async def test():
    llm = ChatOpenAI(model="gpt-3.5")  # type: ignore

    code = """
    import matplotlib.pyplot as plt

    x = list(range(1, 11))
    y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77]

    plt.plot(x, y, marker='o')
    plt.xlabel('Index')
    plt.ylabel('Value')
    plt.title('Data Plot')

    plt.show()
    """

    print(await get_file_modifications(code, llm))


if __name__ == "__main__":
    import asyncio
    from dotenv import load_dotenv
    load_dotenv()

    asyncio.run(test())