forked from CopilotKit/CopilotKit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
91 lines (81 loc) · 2.82 KB
/
Copy pathagent.py
File metadata and controls
91 lines (81 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
This is the main entry point for the AI.
It defines the workflow graph and the entry point for the agent.
"""
# pylint: disable=line-too-long, unused-import
from typing import cast, TypedDict, Any
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, ToolMessage, AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import MessagesState
from copilotkit.langchain import copilotkit_customize_config
class Translations(TypedDict):
"""Contains the translations in four different languages."""
translation_es: str
translation_fr: str
translation_de: str
class AgentState(MessagesState):
"""Contains the state of the agent."""
translations: Translations
input: str
async def translate_node(state: AgentState, config: RunnableConfig):
"""Chatbot that translates text"""
config = copilotkit_customize_config(
config,
# config emits messages by default, so this is not needed:
## emit_messages=True,
emit_intermediate_state=[
{
"state_key": "translations",
"tool": "translate"
}
]
)
model = ChatOpenAI(model="gpt-4o").bind_tools(
[Translations],
parallel_tool_calls=False,
tool_choice=(
None if state["messages"] and
isinstance(state["messages"][-1], HumanMessage)
else "Translations"
)
)
response = await model.ainvoke([
SystemMessage(
content=f"""
You are a helpful assistant that translates text to different languages
(Spanish, French and German).
Don't ask for confirmation before translating.
{
'The user is currently working on translating this text: "' +
state["input"] + '"' if state.get("input") else ""
}
"""
),
*state["messages"],
], config)
if hasattr(response, "tool_calls") and len(getattr(response, "tool_calls")) > 0:
ai_message = cast(AIMessage, response)
return {
"messages": [
response,
ToolMessage(
content="Translated!",
tool_call_id=ai_message.tool_calls[0]["id"]
)
],
"translations": cast(AIMessage, response).tool_calls[0]["args"],
}
return {
"messages": [
response,
],
}
workflow = StateGraph(AgentState)
workflow.add_node("translate_node", cast(Any, translate_node))
workflow.set_entry_point("translate_node")
workflow.add_edge("translate_node", END)
memory = MemorySaver()
graph = workflow.compile(checkpointer=memory)