forked from CopilotKit/CopilotKit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcopilotkit_lg_middleware.py
More file actions
945 lines (817 loc) · 38.2 KB
/
Copy pathcopilotkit_lg_middleware.py
File metadata and controls
945 lines (817 loc) · 38.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
"""
CopilotKit Middleware for LangGraph agents.
Works with any agent (prebuilt or custom).
Example:
from langgraph.prebuilt import create_agent
from copilotkit import CopilotKitMiddleware
agent = create_agent(
model="openai:gpt-4o",
tools=[backend_tool],
middleware=[CopilotKitMiddleware()],
)
"""
import json
import re
from typing import Any, Callable, Awaitable, ClassVar, Iterable, Union
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
from langchain.agents.middleware import (
AgentMiddleware,
AgentState,
ModelRequest,
ModelResponse,
)
from langgraph.runtime import Runtime
from .header_propagation import install_httpx_hook, set_forwarded_headers
from .langgraph import CopilotKitProperties
# Optional dependency: the A2UI subagent-tool factory ships in ag-ui-langgraph.
# Guarded so an older/skewed version without the factory degrades to
# "no auto-A2UI" instead of breaking the whole middleware import.
try: # pragma: no cover - exercised indirectly via the a2ui injection path
from ag_ui_langgraph import get_a2ui_tools
except Exception: # noqa: BLE001 - any import failure means the feature is off
get_a2ui_tools = None
# Track which httpx clients already have the header-propagation hook installed
# (by object id) so we never double-install on repeated model calls.
_hooked_clients: set[int] = set()
# ---------------------------------------------------------------------------
# Auto-A2UI: bridge the inferred model from the model-call hook to the
# tool-call hook
# ---------------------------------------------------------------------------
# The generate_a2ui tool drives a structured-output subagent and so needs a
# chat model. We "infer" that model from ``request.model`` in
# ``wrap_model_call`` (the only hook that exposes the bound model) and reuse it.
# But the tool actually *executes* later in ``wrap_tool_call``, whose request
# does NOT carry the model. ContextVars do not reliably survive LangGraph node
# boundaries, so we bridge the built tool across nodes via a module-level map
# keyed by the run's thread id.
_a2ui_tools_by_thread: dict[str, Any] = {}
# Fallback key for runs without a thread id (e.g. an in-memory invoke with no
# checkpointer). Collisions across concurrent context-less runs are an
# acceptable edge — the deployed path always carries a thread id.
_DEFAULT_THREAD_KEY = "__copilotkit_a2ui_default__"
def _current_thread_id() -> "str | None":
"""Best-effort read of the active run's thread id from the LangGraph config.
Returns ``None`` outside a runnable context (e.g. unit tests); callers then
fall back to ``_DEFAULT_THREAD_KEY``.
"""
try:
from langgraph.config import get_config
cfg = get_config() or {}
return (cfg.get("configurable") or {}).get("thread_id")
except Exception: # noqa: BLE001 - no active context / older langgraph
return None
def _extract_forwarded_headers_from_config() -> None:
"""Extract raw ``x-*`` headers from the current LangGraph RunnableConfig and
push them into the header-propagation ContextVar so the httpx hook can
forward them on outgoing LLM requests.
When an agent runs inside **langgraph-api** with
``LANGGRAPH_HTTP={"configurable_headers":{"include":["x-*"]}}``,
the server copies inbound HTTP ``x-*`` headers into
``config["configurable"]`` as individual keys (e.g.
``configurable["x-aimock-context"] = "value"``). This function reads those
keys and calls :func:`set_forwarded_headers` so they propagate to the
underlying LLM provider SDK via the httpx event hook.
Precedence: the wrapper dict ``copilotkit_forwarded_headers`` (if present)
takes priority over raw ``x-*`` keys. Raw keys are only used when the
wrapper dict is absent or does not contain a given header.
Safe to call outside a runnable context (e.g. in unit tests) — silently
returns without doing anything if ``get_config()`` raises.
"""
try:
from langgraph.config import (
get_config,
) # local import to avoid hard dep at module level
config = get_config()
except ImportError:
return
except RuntimeError:
# No active runnable context — clear the ContextVar so stale headers
# from a prior request in the same async context do not leak through.
set_forwarded_headers({})
return
try:
headers: dict[str, str] = {}
# Sources to scan: config["context"] (LangGraph >=0.6.0) and
# config["configurable"] (all versions).
context = config.get("context") or {}
configurable = config.get("configurable") or {}
# 1) Wrapper-dict path (highest priority): these are headers that
# CopilotKit explicitly bundled under a known key. Process context
# first with first-write-wins so context takes precedence over
# configurable (LangGraph >=0.6.0 introduced context as the newer
# preferred mechanism).
for src in (context, configurable):
if not isinstance(src, dict):
continue
wrapper = src.get("copilotkit_forwarded_headers")
if isinstance(wrapper, dict):
for k, v in wrapper.items():
lk = k.lower() if isinstance(k, str) else k
if isinstance(k, str) and isinstance(v, str) and lk not in headers:
headers[lk] = v
# 2) Raw x-* keys directly on context and configurable. These appear
# when langgraph-api's configurable_headers mechanism forwards inbound
# HTTP headers as individual configurable entries.
for src in (context, configurable):
if not isinstance(src, dict):
continue
for k, v in src.items():
if (
isinstance(k, str)
and k.lower().startswith("x-")
and isinstance(v, str)
):
# Don't overwrite wrapper-dict values (wrapper > raw).
# Lowercase at insertion so precedence checks are
# deterministic regardless of source casing.
lk = k.lower()
if lk not in headers:
headers[lk] = v
# Always set the ContextVar — even with an empty dict — so stale
# headers from previous calls in the same async context do not leak
# into this one.
set_forwarded_headers(headers)
except Exception as e:
# Header forwarding is best-effort. Never block the LLM call.
# Clear the ContextVar so stale headers from a prior request do not
# leak through on failure.
set_forwarded_headers({})
import logging
logging.getLogger(__name__).debug(
"Header forwarding extraction failed; continuing without forwarded headers: %s",
e,
)
def _ensure_httpx_hook(model: Any) -> None:
"""Install the header-propagation httpx hook on a LangChain chat model's
underlying HTTP client(s), if present. No-op for models that don't expose
an httpx transport (e.g. non-OpenAI/Anthropic providers).
"""
for attr in ("client", "async_client"):
client = getattr(model, attr, None)
if client is None:
continue
cid = id(client)
if cid not in _hooked_clients:
install_httpx_hook(client)
_hooked_clients.add(cid)
class StateSchema(AgentState):
copilotkit: CopilotKitProperties
# Internal/framework keys that should never be surfaced to the LLM as
# user-facing state. These are either reducer-managed message buckets,
# CopilotKit/AG-UI plumbing, or graph-internal scaffolding.
_RESERVED_STATE_KEYS = frozenset(
{
"messages",
"copilotkit",
# Transport-layer plumbing: forwarded request headers conveyed via a
# separate ContextVar to the httpx hook. MUST never be rendered into
# the LLM prompt — neither via App Context nor via expose_state.
"copilotkit_forwarded_headers",
"ag-ui",
"tools",
"structured_response",
"thread_id",
"remaining_steps",
}
)
class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
"""CopilotKit Middleware for LangGraph agents.
Handles frontend tool injection, interception for CopilotKit, and
automatic exposure of agent state to the LLM so values written via
``agent.setState`` on the frontend (or via ``Command(update=...)`` in a
tool) are visible in the next model call without needing a custom
``get_state`` tool.
Args:
expose_state: Controls how user-defined state keys are surfaced into
``request.system_message`` on every model call. Off by default
to avoid leaking arbitrary state into prompts; opt in explicitly.
- ``False`` (default) — never surface state.
- ``True`` — every state key that is not in the reserved
internal set and does not start with an underscore is
JSON-serialized into a "Current agent state:" note appended
to the system message.
- ``list``/``tuple``/``set[str]`` — only surface the named keys.
Use this when you want explicit control over what the LLM
sees (e.g. ``["liked", "todos"]``).
"""
state_schema = StateSchema
tools: ClassVar[list] = []
def __init__(
self,
*,
expose_state: Union[bool, Iterable[str]] = False,
):
super().__init__()
if isinstance(expose_state, bool):
self._expose_state: Union[bool, frozenset[str]] = expose_state
else:
self._expose_state = frozenset(expose_state)
@property
def name(self) -> str:
return "CopilotKitMiddleware"
# ------------------------------------------------------------------
# State-to-prompt surfacing
# ------------------------------------------------------------------
def _build_state_note(self, state: dict) -> str | None:
"""Serialize a snapshot of user state into a system-prompt note.
Returns ``None`` when nothing should be appended (feature disabled
or no non-empty user keys present).
"""
if self._expose_state is False:
return None
if isinstance(self._expose_state, frozenset):
# Allowlist branch: honor user intent for other reserved keys
# (e.g. ``thread_id``) so the override test in this suite still
# passes, but hard-exclude ``copilotkit_forwarded_headers`` —
# rendering it would leak the raw forwarded request headers into
# the LLM prompt, which is what the reserved-keys comment above
# promises will never happen "via App Context nor via expose_state".
keys: list[str] = [
k
for k in self._expose_state
if k in state and k != "copilotkit_forwarded_headers"
]
else:
keys = [
k
for k in state
if k not in _RESERVED_STATE_KEYS and not str(k).startswith("_")
]
snapshot: dict[str, Any] = {}
for k in keys:
v = state.get(k)
# Skip empty / no-op values to keep the note tight.
if v in (None, "", [], {}):
continue
snapshot[k] = v
if not snapshot:
return None
try:
body = json.dumps(snapshot, default=str, ensure_ascii=False, indent=2)
except (TypeError, ValueError):
body = str(snapshot)
return f"Current agent state:\n{body}"
def _apply_state_note(self, request: ModelRequest) -> ModelRequest:
note = self._build_state_note(request.state or {})
if not note:
return request
existing = request.system_message
if existing is None:
return request.override(system_message=SystemMessage(content=note))
base = (
existing.content
if isinstance(existing.content, str)
else str(existing.content)
)
return request.override(
system_message=SystemMessage(content=f"{base}\n\n{note}")
)
# ------------------------------------------------------------------
# Auto-A2UI tool injection
# ------------------------------------------------------------------
@staticmethod
def _resolve_a2ui_catalog(state: dict) -> "tuple[str | None, str | None] | None":
"""Find the frontend-registered A2UI catalog wherever it was passed.
Returns ``(component_schema, catalog_id)`` when a catalog is present,
else ``None`` (so the tool is never advertised when the client can't
render A2UI). Two delivery paths are supported, because the catalog
lands in different places depending on how the agent is served:
- **AG-UI native endpoint** → ``state["ag-ui"]["a2ui_schema"]``, a JSON
string ``{"catalogId": ..., "components": [...]}``.
- **CopilotKit runtime proxy** → a ``state["copilotkit"]["context"]``
entry describing the A2UI catalog (catalog id + component schemas as
text).
``component_schema`` is the text/JSON the subagent should compose from;
``catalog_id`` binds generated surfaces to the frontend's catalog (so
BYOC custom catalogs render their own components, not the basic one).
"""
# AG-UI native path.
ag_ui = state.get("ag-ui") or {}
a2ui_schema = ag_ui.get("a2ui_schema")
if a2ui_schema:
catalog_id = None
try:
parsed = (
json.loads(a2ui_schema)
if isinstance(a2ui_schema, str)
else a2ui_schema
)
if isinstance(parsed, dict):
catalog_id = parsed.get("catalogId")
except (TypeError, ValueError):
pass
# Native path: the toolkit reads ``a2ui_schema`` from state itself,
# so no composition_guide is needed — just surface the catalog id.
return None, catalog_id
# CopilotKit runtime-proxy path: the catalog arrives as a context entry.
context = (state.get("copilotkit") or {}).get("context") or []
for entry in context:
if not isinstance(entry, dict):
continue
description = entry.get("description") or ""
value = entry.get("value") or ""
if "A2UI catalog" not in description or not value:
continue
# The value lists catalogs as "- <catalogId>" lines; the first is
# the custom catalog the client registered.
match = re.search(r"(?m)^\s*-\s+(\S+)", value)
catalog_id = match.group(1) if match else None
return value, catalog_id
return None
@staticmethod
def _a2ui_inject_decision(state: dict) -> "bool | str | None":
"""Return the A2UI ``injectA2UITool`` decision, or ``None``.
The ``@ag-ui/a2ui-middleware`` forwards its ``injectA2UITool`` setting on
``forwardedProps``, which ``ag-ui-langgraph`` surfaces into agent state at
``state["ag-ui"]["inject_a2ui_tool"]`` — present only when the host turned
the runtime A2UI tool on (truthy or a custom tool-name string). ``None``
means no signal at all (off, or no A2UI middleware in the pipeline), in
which case we do not auto-inject.
"""
return (state.get("ag-ui") or {}).get("inject_a2ui_tool")
def _maybe_build_a2ui_tool(self, request: ModelRequest) -> Any | None:
"""Build a ``generate_a2ui`` tool bound to the agent's own model when
A2UI tool injection is turned on for this run.
Gating, in order:
1. **Opt-in.** Only inject when the A2UI ``injectA2UITool`` flag is
truthy (forwarded by ``@ag-ui/a2ui-middleware`` and surfaced at
``state["ag-ui"]["inject_a2ui_tool"]``). No flag → no injection. This
is the whole contract: "no injectA2UITool, no A2UI tool injection."
2. **No double-inject.** If the agent already exposes a tool with the
same name (e.g. a backend-defined ``generate_a2ui``), don't inject —
the host owns it, and a duplicate would show the model two tools with
one name.
The model is inferred from ``request.model`` (the bound agent model); the
component schema and catalog id come from the registered catalog (when
present) so the subagent composes the right components and surfaces bind
to the frontend's catalog — otherwise the toolkit's basic catalog is
used. The built tool is stashed for the tool-call hook to execute.
Returns the tool or ``None`` when A2UI is not applicable.
"""
if get_a2ui_tools is None:
return None
state = request.state or {}
# (1) Opt-in: only inject when the host turned the A2UI tool on.
if not self._a2ui_inject_decision(state):
return None
# Bind to the frontend's catalog when one was registered (optional).
resolved = self._resolve_a2ui_catalog(state)
component_schema, catalog_id = resolved if resolved else (None, None)
kwargs: dict[str, Any] = {}
if catalog_id:
kwargs["default_catalog_id"] = catalog_id
# Feed the registered component schema to the subagent so it composes
# only catalog components (the toolkit appends this to its prompt).
if component_schema:
kwargs["composition_guide"] = component_schema
tool = get_a2ui_tools(request.model, **kwargs)
# (2) Don't double-inject if the agent already defines this tool.
existing_names = {getattr(t, "name", None) for t in (request.tools or [])}
if tool.name in existing_names:
return None
_a2ui_tools_by_thread[_current_thread_id() or _DEFAULT_THREAD_KEY] = tool
return tool
# Inject frontend + A2UI tools and surface user state before model call
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
_extract_forwarded_headers_from_config()
_ensure_httpx_hook(request.model)
request = self._apply_state_note(request)
a2ui_tool = self._maybe_build_a2ui_tool(request)
frontend_tools = request.state.get("copilotkit", {}).get("actions", [])
if a2ui_tool is not None:
# Our generate_a2ui replaces the runtime's render tool — don't
# advertise both. Drop the render tool the A2UI middleware injected.
decision = self._a2ui_inject_decision(request.state or {})
drop = decision if isinstance(decision, str) else "render_a2ui"
frontend_tools = [
t
for t in frontend_tools
if ((t.get("function") or {}).get("name") or t.get("name")) != drop
]
if not frontend_tools and a2ui_tool is None:
return handler(request)
extra_tools = [a2ui_tool] if a2ui_tool is not None else []
merged_tools = [*request.tools, *extra_tools, *frontend_tools]
return handler(request.override(tools=merged_tools))
@staticmethod
def _fix_messages_for_bedrock(messages: list) -> list:
"""Fix messages loaded from checkpoint before sending to Bedrock.
Handles four issues caused by CopilotKit's after_agent restoring
frontend tool_calls to the checkpoint:
1. Strip unanswered tool_calls (no matching ToolMessage) — Bedrock
rejects toolUse without a corresponding toolResult.
2. Sync msg.content tool_use blocks with msg.tool_calls.
3. Fix tool_use content blocks with string input (must be dict).
4. Deduplicate ToolMessages by tool_call_id — patch_orphan_tool_calls
injects a placeholder with a new random ID on every checkpoint load;
when the real result is later appended alongside it, Bedrock rejects
the duplicate toolResult IDs. We keep the real result (non-interrupted)
over the placeholder, falling back to the last occurrence if both look
real.
"""
# 4. Deduplicate ToolMessages by tool_call_id before all other processing.
# patch_orphan_tool_calls adds "…was interrupted before completion."
# placeholders with fresh random IDs on every checkpoint load. The real
# result comes in as a separate message with a different ID, so both end
# up in the list. Keep the real (non-interrupted) one; if multiple real
# ones exist, keep the last.
_INTERRUPTED_PAT = re.compile(
r"^Tool call '.+' with id '.+' was interrupted before completion\.$"
)
# Group ToolMessages by tool_call_id, preserving position
tc_groups: dict[str, list] = {}
for i, msg in enumerate(messages):
if isinstance(msg, ToolMessage):
tc_id = getattr(msg, "tool_call_id", None)
if tc_id:
tc_groups.setdefault(tc_id, []).append(i)
drop_indices: set = set()
for tc_id, indices in tc_groups.items():
if len(indices) <= 1:
continue
# Separate interrupted placeholders from real results
real_indices = [
i
for i in indices
if not (
isinstance(messages[i].content, str)
and _INTERRUPTED_PAT.match(messages[i].content)
)
]
interrupted_indices = [i for i in indices if i not in real_indices]
if real_indices and interrupted_indices:
# Replace the first placeholder (correct position, adjacent to AI
# message) with the last real result (likely appended at the end).
# This keeps the tool result in the right position for Bedrock.
messages[interrupted_indices[0]] = messages[real_indices[-1]]
drop_indices.update(interrupted_indices[1:])
drop_indices.update(real_indices) # drop all originals (we moved one)
elif real_indices:
# No placeholders, multiple real — keep only the last
drop_indices.update(real_indices[:-1])
else:
# All interrupted — keep only the last
drop_indices.update(interrupted_indices[:-1])
if drop_indices:
messages[:] = [
msg for i, msg in enumerate(messages) if i not in drop_indices
]
for idx, msg in enumerate(messages):
if not isinstance(msg, AIMessage):
continue
tool_calls = getattr(msg, "tool_calls", None) or []
# 1. Sync content with tool_calls: remove tool_use content blocks
# that aren't in msg.tool_calls (e.g. stripped by after_model
# but content blocks left behind in checkpoint).
if tool_calls and isinstance(msg.content, list):
tc_ids = {tc.get("id") for tc in tool_calls}
msg.content = [
block
for block in msg.content
if not (
isinstance(block, dict)
and block.get("type") == "tool_use"
and block.get("id") not in tc_ids
)
]
elif not tool_calls and isinstance(msg.content, list):
# No tool_calls at all — strip ALL tool_use content blocks
msg.content = [
block
for block in msg.content
if not (isinstance(block, dict) and block.get("type") == "tool_use")
]
if not tool_calls:
continue
# 2. Strip unanswered tool_calls — only consider ToolMessages that
# are ADJACENT (immediately following this AIMessage, before the
# next non-ToolMessage). A ToolMessage at the wrong position
# won't satisfy Bedrock's Converse API requirement that toolResult
# blocks appear in the user turn right after the assistant turn.
adjacent_tc_ids: set = set()
j = idx + 1
while j < len(messages) and isinstance(messages[j], ToolMessage):
tc_id = getattr(messages[j], "tool_call_id", None)
if tc_id:
adjacent_tc_ids.add(tc_id)
j += 1
unanswered = [
tc for tc in tool_calls if tc.get("id") not in adjacent_tc_ids
]
if unanswered:
unanswered_ids = {tc["id"] for tc in unanswered}
msg.tool_calls = [
tc for tc in tool_calls if tc.get("id") in adjacent_tc_ids
]
# Also strip matching content blocks
if isinstance(msg.content, list):
msg.content = [
block
for block in msg.content
if not (
isinstance(block, dict)
and block.get("type") == "tool_use"
and block.get("id") in unanswered_ids
)
]
# 3. Fix string args in tool_calls
for tc in msg.tool_calls or []:
if isinstance(tc.get("args"), str):
try:
tc["args"] = json.loads(tc["args"])
except (json.JSONDecodeError, TypeError):
tc["args"] = {}
# 4. Fix string input in content blocks
if isinstance(msg.content, list):
for block in msg.content:
if isinstance(block, dict) and block.get("type") == "tool_use":
inp = block.get("input")
if isinstance(inp, str):
try:
block["input"] = json.loads(inp) if inp else {}
except (json.JSONDecodeError, TypeError):
block["input"] = {}
elif inp is None:
block["input"] = {}
# 5. Remove orphan ToolMessages whose tool_call_id no longer matches
# any remaining tool_call in any AIMessage. These can be left over
# after stripping unanswered tool_calls above.
remaining_tc_ids: set = set()
for msg in messages:
if isinstance(msg, AIMessage):
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id:
remaining_tc_ids.add(tc_id)
messages[:] = [
msg
for msg in messages
if not isinstance(msg, ToolMessage)
or getattr(msg, "tool_call_id", None) in remaining_tc_ids
]
return messages
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
_extract_forwarded_headers_from_config()
_ensure_httpx_hook(request.model)
self._fix_messages_for_bedrock(request.messages)
request = self._apply_state_note(request)
a2ui_tool = self._maybe_build_a2ui_tool(request)
frontend_tools = request.state.get("copilotkit", {}).get("actions", [])
if a2ui_tool is not None:
# Our generate_a2ui replaces the runtime's render tool — don't
# advertise both. Drop the render tool the A2UI middleware injected.
decision = self._a2ui_inject_decision(request.state or {})
drop = decision if isinstance(decision, str) else "render_a2ui"
frontend_tools = [
t
for t in frontend_tools
if ((t.get("function") or {}).get("name") or t.get("name")) != drop
]
if not frontend_tools and a2ui_tool is None:
return await handler(request)
extra_tools = [a2ui_tool] if a2ui_tool is not None else []
merged_tools = [*request.tools, *extra_tools, *frontend_tools]
return await handler(request.override(tools=merged_tools))
# ------------------------------------------------------------------
# Auto-A2UI tool execution
# ------------------------------------------------------------------
# The generate_a2ui tool is advertised dynamically in wrap_model_call and is
# NOT in create_agent's static tool registry, so the tool node cannot
# execute it on its own. These hooks supply the implementation (built with
# the inferred model) for that one tool; their presence also disables
# create_agent's "unknown tool" guard for dynamically-advertised tools.
def _resolve_a2ui_request(self, request: Any) -> Any:
"""Return a request overridden with the stashed A2UI tool when this
tool call targets it, else the original request unchanged."""
tool = _a2ui_tools_by_thread.get(_current_thread_id() or _DEFAULT_THREAD_KEY)
if (
tool is not None
and getattr(request, "tool", None) is None
and request.tool_call.get("name") == tool.name
):
return request.override(tool=tool)
return request
def wrap_tool_call(
self,
request: Any,
handler: Callable[[Any], Any],
) -> Any:
return handler(self._resolve_a2ui_request(request))
async def awrap_tool_call(
self,
request: Any,
handler: Callable[[Any], Awaitable[Any]],
) -> Any:
return await handler(self._resolve_a2ui_request(request))
# Inject app context before agent runs
def before_agent(
self,
state: StateSchema,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
messages = state.get("messages", [])
if not messages:
return None
# Get app context from state or runtime
copilotkit_state = state.get("copilotkit", {})
app_context = copilotkit_state.get("context") or getattr(
runtime, "context", None
)
# Strip the reserved transport-layer key ``copilotkit_forwarded_headers``
# so it is never rendered into the LLM prompt. langgraph-api auto-copies
# ``config.configurable`` into ``runtime.context``, which means the
# forwarded-headers wrapper dict shows up here even though it is only
# meant for the httpx hook (which reads it from a separate ContextVar
# via ``_extract_forwarded_headers_from_config``).
if isinstance(app_context, dict):
app_context = {
k: v
for k, v in app_context.items()
if k != "copilotkit_forwarded_headers"
}
# Check if app_context is missing or empty
if not app_context:
return None
if isinstance(app_context, str) and app_context.strip() == "":
return None
if isinstance(app_context, dict) and len(app_context) == 0:
return None
# Create the context content
if isinstance(app_context, str):
context_content = app_context
else:
# Handle Pydantic models (e.g. ag_ui Context)
if hasattr(app_context, "model_dump"):
app_context = app_context.model_dump()
elif isinstance(app_context, list):
app_context = [
item.model_dump() if hasattr(item, "model_dump") else item
for item in app_context
]
context_content = json.dumps(app_context, indent=2)
context_message_content = f"App Context:\n{context_content}"
context_message_prefix = "App Context:\n"
# Helper to get message content as string
def get_content_string(msg: Any) -> str | None:
content = getattr(msg, "content", None)
if isinstance(content, str):
return content
if isinstance(content, list) and content and isinstance(content[0], dict):
return content[0].get("text")
return None
# Find the first system/developer message (not our context message)
# to determine where to insert our context message (right after it)
first_system_index = -1
for i, msg in enumerate(messages):
msg_type = getattr(msg, "type", None)
if msg_type in ("system", "developer"):
content = get_content_string(msg)
# Skip if this is our own context message
if content and content.startswith(context_message_prefix):
continue
first_system_index = i
break
# Check if our context message already exists
existing_context_index = -1
for i, msg in enumerate(messages):
msg_type = getattr(msg, "type", None)
if msg_type in ("system", "developer"):
content = get_content_string(msg)
if content and content.startswith(context_message_prefix):
existing_context_index = i
break
# Create the context message.
# When replacing an existing context message, reuse its ID so the
# add_messages reducer updates in-place instead of appending a
# duplicate at the end of the message list.
if existing_context_index != -1:
existing_id = getattr(messages[existing_context_index], "id", None)
context_message = SystemMessage(
content=context_message_content, id=existing_id
)
else:
context_message = SystemMessage(content=context_message_content)
if existing_context_index != -1:
# Replace existing context message
updated_messages = list(messages)
updated_messages[existing_context_index] = context_message
else:
# Insert after the first system message, or at position 0 if no system message
insert_index = first_system_index + 1 if first_system_index != -1 else 0
updated_messages = [
*messages[:insert_index],
context_message,
*messages[insert_index:],
]
return {
**state,
"messages": updated_messages,
}
async def abefore_agent(
self,
state: StateSchema,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
# Delegate to sync implementation
return self.before_agent(state, runtime)
# Intercept frontend tool calls after model returns, before ToolNode executes
def after_model(
self,
state: StateSchema,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
frontend_tools = state.get("copilotkit", {}).get("actions", [])
if not frontend_tools:
return None
frontend_tool_names = {
t.get("function", {}).get("name") or t.get("name") for t in frontend_tools
}
# Find last AI message with tool calls
messages = state.get("messages", [])
if not messages:
return None
last_message = messages[-1]
if not isinstance(last_message, AIMessage):
return None
tool_calls = getattr(last_message, "tool_calls", None) or []
if not tool_calls:
return None
backend_tool_calls = []
frontend_tool_calls = []
for call in tool_calls:
if call.get("name") in frontend_tool_names:
frontend_tool_calls.append(call)
else:
backend_tool_calls.append(call)
if not frontend_tool_calls:
return None
# Create updated AIMessage with only backend tool calls
updated_ai_message = AIMessage(
content=last_message.content,
tool_calls=backend_tool_calls,
id=last_message.id,
)
return {
"messages": [*messages[:-1], updated_ai_message],
"copilotkit": {
"intercepted_tool_calls": frontend_tool_calls,
"original_ai_message_id": last_message.id,
},
}
async def aafter_model(
self,
state: StateSchema,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
# Delegate to sync implementation
return self.after_model(state, runtime)
# Restore frontend tool calls to AIMessage before agent exits
def after_agent(
self,
state: StateSchema,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
# Drop the bridged A2UI tool for this run — all tool calls for the turn
# have executed by now; the next model call re-stashes if needed.
_a2ui_tools_by_thread.pop(_current_thread_id() or _DEFAULT_THREAD_KEY, None)
copilotkit_state = state.get("copilotkit", {})
intercepted_tool_calls = copilotkit_state.get("intercepted_tool_calls")
original_message_id = copilotkit_state.get("original_ai_message_id")
if not intercepted_tool_calls or not original_message_id:
return None
messages = state.get("messages", [])
updated_messages = []
for msg in messages:
if isinstance(msg, AIMessage) and msg.id == original_message_id:
existing_tool_calls = getattr(msg, "tool_calls", None) or []
updated_messages.append(
AIMessage(
content=msg.content,
tool_calls=[*existing_tool_calls, *intercepted_tool_calls],
id=msg.id,
)
)
else:
updated_messages.append(msg)
return {
"messages": updated_messages,
"copilotkit": {
"intercepted_tool_calls": None,
"original_ai_message_id": None,
},
}
async def aafter_agent(
self,
state: StateSchema,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
# Delegate to sync implementation
return self.after_agent(state, runtime)