forked from CopilotKit/CopilotKit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsubagents_agent.py
More file actions
492 lines (416 loc) · 16.9 KB
/
Copy pathsubagents_agent.py
File metadata and controls
492 lines (416 loc) · 16.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
"""MS Agent Framework agent backing the Sub-Agents demo.
Mirrors langgraph-python/src/agents/subagents.py and
google-adk/src/agents/subagents_agent.py:
A top-level supervisor LLM orchestrates three specialized sub-agents
exposed as tools:
- `research_agent` — gathers facts
- `writing_agent` — drafts prose
- `critique_agent` — reviews drafts
Each sub-agent is a real `agent_framework.Agent` with its own system
prompt. Each delegation appends an entry to the `delegations` slot in
AG-UI shared state via `state_update(...)`, so the UI can render a
live delegation log via `useAgent`.
Subagent invocation contract: each delegation tool returns
`state_update(...)` containing the FULL updated `delegations` list. We
read the prior list out of a per-request `ContextVar` populated by an
`agent_middleware` that captures the AG-UI session metadata
(specifically `current_state`, which the AG-UI runtime stuffs into
`session.metadata` on every turn) before the supervisor runs.
"""
# @region[supervisor-delegation-tools]
# @region[subagent-setup]
from __future__ import annotations
import asyncio
import contextvars
import json
import logging
import threading
import uuid
from collections.abc import AsyncGenerator, Awaitable, Callable
from textwrap import dedent
from typing import Annotated, Any
from ag_ui.core import BaseEvent
from agent_framework import (
Agent,
AgentContext,
BaseChatClient,
Content,
agent_middleware,
tool,
)
from agent_framework_ag_ui import AgentFrameworkAgent, state_update
from pydantic import Field
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# State schema — `delegations` is rendered as a live log in the UI.
# ---------------------------------------------------------------------------
STATE_SCHEMA: dict[str, object] = {
"delegations": {
"type": "array",
"description": (
"Append-only log of supervisor -> sub-agent delegations. "
"Each entry is a Delegation = "
"{id, sub_agent, task, status, result}."
),
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"sub_agent": {"type": "string"},
"task": {"type": "string"},
"status": {"type": "string"},
"result": {"type": "string"},
},
},
}
}
# ---------------------------------------------------------------------------
# Per-request current_state bridge
#
# Tools cannot directly receive `current_state` from the AG-UI runtime,
# but `agent_middleware` runs once per agent invocation with full
# session context. We snapshot the latest `delegations` list into a
# ContextVar before `call_next()`, so each delegation tool (running in
# the same task / contextvar scope) can read it back, append, and
# return the FULL list via `state_update`.
# ---------------------------------------------------------------------------
_current_delegations: contextvars.ContextVar[list[dict[str, Any]]] = (
contextvars.ContextVar("ms_subagents_current_delegations", default=[])
)
def _extract_delegations(raw: Any) -> list[dict[str, Any]]:
"""Pull a clean delegations list out of session metadata.
`session.metadata["current_state"]` is JSON-serialized by the
AG-UI runtime (see `_build_safe_metadata`) so we tolerate either
a plain dict or its string form.
"""
payload: Any = raw
if isinstance(payload, str):
try:
payload = json.loads(payload)
except json.JSONDecodeError:
logger.warning(
"subagents: current_state was not valid JSON; "
"starting from empty delegations list"
)
return []
if not isinstance(payload, dict):
return []
delegations = payload.get("delegations")
if not isinstance(delegations, list):
return []
return [dict(d) for d in delegations if isinstance(d, dict)]
@agent_middleware
async def capture_current_state(
context: AgentContext, call_next: Callable[[], Awaitable[None]]
) -> None:
"""Snapshot `delegations` from session metadata into a ContextVar."""
snapshot: list[dict[str, Any]] = []
session = context.session
metadata = getattr(session, "metadata", None) if session else None
if isinstance(metadata, dict):
snapshot = _extract_delegations(metadata.get("current_state"))
token = _current_delegations.set(snapshot)
try:
await call_next()
finally:
_current_delegations.reset(token)
# ---------------------------------------------------------------------------
# Sub-agent factory
#
# Each sub-agent is a full `Agent(...)` with its own system prompt.
# They share the chat client with the supervisor but otherwise have no
# shared memory or tools — the supervisor only sees their final text.
# ---------------------------------------------------------------------------
# Each sub-agent is a full-fledged `Agent(...)` with its own system
# prompt. They don't share memory or tools with the supervisor — the
# supervisor only sees their return value (final text content).
_RESEARCH_INSTRUCTIONS = (
"You are a research sub-agent. Given a topic, produce a concise "
"bulleted list of 3-5 key facts. No preamble, no closing."
)
_WRITING_INSTRUCTIONS = (
"You are a writing sub-agent. Given a brief and optional source "
"facts, produce a polished 1-paragraph draft. Be clear and "
"concrete. No preamble."
)
_CRITIQUE_INSTRUCTIONS = (
"You are an editorial critique sub-agent. Given a draft, give "
"2-3 crisp, actionable critiques. No preamble."
)
def _make_sub_agent(chat_client: BaseChatClient, name: str, instructions: str) -> Agent:
return Agent(
client=chat_client,
name=name,
instructions=instructions,
tools=[],
)
# @endregion[subagent-setup]
# Module-level holder so the delegation tools can reach the
# pre-built sub-agents without rebuilding them on every tool call.
# Populated lazily by `create_subagents_agent(...)`.
_SUB_AGENTS: dict[str, Agent] = {}
async def _invoke_sub_agent_async(sub_agent_name: str, task: str) -> str:
"""Run a sub-agent on `task` and return its final text content."""
agent = _SUB_AGENTS.get(sub_agent_name)
if agent is None:
raise RuntimeError(
f"sub-agent '{sub_agent_name}' is not registered; call "
"create_subagents_agent(...) first"
)
response = await agent.run(task)
text = (getattr(response, "text", "") or "").strip()
if text:
return text
# Fall back to scanning messages — `Agent.run` always returns
# an `AgentRunResponse`, but `.text` may be empty if the chat
# client only emitted reasoning content or tool calls.
messages = getattr(response, "messages", None) or []
for message in reversed(messages):
for content in getattr(message, "contents", None) or []:
content_text = getattr(content, "text", None)
if content_text:
fallback = str(content_text).strip()
if fallback:
return fallback
raise RuntimeError(f"sub-agent '{sub_agent_name}' returned no text content")
def _invoke_sub_agent(sub_agent_name: str, task: str) -> str:
"""Sync bridge: drive the async invocation from inside a tool callback.
`@tool` reflects on the underlying callable's signature, so the
tool entry points are sync. The supervisor's chat client typically
runs inside an existing event loop (FastAPI request handler), so
`asyncio.run` would refuse — fall through to a worker thread that
spins up its own loop.
"""
try:
return asyncio.run(_invoke_sub_agent_async(sub_agent_name, task))
except RuntimeError as exc:
if "asyncio.run() cannot be called" not in str(exc):
raise
container: dict[str, Any] = {}
def _runner() -> None:
try:
container["result"] = asyncio.run(
_invoke_sub_agent_async(sub_agent_name, task)
)
except Exception as inner: # pragma: no cover -- defensive
container["error"] = inner
worker = threading.Thread(target=_runner, daemon=True)
worker.start()
worker.join()
if "error" in container:
raise container["error"]
return str(container["result"])
def _delegate(sub_agent_name: str, task: str) -> Content:
"""Common delegation flow: invoke sub-agent, append entry, push state."""
delegations = list(_current_delegations.get())
entry_id = str(uuid.uuid4())
try:
result_text = _invoke_sub_agent(sub_agent_name, task)
except Exception as exc:
logger.exception("subagents: %s delegation failed", sub_agent_name)
delegations.append(
{
"id": entry_id,
"sub_agent": sub_agent_name,
"task": task,
"status": "failed",
# Surface only the exception class — sub-agent error
# messages can leak chat client URLs / quota details
# in deployed environments.
"result": (f"sub-agent error: {exc.__class__.__name__}"),
}
)
# Mirror the contextvar so a follow-up sub-agent call within the
# same supervisor turn sees this entry.
_current_delegations.set(delegations)
return state_update(
text=(f"{sub_agent_name} failed; surfaced in delegation log."),
state={"delegations": delegations},
)
delegations.append(
{
"id": entry_id,
"sub_agent": sub_agent_name,
"task": task,
"status": "completed",
"result": result_text,
}
)
_current_delegations.set(delegations)
return state_update(
text=result_text,
state={"delegations": delegations},
)
# ---------------------------------------------------------------------------
# Supervisor delegation tools — each one wraps a sub-agent invocation.
# ---------------------------------------------------------------------------
# Each @tool wraps a sub-agent invocation. The supervisor LLM "calls"
# these tools to delegate work; each call synchronously runs the
# matching sub-agent (via `_delegate`), appends the entry to the
# `delegations` shared-state slot, and returns a `state_update(...)` so
# the AG-UI emitter pushes a deterministic StateSnapshotEvent — both
# surfacing the result to the supervisor and refreshing the live
# delegation log in the UI.
@tool(
name="research_agent",
description=(
"Delegate a research task to the research sub-agent. Use for "
"gathering facts, background, definitions, statistics. Returns "
"a bulleted list of key facts."
),
)
def research_agent(
task: Annotated[
str,
Field(description="The research question or topic to investigate."),
],
) -> Content:
"""Delegate a research task to the research sub-agent."""
return _delegate("research_agent", task)
@tool(
name="writing_agent",
description=(
"Delegate a drafting task to the writing sub-agent. Use for "
"producing a polished paragraph, draft, or summary. Pass any "
"relevant facts from prior research inside `task`."
),
)
def writing_agent(
task: Annotated[
str,
Field(
description=(
"The drafting brief, including any relevant source "
"facts the writer should weave in."
)
),
],
) -> Content:
"""Delegate a drafting task to the writing sub-agent."""
return _delegate("writing_agent", task)
@tool(
name="critique_agent",
description=(
"Delegate a critique task to the critique sub-agent. Use for "
"reviewing a draft and suggesting concrete improvements."
),
)
def critique_agent(
task: Annotated[
str,
Field(
description=(
"The draft text to critique. Provide the full text -- "
"the critique sub-agent has no other context."
)
),
],
) -> Content:
"""Delegate a critique task to the critique sub-agent."""
return _delegate("critique_agent", task)
# @endregion[supervisor-delegation-tools]
# ---------------------------------------------------------------------------
# Supervisor agent factory
# ---------------------------------------------------------------------------
SUPERVISOR_PROMPT = dedent(
"""
You are a supervisor agent that coordinates three specialized
sub-agents to produce high-quality deliverables.
Available sub-agents (call them as tools):
- research_agent: gathers facts on a topic.
- writing_agent: turns facts + a brief into a polished draft.
- critique_agent: reviews a draft and suggests improvements.
For every non-trivial user request, delegate in sequence:
research_agent -> writing_agent -> critique_agent.
IMPORTANT: call EACH sub-agent EXACTLY ONCE per user request.
After critique_agent returns, do NOT call any sub-agent again -- return
a concise final answer to the user that incorporates the critique.
Pass the relevant facts/draft through the `task` argument of each tool.
Keep your own messages short — explain the plan once, delegate,
then return a concise summary once done. The UI shows the user a
live log of every sub-agent delegation.
"""
).strip()
def _tool_call_ids(message: dict[str, Any]) -> set[str]:
tool_calls = message.get("tool_calls") or message.get("toolCalls") or []
if not isinstance(tool_calls, list):
return set()
ids: set[str] = set()
for call in tool_calls:
if isinstance(call, dict) and isinstance(call.get("id"), str):
ids.add(call["id"])
return ids
def _tool_result_ids(messages: list[dict[str, Any]], start_index: int) -> set[str]:
ids: set[str] = set()
for message in messages[start_index + 1 :]:
if message.get("role") == "user":
break
if message.get("role") != "tool":
continue
call_id = message.get("tool_call_id") or message.get("toolCallId")
if isinstance(call_id, str):
ids.add(call_id)
return ids
def _drop_orphan_assistant_tool_calls(messages: Any) -> list[dict[str, Any]]:
"""Remove historical assistant tool calls that lack tool result messages.
The MS Agent Framework AG-UI bridge can preserve the assistant tool-call
snapshot while omitting the corresponding tool-role results. OpenAI rejects
that history on the next turn, so keep the final assistant text/state but
omit malformed historical tool-call entries before the supervisor runs.
"""
if not isinstance(messages, list):
return []
clean: list[dict[str, Any]] = []
for index, message in enumerate(messages):
if not isinstance(message, dict):
continue
if message.get("role") == "assistant":
call_ids = _tool_call_ids(message)
if call_ids and not call_ids.issubset(_tool_result_ids(messages, index)):
continue
clean.append(message)
return clean
class SubagentsFrameworkAgent(AgentFrameworkAgent):
"""AgentFrameworkAgent that removes invalid historical tool-call snapshots."""
async def run( # type: ignore[override]
self,
input_data: dict[str, Any],
) -> AsyncGenerator[BaseEvent, None]:
patched_input = dict(input_data)
patched_input["messages"] = _drop_orphan_assistant_tool_calls(
input_data.get("messages")
)
async for event in super().run(patched_input):
yield event
def create_subagents_agent(chat_client: BaseChatClient) -> SubagentsFrameworkAgent:
"""Instantiate the Sub-Agents demo supervisor."""
# Build (and cache) the three sub-agents so the @tool entry points
# can find them via the module-level registry.
_SUB_AGENTS["research_agent"] = _make_sub_agent(
chat_client, "research_agent", _RESEARCH_INSTRUCTIONS
)
_SUB_AGENTS["writing_agent"] = _make_sub_agent(
chat_client, "writing_agent", _WRITING_INSTRUCTIONS
)
_SUB_AGENTS["critique_agent"] = _make_sub_agent(
chat_client, "critique_agent", _CRITIQUE_INSTRUCTIONS
)
base_agent = Agent(
client=chat_client,
name="subagents_supervisor",
instructions=SUPERVISOR_PROMPT,
tools=[research_agent, writing_agent, critique_agent],
default_options={"allow_multiple_tool_calls": False},
middleware=[capture_current_state],
)
return SubagentsFrameworkAgent(
agent=base_agent,
name="CopilotKitMSAgentSubagentsSupervisor",
description=(
"Supervisor agent. Delegates research / writing / critique "
"to specialized sub-agents and surfaces the live "
"delegation log to the UI via shared state."
),
state_schema=STATE_SCHEMA,
require_confirmation=False,
)