From 0b6a41d000e05097d72dba2bf088dec41f964bc0 Mon Sep 17 00:00:00 2001 From: NGUYEN Duc Trung Date: Wed, 1 Apr 2026 13:09:20 +0200 Subject: [PATCH] feat: support custom tools in subagents All three SDKs (Go, Python, Node) use per-session tool handler lookup keyed to the exact session ID. When the CLI creates child sessions for subagents, those child session IDs are never registered in the SDK's sessions map. Tool calls arriving with a child session ID fail with "unknown session". This PR adds the ability to call custom tools for sub-agents in Go, following the proposal in https://github.com/github/copilot-sdk/issues/947 --- go/README.md | 52 ++ go/client.go | 260 ++++++- go/client_subagent_test.go | 653 ++++++++++++++++++ go/internal/e2e/subagent_tool_test.go | 140 ++++ go/session.go | 16 + ...ied_unlisted_tool_returns_unsupported.yaml | 18 + .../subagent_invokes_parent_custom_tool.yaml | 19 + 7 files changed, 1123 insertions(+), 35 deletions(-) create mode 100644 go/client_subagent_test.go create mode 100644 go/internal/e2e/subagent_tool_test.go create mode 100644 test/snapshots/subagent_tool/subagent_denied_unlisted_tool_returns_unsupported.yaml create mode 100644 test/snapshots/subagent_tool/subagent_invokes_parent_custom_tool.yaml diff --git a/go/README.md b/go/README.md index f29ef9fb7..3bea0be2c 100644 --- a/go/README.md +++ b/go/README.md @@ -359,6 +359,58 @@ safeLookup := copilot.DefineTool("safe_lookup", "A read-only lookup that needs n safeLookup.SkipPermission = true ``` +### Custom Tools with Subagents + +When a session is configured with both custom tools and custom agents (subagents), the +subagents can invoke the parent session's custom tools. The SDK automatically routes +tool calls from child sessions back to the parent session's tool handlers. + +#### Tool Access Control + +The `Tools` field on `CustomAgentConfig` controls which custom tools each subagent can access: + +| `Tools` value | Behavior | +|---------------|----------| +| `nil` (default) | Subagent can access **all** custom tools registered on the parent session | +| `[]string{}` (empty) | Subagent cannot access **any** custom tools | +| `[]string{"tool_a", "tool_b"}` | Subagent can only access the listed tools | + +#### Example + +```go +session, err := client.CreateSession(ctx, &copilot.SessionConfig{ + Tools: []copilot.Tool{ + copilot.DefineTool("save_output", "Saves output to storage", + func(params SaveParams, inv copilot.ToolInvocation) (string, error) { + // Handle tool call — works for both direct and subagent invocations + return saveToStorage(params.Content) + }), + copilot.DefineTool("get_data", "Retrieves data from storage", + func(params GetParams, inv copilot.ToolInvocation) (string, error) { + return getData(params.Key) + }), + }, + CustomAgents: []copilot.CustomAgentConfig{ + { + Name: "researcher", + Description: "Researches topics and saves findings", + Tools: []string{"save_output"}, // Can only use save_output, not get_data + Prompt: "You are a research assistant. Save your findings using save_output.", + }, + { + Name: "analyst", + Description: "Analyzes data from storage", + Tools: nil, // Can access ALL custom tools + Prompt: "You are a data analyst.", + }, + }, +}) +``` + +When `researcher` is invoked as a subagent, it can call `save_output` but not `get_data`. +When `analyst` is invoked, it can call both tools. If a subagent attempts to use a tool +not in its allowlist, the SDK returns a `"Tool '{name}' is not supported by this client instance."` response to the LLM. + ## Streaming Enable streaming to receive assistant response chunks as they're generated: diff --git a/go/client.go b/go/client.go index dbb5a3d8f..0601c2af5 100644 --- a/go/client.go +++ b/go/client.go @@ -53,6 +53,14 @@ import ( const noResultPermissionV2Error = "permission handlers cannot return 'no-result' when connected to a protocol v2 server" +// subagentInstance represents a single active subagent launch. +type subagentInstance struct { + agentName string + toolCallID string + childSessionID string // empty until child session ID is known + startedAt time.Time +} + // Client manages the connection to the Copilot CLI server and provides session management. // // The Client can either spawn a CLI server process or connect to an existing server. @@ -73,14 +81,30 @@ const noResultPermissionV2Error = "permission handlers cannot return 'no-result' // } // defer client.Stop() type Client struct { - options ClientOptions - process *exec.Cmd - client *jsonrpc2.Client - actualPort int - actualHost string - state ConnectionState - sessions map[string]*Session - sessionsMux sync.Mutex + options ClientOptions + process *exec.Cmd + client *jsonrpc2.Client + actualPort int + actualHost string + state ConnectionState + sessions map[string]*Session + sessionsMux sync.Mutex + + // childToParent maps childSessionID → parentSessionID. + // Populated exclusively from authoritative protocol signals. + // Protected by sessionsMux. + childToParent map[string]string + + // childToAgent maps childSessionID → agentName. + // Used for allowlist enforcement. Populated alongside childToParent. + // Protected by sessionsMux. + childToAgent map[string]string + + // subagentInstances tracks active subagent launches per parent session. + // Key: parentSessionID → map of toolCallID → subagentInstance. + // Protected by sessionsMux. + subagentInstances map[string]map[string]*subagentInstance + isExternalServer bool conn net.Conn // stores net.Conn for external TCP connections useStdio bool // resolved value from options @@ -127,13 +151,16 @@ func NewClient(options *ClientOptions) *Client { } client := &Client{ - options: opts, - state: StateDisconnected, - sessions: make(map[string]*Session), - actualHost: "localhost", - isExternalServer: false, - useStdio: true, - autoStart: true, // default + options: opts, + state: StateDisconnected, + sessions: make(map[string]*Session), + childToParent: make(map[string]string), + childToAgent: make(map[string]string), + subagentInstances: make(map[string]map[string]*subagentInstance), + actualHost: "localhost", + isExternalServer: false, + useStdio: true, + autoStart: true, // default } if options != nil { @@ -346,6 +373,9 @@ func (c *Client) Stop() error { c.sessionsMux.Lock() c.sessions = make(map[string]*Session) + c.childToParent = make(map[string]string) + c.childToAgent = make(map[string]string) + c.subagentInstances = make(map[string]map[string]*subagentInstance) c.sessionsMux.Unlock() c.startStopMux.Lock() @@ -586,6 +616,12 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses // events emitted by the CLI (e.g. session.start) are not dropped. session := newSession(sessionID, c.client, "") + session.customAgents = config.CustomAgents + session.onDestroy = func() { + c.sessionsMux.Lock() + c.removeChildMappingsForParentLocked(session.SessionID) + c.sessionsMux.Unlock() + } session.registerTools(config.Tools) session.registerPermissionHandler(config.OnPermissionRequest) if config.OnUserInputRequest != nil { @@ -707,6 +743,12 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, // events emitted by the CLI (e.g. session.start) are not dropped. session := newSession(sessionID, c.client, "") + session.customAgents = config.CustomAgents + session.onDestroy = func() { + c.sessionsMux.Lock() + c.removeChildMappingsForParentLocked(session.SessionID) + c.sessionsMux.Unlock() + } session.registerTools(config.Tools) session.registerPermissionHandler(config.OnPermissionRequest) if config.OnUserInputRequest != nil { @@ -860,6 +902,7 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error { // Remove from local sessions map if present c.sessionsMux.Lock() delete(c.sessions, sessionID) + c.removeChildMappingsForParentLocked(sessionID) c.sessionsMux.Unlock() return nil @@ -1500,21 +1543,164 @@ func (c *Client) handleSessionEvent(req sessionEventRequest) { c.sessionsMux.Unlock() if ok { + // Intercept subagent lifecycle events for child tracking + c.handleSubagentEvent(req.SessionID, req.Event) session.dispatchEvent(req.Event) } } +// handleSubagentEvent intercepts subagent lifecycle events to manage child session tracking. +func (c *Client) handleSubagentEvent(parentSessionID string, event SessionEvent) { + switch event.Type { + case SessionEventTypeSubagentStarted: + c.onSubagentStarted(parentSessionID, event) + case SessionEventTypeSubagentCompleted, SessionEventTypeSubagentFailed: + c.onSubagentEnded(parentSessionID, event) + } +} + +// onSubagentStarted handles a subagent.started event by creating a subagent instance +// and mapping the child session to its parent. +func (c *Client) onSubagentStarted(parentSessionID string, event SessionEvent) { + toolCallID := derefStr(event.Data.ToolCallID) + agentName := derefStr(event.Data.AgentName) + childSessionID := derefStr(event.Data.RemoteSessionID) + + c.sessionsMux.Lock() + defer c.sessionsMux.Unlock() + + // Track instance by toolCallID (unique per launch). + // Skip tracking when toolCallID is empty — multiple launches would collide + // on the empty-string key and overwrite each other. + if toolCallID != "" { + if c.subagentInstances[parentSessionID] == nil { + c.subagentInstances[parentSessionID] = make(map[string]*subagentInstance) + } + c.subagentInstances[parentSessionID][toolCallID] = &subagentInstance{ + agentName: agentName, + toolCallID: toolCallID, + childSessionID: childSessionID, + startedAt: event.Timestamp, + } + } + + // Eagerly map child→parent and child→agent + if childSessionID != "" { + c.childToParent[childSessionID] = parentSessionID + c.childToAgent[childSessionID] = agentName + } +} + +// onSubagentEnded handles subagent.completed and subagent.failed events +// by removing the subagent instance. Child-to-parent mappings are NOT removed +// here because in-flight requests may still arrive after the subagent completes. +func (c *Client) onSubagentEnded(parentSessionID string, event SessionEvent) { + toolCallID := derefStr(event.Data.ToolCallID) + + c.sessionsMux.Lock() + defer c.sessionsMux.Unlock() + + if instances, ok := c.subagentInstances[parentSessionID]; ok { + delete(instances, toolCallID) + if len(instances) == 0 { + delete(c.subagentInstances, parentSessionID) + } + } +} + +// derefStr safely dereferences a string pointer, returning "" if nil. +func derefStr(s *string) string { + if s == nil { + return "" + } + return *s +} + +// resolveSession looks up a session by ID. If the ID is not a directly +// registered session, it checks whether it is a known child session and +// returns the parent session instead. +// +// Returns (session, isChild, error). isChild=true means the request came +// from a child session and was resolved via parent lineage. +// +// Lock contract: acquires and releases sessionsMux internally. +// Does NOT hold sessionsMux when returning. +func (c *Client) resolveSession(sessionID string) (*Session, bool, error) { + c.sessionsMux.Lock() + // Direct lookup + if session, ok := c.sessions[sessionID]; ok { + c.sessionsMux.Unlock() + return session, false, nil + } + // Child→parent lookup (authoritative mapping only) + parentID, isChild := c.childToParent[sessionID] + if !isChild { + c.sessionsMux.Unlock() + return nil, false, fmt.Errorf("unknown session %s", sessionID) + } + session, ok := c.sessions[parentID] + c.sessionsMux.Unlock() + if !ok { + return nil, false, fmt.Errorf("parent session %s for child %s not found", parentID, sessionID) + } + return session, true, nil +} + +// removeChildMappingsForParentLocked removes all child mappings for a parent session. +// MUST be called with sessionsMux held. +func (c *Client) removeChildMappingsForParentLocked(parentSessionID string) { + for childID, parentID := range c.childToParent { + if parentID == parentSessionID { + delete(c.childToParent, childID) + delete(c.childToAgent, childID) + } + } + delete(c.subagentInstances, parentSessionID) +} + +// isToolAllowedForChild checks whether a tool is in the allowlist for the agent +// that owns the given child session. +func (c *Client) isToolAllowedForChild(childSessionID, toolName string) bool { + c.sessionsMux.Lock() + agentName, ok := c.childToAgent[childSessionID] + c.sessionsMux.Unlock() + if !ok { + return false // unknown child → deny + } + + session, _, _ := c.resolveSession(childSessionID) + if session == nil { + return false + } + + agentConfig := session.getAgentConfig(agentName) + if agentConfig == nil { + return false // agent not found → deny + } + + // nil Tools = all tools allowed + if agentConfig.Tools == nil { + return true + } + + // Explicit list — check membership + for _, t := range agentConfig.Tools { + if t == toolName { + return true + } + } + return false +} + // handleUserInputRequest handles a user input request from the CLI server. func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputResponse, *jsonrpc2.Error) { if req.SessionID == "" || req.Question == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid user input request payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, _, err := c.resolveSession(req.SessionID) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} } response, err := session.handleUserInputRequest(UserInputRequest{ @@ -1535,11 +1721,9 @@ func (c *Client) handleHooksInvoke(req hooksInvokeRequest) (map[string]any, *jso return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid hooks invoke payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, _, err := c.resolveSession(req.SessionID) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} } output, err := session.handleHooksInvoke(req.Type, req.Input) @@ -1610,11 +1794,19 @@ func (c *Client) handleToolCallRequestV2(req toolCallRequestV2) (*toolCallRespon return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid tool call payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, isChild, err := c.resolveSession(req.SessionID) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} + } + + // For child sessions, enforce tool allowlist + if isChild && !c.isToolAllowedForChild(req.SessionID, req.ToolName) { + return &toolCallResponseV2{Result: ToolResult{ + TextResultForLLM: fmt.Sprintf("Tool '%s' is not supported by this client instance.", req.ToolName), + ResultType: "failure", + Error: fmt.Sprintf("tool '%s' not supported", req.ToolName), + ToolTelemetry: map[string]any{}, + }}, nil } handler, ok := session.getToolHandler(req.ToolName) @@ -1656,11 +1848,9 @@ func (c *Client) handlePermissionRequestV2(req permissionRequestV2) (*permission return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid permission request payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, _, err := c.resolveSession(req.SessionID) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} } handler := session.getPermissionHandler() diff --git a/go/client_subagent_test.go b/go/client_subagent_test.go new file mode 100644 index 000000000..a368c44db --- /dev/null +++ b/go/client_subagent_test.go @@ -0,0 +1,653 @@ +package copilot + +import ( + "encoding/json" + "strings" + "sync" + "testing" + "time" +) + +// newTestClient creates a minimal test client with initialized maps. +func newTestClient() *Client { + return &Client{ + sessions: make(map[string]*Session), + childToParent: make(map[string]string), + childToAgent: make(map[string]string), + subagentInstances: make(map[string]map[string]*subagentInstance), + } +} + +// newSubagentTestSession creates a minimal test session with tools and agents. +func newSubagentTestSession(id string, tools []Tool, agents []CustomAgentConfig) *Session { + s := &Session{ + SessionID: id, + toolHandlers: make(map[string]ToolHandler), + customAgents: agents, + } + for _, t := range tools { + if t.Name != "" && t.Handler != nil { + s.toolHandlers[t.Name] = t.Handler + } + } + return s +} + +func strPtr(s string) *string { return &s } + +func testToolHandler(inv ToolInvocation) (ToolResult, error) { + return ToolResult{TextResultForLLM: "ok", ResultType: "success"}, nil +} + +// --------------------------------------------------------------------------- +// TestResolveSession +// --------------------------------------------------------------------------- + +func TestResolveSession(t *testing.T) { + t.Run("direct_session_returns_session", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + + session, isChild, err := c.resolveSession("parent-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if isChild { + t.Fatal("expected isChild=false for direct session") + } + if session != parent { + t.Fatal("returned session does not match registered session") + } + }) + + t.Run("child_session_returns_parent", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + + session, isChild, err := c.resolveSession("child-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !isChild { + t.Fatal("expected isChild=true for child session") + } + if session != parent { + t.Fatal("returned session should be the parent session") + } + }) + + t.Run("unknown_session_returns_error", func(t *testing.T) { + c := newTestClient() + + session, isChild, err := c.resolveSession("nonexistent") + if err == nil { + t.Fatal("expected error for unknown session") + } + if !strings.Contains(err.Error(), "unknown session") { + t.Fatalf("error should contain 'unknown session', got: %v", err) + } + if isChild { + t.Fatal("expected isChild=false") + } + if session != nil { + t.Fatal("expected nil session") + } + }) + + t.Run("child_of_deleted_parent_returns_error", func(t *testing.T) { + c := newTestClient() + c.childToParent["child-1"] = "parent-1" + // parent-1 is NOT registered in c.sessions + + session, isChild, err := c.resolveSession("child-1") + if err == nil { + t.Fatal("expected error when parent session is missing") + } + if !strings.Contains(err.Error(), "parent session") { + t.Fatalf("error should contain 'parent session', got: %v", err) + } + if isChild { + t.Fatal("expected isChild=false on error path") + } + if session != nil { + t.Fatal("expected nil session") + } + }) +} + +// --------------------------------------------------------------------------- +// TestChildToolAllowlist +// --------------------------------------------------------------------------- + +func TestChildToolAllowlist(t *testing.T) { + setup := func(tools []string) *Client { + c := newTestClient() + agents := []CustomAgentConfig{{Name: "test-agent", Tools: tools}} + parent := newSubagentTestSession("parent-1", []Tool{ + {Name: "save_output", Handler: testToolHandler}, + {Name: "other_tool", Handler: testToolHandler}, + }, agents) + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + return c + } + + t.Run("nil_tools_allows_all", func(t *testing.T) { + c := setup(nil) // nil Tools = all allowed + if !c.isToolAllowedForChild("child-1", "save_output") { + t.Fatal("nil Tools should allow save_output") + } + if !c.isToolAllowedForChild("child-1", "other_tool") { + t.Fatal("nil Tools should allow other_tool") + } + if !c.isToolAllowedForChild("child-1", "any_random_tool") { + t.Fatal("nil Tools should allow any tool") + } + }) + + t.Run("explicit_list_allows_listed_tool", func(t *testing.T) { + c := setup([]string{"save_output"}) + if !c.isToolAllowedForChild("child-1", "save_output") { + t.Fatal("save_output should be allowed") + } + }) + + t.Run("explicit_list_blocks_unlisted_tool", func(t *testing.T) { + c := setup([]string{"save_output"}) + if c.isToolAllowedForChild("child-1", "other_tool") { + t.Fatal("other_tool should be blocked") + } + }) + + t.Run("empty_tools_blocks_all", func(t *testing.T) { + c := setup([]string{}) // empty = block all + if c.isToolAllowedForChild("child-1", "save_output") { + t.Fatal("empty Tools should block save_output") + } + if c.isToolAllowedForChild("child-1", "other_tool") { + t.Fatal("empty Tools should block other_tool") + } + }) +} + +// --------------------------------------------------------------------------- +// TestSubagentInstanceTracking +// --------------------------------------------------------------------------- + +func TestSubagentInstanceTracking(t *testing.T) { + makeEvent := func(evType SessionEventType, toolCallID, agentName, childSessionID string) SessionEvent { + return SessionEvent{ + Type: evType, + Timestamp: time.Now(), + Data: Data{ + ToolCallID: strPtr(toolCallID), + AgentName: strPtr(agentName), + RemoteSessionID: strPtr(childSessionID), + }, + } + } + + t.Run("started_creates_instance", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + + event := makeEvent(SessionEventTypeSubagentStarted, "tc-1", "my-agent", "child-session-1") + c.onSubagentStarted("parent-1", event) + + // Verify subagentInstances + instances, ok := c.subagentInstances["parent-1"] + if !ok { + t.Fatal("expected subagentInstances entry for parent-1") + } + inst, ok := instances["tc-1"] + if !ok { + t.Fatal("expected instance with toolCallID tc-1") + } + if inst.agentName != "my-agent" { + t.Fatalf("expected agentName 'my-agent', got %q", inst.agentName) + } + if inst.childSessionID != "child-session-1" { + t.Fatalf("expected childSessionID 'child-session-1', got %q", inst.childSessionID) + } + + // Verify child mappings + if c.childToParent["child-session-1"] != "parent-1" { + t.Fatal("childToParent mapping not set") + } + if c.childToAgent["child-session-1"] != "my-agent" { + t.Fatal("childToAgent mapping not set") + } + }) + + t.Run("completed_removes_instance", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + + startEvent := makeEvent(SessionEventTypeSubagentStarted, "tc-1", "my-agent", "child-session-1") + c.onSubagentStarted("parent-1", startEvent) + + endEvent := makeEvent(SessionEventTypeSubagentCompleted, "tc-1", "my-agent", "child-session-1") + c.onSubagentEnded("parent-1", endEvent) + + // Instance removed + if instances, ok := c.subagentInstances["parent-1"]; ok && len(instances) > 0 { + t.Fatal("expected instance to be removed after completion") + } + + // Child mappings preserved for in-flight requests + if c.childToParent["child-session-1"] != "parent-1" { + t.Fatal("childToParent should be preserved after subagent completion") + } + if c.childToAgent["child-session-1"] != "my-agent" { + t.Fatal("childToAgent should be preserved after subagent completion") + } + }) + + t.Run("concurrent_same_agent_tracked_independently", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + + // Two launches of the same agent with different toolCallIDs + event1 := makeEvent(SessionEventTypeSubagentStarted, "tc-1", "my-agent", "child-1") + event2 := makeEvent(SessionEventTypeSubagentStarted, "tc-2", "my-agent", "child-2") + c.onSubagentStarted("parent-1", event1) + c.onSubagentStarted("parent-1", event2) + + instances := c.subagentInstances["parent-1"] + if len(instances) != 2 { + t.Fatalf("expected 2 instances, got %d", len(instances)) + } + + // Complete one + endEvent := makeEvent(SessionEventTypeSubagentCompleted, "tc-1", "my-agent", "child-1") + c.onSubagentEnded("parent-1", endEvent) + + instances = c.subagentInstances["parent-1"] + if len(instances) != 1 { + t.Fatalf("expected 1 instance remaining, got %d", len(instances)) + } + if _, ok := instances["tc-2"]; !ok { + t.Fatal("tc-2 should still be tracked") + } + }) +} + +// --------------------------------------------------------------------------- +// TestRequestHandlerResolution +// --------------------------------------------------------------------------- + +func TestRequestHandlerResolution(t *testing.T) { + t.Run("tool_call_resolves_child_session", func(t *testing.T) { + c := newTestClient() + agents := []CustomAgentConfig{{Name: "test-agent", Tools: nil}} // nil = all tools + parent := newSubagentTestSession("parent-1", []Tool{ + {Name: "my_tool", Handler: testToolHandler}, + }, agents) + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + + resp, rpcErr := c.handleToolCallRequestV2(toolCallRequestV2{ + SessionID: "child-1", + ToolCallID: "tc-1", + ToolName: "my_tool", + Arguments: map[string]any{}, + }) + if rpcErr != nil { + t.Fatalf("unexpected RPC error: %v", rpcErr.Message) + } + if resp.Result.ResultType != "success" { + t.Fatalf("expected success result, got %q", resp.Result.ResultType) + } + if resp.Result.TextResultForLLM != "ok" { + t.Fatalf("expected 'ok', got %q", resp.Result.TextResultForLLM) + } + }) + + t.Run("permission_request_resolves_child_session", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + parent.permissionHandler = func(req PermissionRequest, inv PermissionInvocation) (PermissionRequestResult, error) { + return PermissionRequestResult{Kind: "approved"}, nil + } + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + + resp, rpcErr := c.handlePermissionRequestV2(permissionRequestV2{ + SessionID: "child-1", + Request: PermissionRequest{Kind: "file_write"}, + }) + if rpcErr != nil { + t.Fatalf("unexpected RPC error: %v", rpcErr.Message) + } + if resp.Result.Kind != "approved" { + t.Fatalf("expected 'approved', got %q", resp.Result.Kind) + } + }) + + t.Run("user_input_resolves_child_session", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + parent.userInputHandler = func(req UserInputRequest, inv UserInputInvocation) (UserInputResponse, error) { + return UserInputResponse{Answer: "test-answer"}, nil + } + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + + resp, rpcErr := c.handleUserInputRequest(userInputRequest{ + SessionID: "child-1", + Question: "What is your name?", + }) + if rpcErr != nil { + t.Fatalf("unexpected RPC error: %v", rpcErr.Message) + } + if resp.Answer != "test-answer" { + t.Fatalf("expected 'test-answer', got %q", resp.Answer) + } + }) + + t.Run("hooks_invoke_resolves_child_session", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + parent.hooks = &SessionHooks{ + OnPreToolUse: func(input PreToolUseHookInput, inv HookInvocation) (*PreToolUseHookOutput, error) { + return &PreToolUseHookOutput{PermissionDecision: "allow"}, nil + }, + } + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + + hookInput, _ := json.Marshal(PreToolUseHookInput{ + Timestamp: time.Now().Unix(), + Cwd: "/tmp", + ToolName: "some_tool", + }) + + result, rpcErr := c.handleHooksInvoke(hooksInvokeRequest{ + SessionID: "child-1", + Type: "preToolUse", + Input: json.RawMessage(hookInput), + }) + if rpcErr != nil { + t.Fatalf("unexpected RPC error: %v", rpcErr.Message) + } + if result == nil { + t.Fatal("expected non-nil result") + } + if result["output"] == nil { + t.Fatal("expected output in result") + } + }) + + t.Run("tool_call_child_denied_tool_returns_unsupported", func(t *testing.T) { + c := newTestClient() + agents := []CustomAgentConfig{{Name: "test-agent", Tools: []string{"allowed_tool"}}} + parent := newSubagentTestSession("parent-1", []Tool{ + {Name: "allowed_tool", Handler: testToolHandler}, + {Name: "denied_tool", Handler: testToolHandler}, + }, agents) + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + + resp, rpcErr := c.handleToolCallRequestV2(toolCallRequestV2{ + SessionID: "child-1", + ToolCallID: "tc-1", + ToolName: "denied_tool", + Arguments: map[string]any{}, + }) + // Should NOT return an RPC error — returns an unsupported tool result instead + if rpcErr != nil { + t.Fatalf("should not return RPC error for denied tool, got: %v", rpcErr.Message) + } + if resp.Result.ResultType != "failure" { + t.Fatalf("expected failure result, got %q", resp.Result.ResultType) + } + if !strings.Contains(resp.Result.TextResultForLLM, "not supported") { + t.Fatalf("expected 'not supported' message, got %q", resp.Result.TextResultForLLM) + } + }) +} + +// --------------------------------------------------------------------------- +// TestCleanup +// --------------------------------------------------------------------------- + +func TestCleanup(t *testing.T) { + t.Run("stop_clears_all_maps", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + c.subagentInstances["parent-1"] = map[string]*subagentInstance{ + "tc-1": {agentName: "test-agent", toolCallID: "tc-1"}, + } + + // Simulate cleanup (Stop() does RPC + map clearing; we test removeChildMappingsForParentLocked + manual clear) + c.sessionsMux.Lock() + c.removeChildMappingsForParentLocked("parent-1") + delete(c.sessions, "parent-1") + c.sessionsMux.Unlock() + + if len(c.childToParent) != 0 { + t.Fatal("childToParent should be empty") + } + if len(c.childToAgent) != 0 { + t.Fatal("childToAgent should be empty") + } + if len(c.subagentInstances) != 0 { + t.Fatal("subagentInstances should be empty") + } + if len(c.sessions) != 0 { + t.Fatal("sessions should be empty") + } + }) + + t.Run("delete_session_clears_only_target_children", func(t *testing.T) { + c := newTestClient() + parentA := newSubagentTestSession("parent-A", nil, nil) + parentB := newSubagentTestSession("parent-B", nil, nil) + c.sessions["parent-A"] = parentA + c.sessions["parent-B"] = parentB + c.childToParent["child-A1"] = "parent-A" + c.childToParent["child-A2"] = "parent-A" + c.childToParent["child-B1"] = "parent-B" + c.childToAgent["child-A1"] = "agent-a" + c.childToAgent["child-A2"] = "agent-a" + c.childToAgent["child-B1"] = "agent-b" + c.subagentInstances["parent-A"] = map[string]*subagentInstance{ + "tc-a1": {agentName: "agent-a"}, + } + c.subagentInstances["parent-B"] = map[string]*subagentInstance{ + "tc-b1": {agentName: "agent-b"}, + } + + c.sessionsMux.Lock() + c.removeChildMappingsForParentLocked("parent-A") + c.sessionsMux.Unlock() + + // parent-A children removed + if _, ok := c.childToParent["child-A1"]; ok { + t.Fatal("child-A1 should be removed") + } + if _, ok := c.childToParent["child-A2"]; ok { + t.Fatal("child-A2 should be removed") + } + if _, ok := c.subagentInstances["parent-A"]; ok { + t.Fatal("parent-A subagentInstances should be removed") + } + + // parent-B children intact + if c.childToParent["child-B1"] != "parent-B" { + t.Fatal("child-B1 mapping should still exist") + } + if c.childToAgent["child-B1"] != "agent-b" { + t.Fatal("child-B1 agent mapping should still exist") + } + if _, ok := c.subagentInstances["parent-B"]; !ok { + t.Fatal("parent-B subagentInstances should still exist") + } + }) + + t.Run("destroy_session_clears_children_via_callback", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + c.subagentInstances["parent-1"] = map[string]*subagentInstance{ + "tc-1": {agentName: "test-agent"}, + } + + // Set up onDestroy callback (mirrors Client's real onDestroy which only clears child mappings) + parent.onDestroy = func() { + c.sessionsMux.Lock() + defer c.sessionsMux.Unlock() + c.removeChildMappingsForParentLocked("parent-1") + } + + // Call onDestroy + parent.onDestroy() + + if len(c.childToParent) != 0 { + t.Fatal("childToParent should be cleared by onDestroy") + } + if len(c.childToAgent) != 0 { + t.Fatal("childToAgent should be cleared by onDestroy") + } + if len(c.subagentInstances) != 0 { + t.Fatal("subagentInstances should be cleared by onDestroy") + } + // Session itself is NOT removed by onDestroy (that's Destroy()'s job via RPC) + if _, ok := c.sessions["parent-1"]; !ok { + t.Fatal("session should still exist after onDestroy (only child mappings cleared)") + } + }) +} + +// --------------------------------------------------------------------------- +// TestSessionIsolation +// --------------------------------------------------------------------------- + +func TestSessionIsolation(t *testing.T) { + t.Run("child_cannot_reach_other_parent", func(t *testing.T) { + c := newTestClient() + parentA := newSubagentTestSession("parent-A", nil, nil) + parentB := newSubagentTestSession("parent-B", nil, nil) + c.sessions["parent-A"] = parentA + c.sessions["parent-B"] = parentB + c.childToParent["child-A"] = "parent-A" + + session, isChild, err := c.resolveSession("child-A") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !isChild { + t.Fatal("expected isChild=true") + } + if session != parentA { + t.Fatal("child-A should resolve to parent-A, not parent-B") + } + if session == parentB { + t.Fatal("child-A must not resolve to parent-B") + } + }) + + t.Run("child_session_id_immutable_mapping", func(t *testing.T) { + c := newTestClient() + parentA := newSubagentTestSession("parent-A", nil, nil) + c.sessions["parent-A"] = parentA + c.childToParent["child-1"] = "parent-A" + + // Resolve multiple times — always gets parent-A + for i := 0; i < 5; i++ { + session, isChild, err := c.resolveSession("child-1") + if err != nil { + t.Fatalf("iteration %d: unexpected error: %v", i, err) + } + if !isChild { + t.Fatalf("iteration %d: expected isChild=true", i) + } + if session != parentA { + t.Fatalf("iteration %d: mapping should consistently resolve to parent-A", i) + } + } + }) +} + +// --------------------------------------------------------------------------- +// TestConcurrency +// --------------------------------------------------------------------------- + +func TestConcurrency(t *testing.T) { + t.Run("concurrent_resolve_session_safe", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent", nil, nil) + c.sessions["parent"] = parent + c.childToParent["child-1"] = "parent" + c.childToAgent["child-1"] = "agent" + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + c.resolveSession("parent") + c.resolveSession("child-1") + c.resolveSession("nonexistent") + }() + } + wg.Wait() + }) + + t.Run("concurrent_subagent_events_safe", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent", nil, nil) + c.sessions["parent"] = parent + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + tcID := "tc-" + strings.Repeat("x", idx%10) + childID := "child-" + strings.Repeat("x", idx%10) + event := SessionEvent{ + Type: SessionEventTypeSubagentStarted, + Timestamp: time.Now(), + Data: Data{ + ToolCallID: strPtr(tcID), + AgentName: strPtr("agent"), + RemoteSessionID: strPtr(childID), + }, + } + c.handleSubagentEvent("parent", event) + + // Also try resolving concurrently + c.resolveSession(childID) + c.isToolAllowedForChild(childID, "some_tool") + + endEvent := SessionEvent{ + Type: SessionEventTypeSubagentCompleted, + Timestamp: time.Now(), + Data: Data{ + ToolCallID: strPtr(tcID), + }, + } + c.handleSubagentEvent("parent", endEvent) + }(i) + } + wg.Wait() + }) +} diff --git a/go/internal/e2e/subagent_tool_test.go b/go/internal/e2e/subagent_tool_test.go new file mode 100644 index 000000000..08efa5746 --- /dev/null +++ b/go/internal/e2e/subagent_tool_test.go @@ -0,0 +1,140 @@ +//go:build integration + +package e2e + +import ( + "strings" + "testing" + "time" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +// TestSubagentCustomTools requires a real CLI to test the full round-trip of +// subagent child sessions invoking custom tools registered on parent sessions. +// +// Run with: +// +// cd go && go test -tags integration -v ./internal/e2e -run TestSubagentCustomTools +// +// Prerequisites: +// - Copilot CLI installed (or COPILOT_CLI_PATH set) +// - Valid GitHub authentication configured +func TestSubagentCustomTools(t *testing.T) { + ctx := testharness.NewTestContext(t) + client := ctx.NewClient() + t.Cleanup(func() { client.ForceStop() }) + + t.Run("subagent invokes parent custom tool", func(t *testing.T) { + ctx.ConfigureForTest(t) + + // Track tool invocations + toolInvoked := make(chan string, 1) + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Tools: []copilot.Tool{ + copilot.DefineTool("save_result", "Saves a result string", + func(params struct { + Result string `json:"result" jsonschema:"The result to save"` + }, inv copilot.ToolInvocation) (string, error) { + select { + case toolInvoked <- params.Result: + default: + } + return "saved: " + params.Result, nil + }), + }, + CustomAgents: []copilot.CustomAgentConfig{ + { + Name: "helper-agent", + DisplayName: "Helper Agent", + Description: "A helper agent that can save results using the save_result tool", + Tools: []string{"save_result"}, + Prompt: "You are a helper agent. When asked to save something, use the save_result tool.", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // Send a message that should trigger the subagent which invokes the custom tool + _, err = session.Send(t.Context(), copilot.MessageOptions{ + Prompt: "Use the helper-agent to save the result 'hello world'", + }) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + + // Wait for the tool to be invoked (with timeout) + select { + case result := <-toolInvoked: + if !strings.Contains(strings.ToLower(result), "hello world") { + t.Errorf("Expected tool to receive 'hello world', got %q", result) + } + case <-time.After(30 * time.Second): + t.Fatal("Timeout waiting for save_result tool invocation from subagent") + } + + // Get the final response + answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) + } + if answer.Data.Content == nil { + t.Fatal("Expected non-nil content in response") + } + t.Logf("Response: %s", *answer.Data.Content) + }) + + t.Run("subagent denied unlisted tool returns unsupported", func(t *testing.T) { + ctx.ConfigureForTest(t) + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Tools: []copilot.Tool{ + copilot.DefineTool("allowed_tool", "An allowed tool", + func(params struct{}, inv copilot.ToolInvocation) (string, error) { + return "allowed", nil + }), + copilot.DefineTool("restricted_tool", "A restricted tool", + func(params struct{}, inv copilot.ToolInvocation) (string, error) { + t.Error("restricted_tool should not be invoked by subagent") + return "should not reach here", nil + }), + }, + CustomAgents: []copilot.CustomAgentConfig{ + { + Name: "restricted-agent", + DisplayName: "Restricted Agent", + Description: "An agent with limited tool access", + Tools: []string{"allowed_tool"}, // restricted_tool NOT listed + Prompt: "You are a restricted agent. Try to use both allowed_tool and restricted_tool.", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + _, err = session.Send(t.Context(), copilot.MessageOptions{ + Prompt: "Use the restricted-agent to invoke restricted_tool", + }) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + + answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) + } + if answer.Data.Content == nil { + t.Fatal("Expected assistant message content but got nil") + } + t.Logf("Response: %s", *answer.Data.Content) + // The primary assertion is the t.Error inside restricted_tool's handler above. + // We don't assert on response text because LLM output is non-deterministic. + }) +} diff --git a/go/session.go b/go/session.go index 5be626b52..28015e17c 100644 --- a/go/session.go +++ b/go/session.go @@ -66,6 +66,8 @@ type Session struct { hooksMux sync.RWMutex transformCallbacks map[string]SectionTransformFn transformMu sync.Mutex + onDestroy func() // set by Client when session is created; called by Disconnect() (Destroy() delegates to Disconnect()) + customAgents []CustomAgentConfig // agent configs from SessionConfig // eventCh serializes user event handler dispatch. dispatchEvent enqueues; // a single goroutine (processEvents) dequeues and invokes handlers in FIFO order. @@ -83,6 +85,16 @@ func (s *Session) WorkspacePath() string { return s.workspacePath } +// getAgentConfig returns the CustomAgentConfig for the given agent name, or nil if not found. +func (s *Session) getAgentConfig(agentName string) *CustomAgentConfig { + for i := range s.customAgents { + if s.customAgents[i].Name == agentName { + return &s.customAgents[i] + } + } + return nil +} + // newSession creates a new session wrapper with the given session ID and client. func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) *Session { s := &Session{ @@ -748,6 +760,10 @@ func (s *Session) Disconnect() error { s.permissionHandler = nil s.permissionMux.Unlock() + if s.onDestroy != nil { + s.onDestroy() + } + return nil } diff --git a/test/snapshots/subagent_tool/subagent_denied_unlisted_tool_returns_unsupported.yaml b/test/snapshots/subagent_tool/subagent_denied_unlisted_tool_returns_unsupported.yaml new file mode 100644 index 000000000..b963a2a1b --- /dev/null +++ b/test/snapshots/subagent_tool/subagent_denied_unlisted_tool_returns_unsupported.yaml @@ -0,0 +1,18 @@ +# Placeholder snapshot for subagent denied tool E2E test. +# This snapshot needs to be captured from a real CLI session. +# +# To capture: +# 1. Ensure COPILOT_CLI_PATH is set or CLI is installed +# 2. Run: cd go && go test -tags integration -v ./internal/e2e -run TestSubagentCustomTools +# 3. The proxy will capture the exchanges and write this file +# +# Expected flow: +# 1. Parent creates session with allowed_tool and restricted_tool, plus restricted-agent +# 2. restricted-agent only has allowed_tool in its Tools allowlist +# 3. User asks restricted-agent to invoke restricted_tool +# 4. CLI sends tool.call with child session ID for restricted_tool +# 5. SDK resolves child→parent, checks allowlist, returns "unsupported" error +# 6. restricted_tool handler is never invoked +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/subagent_tool/subagent_invokes_parent_custom_tool.yaml b/test/snapshots/subagent_tool/subagent_invokes_parent_custom_tool.yaml new file mode 100644 index 000000000..c544f9ac1 --- /dev/null +++ b/test/snapshots/subagent_tool/subagent_invokes_parent_custom_tool.yaml @@ -0,0 +1,19 @@ +# Placeholder snapshot for subagent custom tool E2E test. +# This snapshot needs to be captured from a real CLI session. +# +# To capture: +# 1. Ensure COPILOT_CLI_PATH is set or CLI is installed +# 2. Run: cd go && go test -tags integration -v ./internal/e2e -run TestSubagentCustomTools +# 3. The proxy will capture the exchanges and write this file +# +# Expected flow: +# 1. Parent creates session with save_result tool and helper-agent subagent +# 2. User sends message triggering helper-agent +# 3. CLI creates child session, emits subagent.started with RemoteSessionID +# 4. Child LLM invokes save_result tool +# 5. CLI sends tool.call with child session ID +# 6. SDK resolves child→parent, dispatches to save_result handler +# 7. Tool result returned, subagent completes +models: + - claude-sonnet-4.5 +conversations: []