diff --git a/go/README.md b/go/README.md index f29ef9fb7..3bea0be2c 100644 --- a/go/README.md +++ b/go/README.md @@ -359,6 +359,58 @@ safeLookup := copilot.DefineTool("safe_lookup", "A read-only lookup that needs n safeLookup.SkipPermission = true ``` +### Custom Tools with Subagents + +When a session is configured with both custom tools and custom agents (subagents), the +subagents can invoke the parent session's custom tools. The SDK automatically routes +tool calls from child sessions back to the parent session's tool handlers. + +#### Tool Access Control + +The `Tools` field on `CustomAgentConfig` controls which custom tools each subagent can access: + +| `Tools` value | Behavior | +|---------------|----------| +| `nil` (default) | Subagent can access **all** custom tools registered on the parent session | +| `[]string{}` (empty) | Subagent cannot access **any** custom tools | +| `[]string{"tool_a", "tool_b"}` | Subagent can only access the listed tools | + +#### Example + +```go +session, err := client.CreateSession(ctx, &copilot.SessionConfig{ + Tools: []copilot.Tool{ + copilot.DefineTool("save_output", "Saves output to storage", + func(params SaveParams, inv copilot.ToolInvocation) (string, error) { + // Handle tool call — works for both direct and subagent invocations + return saveToStorage(params.Content) + }), + copilot.DefineTool("get_data", "Retrieves data from storage", + func(params GetParams, inv copilot.ToolInvocation) (string, error) { + return getData(params.Key) + }), + }, + CustomAgents: []copilot.CustomAgentConfig{ + { + Name: "researcher", + Description: "Researches topics and saves findings", + Tools: []string{"save_output"}, // Can only use save_output, not get_data + Prompt: "You are a research assistant. Save your findings using save_output.", + }, + { + Name: "analyst", + Description: "Analyzes data from storage", + Tools: nil, // Can access ALL custom tools + Prompt: "You are a data analyst.", + }, + }, +}) +``` + +When `researcher` is invoked as a subagent, it can call `save_output` but not `get_data`. +When `analyst` is invoked, it can call both tools. If a subagent attempts to use a tool +not in its allowlist, the SDK returns a `"Tool '{name}' is not supported by this client instance."` response to the LLM. + ## Streaming Enable streaming to receive assistant response chunks as they're generated: diff --git a/go/client.go b/go/client.go index dbb5a3d8f..0601c2af5 100644 --- a/go/client.go +++ b/go/client.go @@ -53,6 +53,14 @@ import ( const noResultPermissionV2Error = "permission handlers cannot return 'no-result' when connected to a protocol v2 server" +// subagentInstance represents a single active subagent launch. +type subagentInstance struct { + agentName string + toolCallID string + childSessionID string // empty until child session ID is known + startedAt time.Time +} + // Client manages the connection to the Copilot CLI server and provides session management. // // The Client can either spawn a CLI server process or connect to an existing server. @@ -73,14 +81,30 @@ const noResultPermissionV2Error = "permission handlers cannot return 'no-result' // } // defer client.Stop() type Client struct { - options ClientOptions - process *exec.Cmd - client *jsonrpc2.Client - actualPort int - actualHost string - state ConnectionState - sessions map[string]*Session - sessionsMux sync.Mutex + options ClientOptions + process *exec.Cmd + client *jsonrpc2.Client + actualPort int + actualHost string + state ConnectionState + sessions map[string]*Session + sessionsMux sync.Mutex + + // childToParent maps childSessionID → parentSessionID. + // Populated exclusively from authoritative protocol signals. + // Protected by sessionsMux. + childToParent map[string]string + + // childToAgent maps childSessionID → agentName. + // Used for allowlist enforcement. Populated alongside childToParent. + // Protected by sessionsMux. + childToAgent map[string]string + + // subagentInstances tracks active subagent launches per parent session. + // Key: parentSessionID → map of toolCallID → subagentInstance. + // Protected by sessionsMux. + subagentInstances map[string]map[string]*subagentInstance + isExternalServer bool conn net.Conn // stores net.Conn for external TCP connections useStdio bool // resolved value from options @@ -127,13 +151,16 @@ func NewClient(options *ClientOptions) *Client { } client := &Client{ - options: opts, - state: StateDisconnected, - sessions: make(map[string]*Session), - actualHost: "localhost", - isExternalServer: false, - useStdio: true, - autoStart: true, // default + options: opts, + state: StateDisconnected, + sessions: make(map[string]*Session), + childToParent: make(map[string]string), + childToAgent: make(map[string]string), + subagentInstances: make(map[string]map[string]*subagentInstance), + actualHost: "localhost", + isExternalServer: false, + useStdio: true, + autoStart: true, // default } if options != nil { @@ -346,6 +373,9 @@ func (c *Client) Stop() error { c.sessionsMux.Lock() c.sessions = make(map[string]*Session) + c.childToParent = make(map[string]string) + c.childToAgent = make(map[string]string) + c.subagentInstances = make(map[string]map[string]*subagentInstance) c.sessionsMux.Unlock() c.startStopMux.Lock() @@ -586,6 +616,12 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses // events emitted by the CLI (e.g. session.start) are not dropped. session := newSession(sessionID, c.client, "") + session.customAgents = config.CustomAgents + session.onDestroy = func() { + c.sessionsMux.Lock() + c.removeChildMappingsForParentLocked(session.SessionID) + c.sessionsMux.Unlock() + } session.registerTools(config.Tools) session.registerPermissionHandler(config.OnPermissionRequest) if config.OnUserInputRequest != nil { @@ -707,6 +743,12 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, // events emitted by the CLI (e.g. session.start) are not dropped. session := newSession(sessionID, c.client, "") + session.customAgents = config.CustomAgents + session.onDestroy = func() { + c.sessionsMux.Lock() + c.removeChildMappingsForParentLocked(session.SessionID) + c.sessionsMux.Unlock() + } session.registerTools(config.Tools) session.registerPermissionHandler(config.OnPermissionRequest) if config.OnUserInputRequest != nil { @@ -860,6 +902,7 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error { // Remove from local sessions map if present c.sessionsMux.Lock() delete(c.sessions, sessionID) + c.removeChildMappingsForParentLocked(sessionID) c.sessionsMux.Unlock() return nil @@ -1500,21 +1543,164 @@ func (c *Client) handleSessionEvent(req sessionEventRequest) { c.sessionsMux.Unlock() if ok { + // Intercept subagent lifecycle events for child tracking + c.handleSubagentEvent(req.SessionID, req.Event) session.dispatchEvent(req.Event) } } +// handleSubagentEvent intercepts subagent lifecycle events to manage child session tracking. +func (c *Client) handleSubagentEvent(parentSessionID string, event SessionEvent) { + switch event.Type { + case SessionEventTypeSubagentStarted: + c.onSubagentStarted(parentSessionID, event) + case SessionEventTypeSubagentCompleted, SessionEventTypeSubagentFailed: + c.onSubagentEnded(parentSessionID, event) + } +} + +// onSubagentStarted handles a subagent.started event by creating a subagent instance +// and mapping the child session to its parent. +func (c *Client) onSubagentStarted(parentSessionID string, event SessionEvent) { + toolCallID := derefStr(event.Data.ToolCallID) + agentName := derefStr(event.Data.AgentName) + childSessionID := derefStr(event.Data.RemoteSessionID) + + c.sessionsMux.Lock() + defer c.sessionsMux.Unlock() + + // Track instance by toolCallID (unique per launch). + // Skip tracking when toolCallID is empty — multiple launches would collide + // on the empty-string key and overwrite each other. + if toolCallID != "" { + if c.subagentInstances[parentSessionID] == nil { + c.subagentInstances[parentSessionID] = make(map[string]*subagentInstance) + } + c.subagentInstances[parentSessionID][toolCallID] = &subagentInstance{ + agentName: agentName, + toolCallID: toolCallID, + childSessionID: childSessionID, + startedAt: event.Timestamp, + } + } + + // Eagerly map child→parent and child→agent + if childSessionID != "" { + c.childToParent[childSessionID] = parentSessionID + c.childToAgent[childSessionID] = agentName + } +} + +// onSubagentEnded handles subagent.completed and subagent.failed events +// by removing the subagent instance. Child-to-parent mappings are NOT removed +// here because in-flight requests may still arrive after the subagent completes. +func (c *Client) onSubagentEnded(parentSessionID string, event SessionEvent) { + toolCallID := derefStr(event.Data.ToolCallID) + + c.sessionsMux.Lock() + defer c.sessionsMux.Unlock() + + if instances, ok := c.subagentInstances[parentSessionID]; ok { + delete(instances, toolCallID) + if len(instances) == 0 { + delete(c.subagentInstances, parentSessionID) + } + } +} + +// derefStr safely dereferences a string pointer, returning "" if nil. +func derefStr(s *string) string { + if s == nil { + return "" + } + return *s +} + +// resolveSession looks up a session by ID. If the ID is not a directly +// registered session, it checks whether it is a known child session and +// returns the parent session instead. +// +// Returns (session, isChild, error). isChild=true means the request came +// from a child session and was resolved via parent lineage. +// +// Lock contract: acquires and releases sessionsMux internally. +// Does NOT hold sessionsMux when returning. +func (c *Client) resolveSession(sessionID string) (*Session, bool, error) { + c.sessionsMux.Lock() + // Direct lookup + if session, ok := c.sessions[sessionID]; ok { + c.sessionsMux.Unlock() + return session, false, nil + } + // Child→parent lookup (authoritative mapping only) + parentID, isChild := c.childToParent[sessionID] + if !isChild { + c.sessionsMux.Unlock() + return nil, false, fmt.Errorf("unknown session %s", sessionID) + } + session, ok := c.sessions[parentID] + c.sessionsMux.Unlock() + if !ok { + return nil, false, fmt.Errorf("parent session %s for child %s not found", parentID, sessionID) + } + return session, true, nil +} + +// removeChildMappingsForParentLocked removes all child mappings for a parent session. +// MUST be called with sessionsMux held. +func (c *Client) removeChildMappingsForParentLocked(parentSessionID string) { + for childID, parentID := range c.childToParent { + if parentID == parentSessionID { + delete(c.childToParent, childID) + delete(c.childToAgent, childID) + } + } + delete(c.subagentInstances, parentSessionID) +} + +// isToolAllowedForChild checks whether a tool is in the allowlist for the agent +// that owns the given child session. +func (c *Client) isToolAllowedForChild(childSessionID, toolName string) bool { + c.sessionsMux.Lock() + agentName, ok := c.childToAgent[childSessionID] + c.sessionsMux.Unlock() + if !ok { + return false // unknown child → deny + } + + session, _, _ := c.resolveSession(childSessionID) + if session == nil { + return false + } + + agentConfig := session.getAgentConfig(agentName) + if agentConfig == nil { + return false // agent not found → deny + } + + // nil Tools = all tools allowed + if agentConfig.Tools == nil { + return true + } + + // Explicit list — check membership + for _, t := range agentConfig.Tools { + if t == toolName { + return true + } + } + return false +} + // handleUserInputRequest handles a user input request from the CLI server. func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputResponse, *jsonrpc2.Error) { if req.SessionID == "" || req.Question == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid user input request payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, _, err := c.resolveSession(req.SessionID) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} } response, err := session.handleUserInputRequest(UserInputRequest{ @@ -1535,11 +1721,9 @@ func (c *Client) handleHooksInvoke(req hooksInvokeRequest) (map[string]any, *jso return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid hooks invoke payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, _, err := c.resolveSession(req.SessionID) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} } output, err := session.handleHooksInvoke(req.Type, req.Input) @@ -1610,11 +1794,19 @@ func (c *Client) handleToolCallRequestV2(req toolCallRequestV2) (*toolCallRespon return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid tool call payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, isChild, err := c.resolveSession(req.SessionID) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} + } + + // For child sessions, enforce tool allowlist + if isChild && !c.isToolAllowedForChild(req.SessionID, req.ToolName) { + return &toolCallResponseV2{Result: ToolResult{ + TextResultForLLM: fmt.Sprintf("Tool '%s' is not supported by this client instance.", req.ToolName), + ResultType: "failure", + Error: fmt.Sprintf("tool '%s' not supported", req.ToolName), + ToolTelemetry: map[string]any{}, + }}, nil } handler, ok := session.getToolHandler(req.ToolName) @@ -1656,11 +1848,9 @@ func (c *Client) handlePermissionRequestV2(req permissionRequestV2) (*permission return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid permission request payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, _, err := c.resolveSession(req.SessionID) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} } handler := session.getPermissionHandler() diff --git a/go/client_subagent_test.go b/go/client_subagent_test.go new file mode 100644 index 000000000..a368c44db --- /dev/null +++ b/go/client_subagent_test.go @@ -0,0 +1,653 @@ +package copilot + +import ( + "encoding/json" + "strings" + "sync" + "testing" + "time" +) + +// newTestClient creates a minimal test client with initialized maps. +func newTestClient() *Client { + return &Client{ + sessions: make(map[string]*Session), + childToParent: make(map[string]string), + childToAgent: make(map[string]string), + subagentInstances: make(map[string]map[string]*subagentInstance), + } +} + +// newSubagentTestSession creates a minimal test session with tools and agents. +func newSubagentTestSession(id string, tools []Tool, agents []CustomAgentConfig) *Session { + s := &Session{ + SessionID: id, + toolHandlers: make(map[string]ToolHandler), + customAgents: agents, + } + for _, t := range tools { + if t.Name != "" && t.Handler != nil { + s.toolHandlers[t.Name] = t.Handler + } + } + return s +} + +func strPtr(s string) *string { return &s } + +func testToolHandler(inv ToolInvocation) (ToolResult, error) { + return ToolResult{TextResultForLLM: "ok", ResultType: "success"}, nil +} + +// --------------------------------------------------------------------------- +// TestResolveSession +// --------------------------------------------------------------------------- + +func TestResolveSession(t *testing.T) { + t.Run("direct_session_returns_session", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + + session, isChild, err := c.resolveSession("parent-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if isChild { + t.Fatal("expected isChild=false for direct session") + } + if session != parent { + t.Fatal("returned session does not match registered session") + } + }) + + t.Run("child_session_returns_parent", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + + session, isChild, err := c.resolveSession("child-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !isChild { + t.Fatal("expected isChild=true for child session") + } + if session != parent { + t.Fatal("returned session should be the parent session") + } + }) + + t.Run("unknown_session_returns_error", func(t *testing.T) { + c := newTestClient() + + session, isChild, err := c.resolveSession("nonexistent") + if err == nil { + t.Fatal("expected error for unknown session") + } + if !strings.Contains(err.Error(), "unknown session") { + t.Fatalf("error should contain 'unknown session', got: %v", err) + } + if isChild { + t.Fatal("expected isChild=false") + } + if session != nil { + t.Fatal("expected nil session") + } + }) + + t.Run("child_of_deleted_parent_returns_error", func(t *testing.T) { + c := newTestClient() + c.childToParent["child-1"] = "parent-1" + // parent-1 is NOT registered in c.sessions + + session, isChild, err := c.resolveSession("child-1") + if err == nil { + t.Fatal("expected error when parent session is missing") + } + if !strings.Contains(err.Error(), "parent session") { + t.Fatalf("error should contain 'parent session', got: %v", err) + } + if isChild { + t.Fatal("expected isChild=false on error path") + } + if session != nil { + t.Fatal("expected nil session") + } + }) +} + +// --------------------------------------------------------------------------- +// TestChildToolAllowlist +// --------------------------------------------------------------------------- + +func TestChildToolAllowlist(t *testing.T) { + setup := func(tools []string) *Client { + c := newTestClient() + agents := []CustomAgentConfig{{Name: "test-agent", Tools: tools}} + parent := newSubagentTestSession("parent-1", []Tool{ + {Name: "save_output", Handler: testToolHandler}, + {Name: "other_tool", Handler: testToolHandler}, + }, agents) + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + return c + } + + t.Run("nil_tools_allows_all", func(t *testing.T) { + c := setup(nil) // nil Tools = all allowed + if !c.isToolAllowedForChild("child-1", "save_output") { + t.Fatal("nil Tools should allow save_output") + } + if !c.isToolAllowedForChild("child-1", "other_tool") { + t.Fatal("nil Tools should allow other_tool") + } + if !c.isToolAllowedForChild("child-1", "any_random_tool") { + t.Fatal("nil Tools should allow any tool") + } + }) + + t.Run("explicit_list_allows_listed_tool", func(t *testing.T) { + c := setup([]string{"save_output"}) + if !c.isToolAllowedForChild("child-1", "save_output") { + t.Fatal("save_output should be allowed") + } + }) + + t.Run("explicit_list_blocks_unlisted_tool", func(t *testing.T) { + c := setup([]string{"save_output"}) + if c.isToolAllowedForChild("child-1", "other_tool") { + t.Fatal("other_tool should be blocked") + } + }) + + t.Run("empty_tools_blocks_all", func(t *testing.T) { + c := setup([]string{}) // empty = block all + if c.isToolAllowedForChild("child-1", "save_output") { + t.Fatal("empty Tools should block save_output") + } + if c.isToolAllowedForChild("child-1", "other_tool") { + t.Fatal("empty Tools should block other_tool") + } + }) +} + +// --------------------------------------------------------------------------- +// TestSubagentInstanceTracking +// --------------------------------------------------------------------------- + +func TestSubagentInstanceTracking(t *testing.T) { + makeEvent := func(evType SessionEventType, toolCallID, agentName, childSessionID string) SessionEvent { + return SessionEvent{ + Type: evType, + Timestamp: time.Now(), + Data: Data{ + ToolCallID: strPtr(toolCallID), + AgentName: strPtr(agentName), + RemoteSessionID: strPtr(childSessionID), + }, + } + } + + t.Run("started_creates_instance", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + + event := makeEvent(SessionEventTypeSubagentStarted, "tc-1", "my-agent", "child-session-1") + c.onSubagentStarted("parent-1", event) + + // Verify subagentInstances + instances, ok := c.subagentInstances["parent-1"] + if !ok { + t.Fatal("expected subagentInstances entry for parent-1") + } + inst, ok := instances["tc-1"] + if !ok { + t.Fatal("expected instance with toolCallID tc-1") + } + if inst.agentName != "my-agent" { + t.Fatalf("expected agentName 'my-agent', got %q", inst.agentName) + } + if inst.childSessionID != "child-session-1" { + t.Fatalf("expected childSessionID 'child-session-1', got %q", inst.childSessionID) + } + + // Verify child mappings + if c.childToParent["child-session-1"] != "parent-1" { + t.Fatal("childToParent mapping not set") + } + if c.childToAgent["child-session-1"] != "my-agent" { + t.Fatal("childToAgent mapping not set") + } + }) + + t.Run("completed_removes_instance", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + + startEvent := makeEvent(SessionEventTypeSubagentStarted, "tc-1", "my-agent", "child-session-1") + c.onSubagentStarted("parent-1", startEvent) + + endEvent := makeEvent(SessionEventTypeSubagentCompleted, "tc-1", "my-agent", "child-session-1") + c.onSubagentEnded("parent-1", endEvent) + + // Instance removed + if instances, ok := c.subagentInstances["parent-1"]; ok && len(instances) > 0 { + t.Fatal("expected instance to be removed after completion") + } + + // Child mappings preserved for in-flight requests + if c.childToParent["child-session-1"] != "parent-1" { + t.Fatal("childToParent should be preserved after subagent completion") + } + if c.childToAgent["child-session-1"] != "my-agent" { + t.Fatal("childToAgent should be preserved after subagent completion") + } + }) + + t.Run("concurrent_same_agent_tracked_independently", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + + // Two launches of the same agent with different toolCallIDs + event1 := makeEvent(SessionEventTypeSubagentStarted, "tc-1", "my-agent", "child-1") + event2 := makeEvent(SessionEventTypeSubagentStarted, "tc-2", "my-agent", "child-2") + c.onSubagentStarted("parent-1", event1) + c.onSubagentStarted("parent-1", event2) + + instances := c.subagentInstances["parent-1"] + if len(instances) != 2 { + t.Fatalf("expected 2 instances, got %d", len(instances)) + } + + // Complete one + endEvent := makeEvent(SessionEventTypeSubagentCompleted, "tc-1", "my-agent", "child-1") + c.onSubagentEnded("parent-1", endEvent) + + instances = c.subagentInstances["parent-1"] + if len(instances) != 1 { + t.Fatalf("expected 1 instance remaining, got %d", len(instances)) + } + if _, ok := instances["tc-2"]; !ok { + t.Fatal("tc-2 should still be tracked") + } + }) +} + +// --------------------------------------------------------------------------- +// TestRequestHandlerResolution +// --------------------------------------------------------------------------- + +func TestRequestHandlerResolution(t *testing.T) { + t.Run("tool_call_resolves_child_session", func(t *testing.T) { + c := newTestClient() + agents := []CustomAgentConfig{{Name: "test-agent", Tools: nil}} // nil = all tools + parent := newSubagentTestSession("parent-1", []Tool{ + {Name: "my_tool", Handler: testToolHandler}, + }, agents) + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + + resp, rpcErr := c.handleToolCallRequestV2(toolCallRequestV2{ + SessionID: "child-1", + ToolCallID: "tc-1", + ToolName: "my_tool", + Arguments: map[string]any{}, + }) + if rpcErr != nil { + t.Fatalf("unexpected RPC error: %v", rpcErr.Message) + } + if resp.Result.ResultType != "success" { + t.Fatalf("expected success result, got %q", resp.Result.ResultType) + } + if resp.Result.TextResultForLLM != "ok" { + t.Fatalf("expected 'ok', got %q", resp.Result.TextResultForLLM) + } + }) + + t.Run("permission_request_resolves_child_session", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + parent.permissionHandler = func(req PermissionRequest, inv PermissionInvocation) (PermissionRequestResult, error) { + return PermissionRequestResult{Kind: "approved"}, nil + } + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + + resp, rpcErr := c.handlePermissionRequestV2(permissionRequestV2{ + SessionID: "child-1", + Request: PermissionRequest{Kind: "file_write"}, + }) + if rpcErr != nil { + t.Fatalf("unexpected RPC error: %v", rpcErr.Message) + } + if resp.Result.Kind != "approved" { + t.Fatalf("expected 'approved', got %q", resp.Result.Kind) + } + }) + + t.Run("user_input_resolves_child_session", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + parent.userInputHandler = func(req UserInputRequest, inv UserInputInvocation) (UserInputResponse, error) { + return UserInputResponse{Answer: "test-answer"}, nil + } + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + + resp, rpcErr := c.handleUserInputRequest(userInputRequest{ + SessionID: "child-1", + Question: "What is your name?", + }) + if rpcErr != nil { + t.Fatalf("unexpected RPC error: %v", rpcErr.Message) + } + if resp.Answer != "test-answer" { + t.Fatalf("expected 'test-answer', got %q", resp.Answer) + } + }) + + t.Run("hooks_invoke_resolves_child_session", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + parent.hooks = &SessionHooks{ + OnPreToolUse: func(input PreToolUseHookInput, inv HookInvocation) (*PreToolUseHookOutput, error) { + return &PreToolUseHookOutput{PermissionDecision: "allow"}, nil + }, + } + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + + hookInput, _ := json.Marshal(PreToolUseHookInput{ + Timestamp: time.Now().Unix(), + Cwd: "/tmp", + ToolName: "some_tool", + }) + + result, rpcErr := c.handleHooksInvoke(hooksInvokeRequest{ + SessionID: "child-1", + Type: "preToolUse", + Input: json.RawMessage(hookInput), + }) + if rpcErr != nil { + t.Fatalf("unexpected RPC error: %v", rpcErr.Message) + } + if result == nil { + t.Fatal("expected non-nil result") + } + if result["output"] == nil { + t.Fatal("expected output in result") + } + }) + + t.Run("tool_call_child_denied_tool_returns_unsupported", func(t *testing.T) { + c := newTestClient() + agents := []CustomAgentConfig{{Name: "test-agent", Tools: []string{"allowed_tool"}}} + parent := newSubagentTestSession("parent-1", []Tool{ + {Name: "allowed_tool", Handler: testToolHandler}, + {Name: "denied_tool", Handler: testToolHandler}, + }, agents) + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + + resp, rpcErr := c.handleToolCallRequestV2(toolCallRequestV2{ + SessionID: "child-1", + ToolCallID: "tc-1", + ToolName: "denied_tool", + Arguments: map[string]any{}, + }) + // Should NOT return an RPC error — returns an unsupported tool result instead + if rpcErr != nil { + t.Fatalf("should not return RPC error for denied tool, got: %v", rpcErr.Message) + } + if resp.Result.ResultType != "failure" { + t.Fatalf("expected failure result, got %q", resp.Result.ResultType) + } + if !strings.Contains(resp.Result.TextResultForLLM, "not supported") { + t.Fatalf("expected 'not supported' message, got %q", resp.Result.TextResultForLLM) + } + }) +} + +// --------------------------------------------------------------------------- +// TestCleanup +// --------------------------------------------------------------------------- + +func TestCleanup(t *testing.T) { + t.Run("stop_clears_all_maps", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + c.subagentInstances["parent-1"] = map[string]*subagentInstance{ + "tc-1": {agentName: "test-agent", toolCallID: "tc-1"}, + } + + // Simulate cleanup (Stop() does RPC + map clearing; we test removeChildMappingsForParentLocked + manual clear) + c.sessionsMux.Lock() + c.removeChildMappingsForParentLocked("parent-1") + delete(c.sessions, "parent-1") + c.sessionsMux.Unlock() + + if len(c.childToParent) != 0 { + t.Fatal("childToParent should be empty") + } + if len(c.childToAgent) != 0 { + t.Fatal("childToAgent should be empty") + } + if len(c.subagentInstances) != 0 { + t.Fatal("subagentInstances should be empty") + } + if len(c.sessions) != 0 { + t.Fatal("sessions should be empty") + } + }) + + t.Run("delete_session_clears_only_target_children", func(t *testing.T) { + c := newTestClient() + parentA := newSubagentTestSession("parent-A", nil, nil) + parentB := newSubagentTestSession("parent-B", nil, nil) + c.sessions["parent-A"] = parentA + c.sessions["parent-B"] = parentB + c.childToParent["child-A1"] = "parent-A" + c.childToParent["child-A2"] = "parent-A" + c.childToParent["child-B1"] = "parent-B" + c.childToAgent["child-A1"] = "agent-a" + c.childToAgent["child-A2"] = "agent-a" + c.childToAgent["child-B1"] = "agent-b" + c.subagentInstances["parent-A"] = map[string]*subagentInstance{ + "tc-a1": {agentName: "agent-a"}, + } + c.subagentInstances["parent-B"] = map[string]*subagentInstance{ + "tc-b1": {agentName: "agent-b"}, + } + + c.sessionsMux.Lock() + c.removeChildMappingsForParentLocked("parent-A") + c.sessionsMux.Unlock() + + // parent-A children removed + if _, ok := c.childToParent["child-A1"]; ok { + t.Fatal("child-A1 should be removed") + } + if _, ok := c.childToParent["child-A2"]; ok { + t.Fatal("child-A2 should be removed") + } + if _, ok := c.subagentInstances["parent-A"]; ok { + t.Fatal("parent-A subagentInstances should be removed") + } + + // parent-B children intact + if c.childToParent["child-B1"] != "parent-B" { + t.Fatal("child-B1 mapping should still exist") + } + if c.childToAgent["child-B1"] != "agent-b" { + t.Fatal("child-B1 agent mapping should still exist") + } + if _, ok := c.subagentInstances["parent-B"]; !ok { + t.Fatal("parent-B subagentInstances should still exist") + } + }) + + t.Run("destroy_session_clears_children_via_callback", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent-1", nil, nil) + c.sessions["parent-1"] = parent + c.childToParent["child-1"] = "parent-1" + c.childToAgent["child-1"] = "test-agent" + c.subagentInstances["parent-1"] = map[string]*subagentInstance{ + "tc-1": {agentName: "test-agent"}, + } + + // Set up onDestroy callback (mirrors Client's real onDestroy which only clears child mappings) + parent.onDestroy = func() { + c.sessionsMux.Lock() + defer c.sessionsMux.Unlock() + c.removeChildMappingsForParentLocked("parent-1") + } + + // Call onDestroy + parent.onDestroy() + + if len(c.childToParent) != 0 { + t.Fatal("childToParent should be cleared by onDestroy") + } + if len(c.childToAgent) != 0 { + t.Fatal("childToAgent should be cleared by onDestroy") + } + if len(c.subagentInstances) != 0 { + t.Fatal("subagentInstances should be cleared by onDestroy") + } + // Session itself is NOT removed by onDestroy (that's Destroy()'s job via RPC) + if _, ok := c.sessions["parent-1"]; !ok { + t.Fatal("session should still exist after onDestroy (only child mappings cleared)") + } + }) +} + +// --------------------------------------------------------------------------- +// TestSessionIsolation +// --------------------------------------------------------------------------- + +func TestSessionIsolation(t *testing.T) { + t.Run("child_cannot_reach_other_parent", func(t *testing.T) { + c := newTestClient() + parentA := newSubagentTestSession("parent-A", nil, nil) + parentB := newSubagentTestSession("parent-B", nil, nil) + c.sessions["parent-A"] = parentA + c.sessions["parent-B"] = parentB + c.childToParent["child-A"] = "parent-A" + + session, isChild, err := c.resolveSession("child-A") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !isChild { + t.Fatal("expected isChild=true") + } + if session != parentA { + t.Fatal("child-A should resolve to parent-A, not parent-B") + } + if session == parentB { + t.Fatal("child-A must not resolve to parent-B") + } + }) + + t.Run("child_session_id_immutable_mapping", func(t *testing.T) { + c := newTestClient() + parentA := newSubagentTestSession("parent-A", nil, nil) + c.sessions["parent-A"] = parentA + c.childToParent["child-1"] = "parent-A" + + // Resolve multiple times — always gets parent-A + for i := 0; i < 5; i++ { + session, isChild, err := c.resolveSession("child-1") + if err != nil { + t.Fatalf("iteration %d: unexpected error: %v", i, err) + } + if !isChild { + t.Fatalf("iteration %d: expected isChild=true", i) + } + if session != parentA { + t.Fatalf("iteration %d: mapping should consistently resolve to parent-A", i) + } + } + }) +} + +// --------------------------------------------------------------------------- +// TestConcurrency +// --------------------------------------------------------------------------- + +func TestConcurrency(t *testing.T) { + t.Run("concurrent_resolve_session_safe", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent", nil, nil) + c.sessions["parent"] = parent + c.childToParent["child-1"] = "parent" + c.childToAgent["child-1"] = "agent" + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + c.resolveSession("parent") + c.resolveSession("child-1") + c.resolveSession("nonexistent") + }() + } + wg.Wait() + }) + + t.Run("concurrent_subagent_events_safe", func(t *testing.T) { + c := newTestClient() + parent := newSubagentTestSession("parent", nil, nil) + c.sessions["parent"] = parent + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + tcID := "tc-" + strings.Repeat("x", idx%10) + childID := "child-" + strings.Repeat("x", idx%10) + event := SessionEvent{ + Type: SessionEventTypeSubagentStarted, + Timestamp: time.Now(), + Data: Data{ + ToolCallID: strPtr(tcID), + AgentName: strPtr("agent"), + RemoteSessionID: strPtr(childID), + }, + } + c.handleSubagentEvent("parent", event) + + // Also try resolving concurrently + c.resolveSession(childID) + c.isToolAllowedForChild(childID, "some_tool") + + endEvent := SessionEvent{ + Type: SessionEventTypeSubagentCompleted, + Timestamp: time.Now(), + Data: Data{ + ToolCallID: strPtr(tcID), + }, + } + c.handleSubagentEvent("parent", endEvent) + }(i) + } + wg.Wait() + }) +} diff --git a/go/internal/e2e/subagent_tool_test.go b/go/internal/e2e/subagent_tool_test.go new file mode 100644 index 000000000..08efa5746 --- /dev/null +++ b/go/internal/e2e/subagent_tool_test.go @@ -0,0 +1,140 @@ +//go:build integration + +package e2e + +import ( + "strings" + "testing" + "time" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +// TestSubagentCustomTools requires a real CLI to test the full round-trip of +// subagent child sessions invoking custom tools registered on parent sessions. +// +// Run with: +// +// cd go && go test -tags integration -v ./internal/e2e -run TestSubagentCustomTools +// +// Prerequisites: +// - Copilot CLI installed (or COPILOT_CLI_PATH set) +// - Valid GitHub authentication configured +func TestSubagentCustomTools(t *testing.T) { + ctx := testharness.NewTestContext(t) + client := ctx.NewClient() + t.Cleanup(func() { client.ForceStop() }) + + t.Run("subagent invokes parent custom tool", func(t *testing.T) { + ctx.ConfigureForTest(t) + + // Track tool invocations + toolInvoked := make(chan string, 1) + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Tools: []copilot.Tool{ + copilot.DefineTool("save_result", "Saves a result string", + func(params struct { + Result string `json:"result" jsonschema:"The result to save"` + }, inv copilot.ToolInvocation) (string, error) { + select { + case toolInvoked <- params.Result: + default: + } + return "saved: " + params.Result, nil + }), + }, + CustomAgents: []copilot.CustomAgentConfig{ + { + Name: "helper-agent", + DisplayName: "Helper Agent", + Description: "A helper agent that can save results using the save_result tool", + Tools: []string{"save_result"}, + Prompt: "You are a helper agent. When asked to save something, use the save_result tool.", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // Send a message that should trigger the subagent which invokes the custom tool + _, err = session.Send(t.Context(), copilot.MessageOptions{ + Prompt: "Use the helper-agent to save the result 'hello world'", + }) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + + // Wait for the tool to be invoked (with timeout) + select { + case result := <-toolInvoked: + if !strings.Contains(strings.ToLower(result), "hello world") { + t.Errorf("Expected tool to receive 'hello world', got %q", result) + } + case <-time.After(30 * time.Second): + t.Fatal("Timeout waiting for save_result tool invocation from subagent") + } + + // Get the final response + answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) + } + if answer.Data.Content == nil { + t.Fatal("Expected non-nil content in response") + } + t.Logf("Response: %s", *answer.Data.Content) + }) + + t.Run("subagent denied unlisted tool returns unsupported", func(t *testing.T) { + ctx.ConfigureForTest(t) + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Tools: []copilot.Tool{ + copilot.DefineTool("allowed_tool", "An allowed tool", + func(params struct{}, inv copilot.ToolInvocation) (string, error) { + return "allowed", nil + }), + copilot.DefineTool("restricted_tool", "A restricted tool", + func(params struct{}, inv copilot.ToolInvocation) (string, error) { + t.Error("restricted_tool should not be invoked by subagent") + return "should not reach here", nil + }), + }, + CustomAgents: []copilot.CustomAgentConfig{ + { + Name: "restricted-agent", + DisplayName: "Restricted Agent", + Description: "An agent with limited tool access", + Tools: []string{"allowed_tool"}, // restricted_tool NOT listed + Prompt: "You are a restricted agent. Try to use both allowed_tool and restricted_tool.", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + _, err = session.Send(t.Context(), copilot.MessageOptions{ + Prompt: "Use the restricted-agent to invoke restricted_tool", + }) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + + answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) + } + if answer.Data.Content == nil { + t.Fatal("Expected assistant message content but got nil") + } + t.Logf("Response: %s", *answer.Data.Content) + // The primary assertion is the t.Error inside restricted_tool's handler above. + // We don't assert on response text because LLM output is non-deterministic. + }) +} diff --git a/go/session.go b/go/session.go index 5be626b52..28015e17c 100644 --- a/go/session.go +++ b/go/session.go @@ -66,6 +66,8 @@ type Session struct { hooksMux sync.RWMutex transformCallbacks map[string]SectionTransformFn transformMu sync.Mutex + onDestroy func() // set by Client when session is created; called by Disconnect() (Destroy() delegates to Disconnect()) + customAgents []CustomAgentConfig // agent configs from SessionConfig // eventCh serializes user event handler dispatch. dispatchEvent enqueues; // a single goroutine (processEvents) dequeues and invokes handlers in FIFO order. @@ -83,6 +85,16 @@ func (s *Session) WorkspacePath() string { return s.workspacePath } +// getAgentConfig returns the CustomAgentConfig for the given agent name, or nil if not found. +func (s *Session) getAgentConfig(agentName string) *CustomAgentConfig { + for i := range s.customAgents { + if s.customAgents[i].Name == agentName { + return &s.customAgents[i] + } + } + return nil +} + // newSession creates a new session wrapper with the given session ID and client. func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) *Session { s := &Session{ @@ -748,6 +760,10 @@ func (s *Session) Disconnect() error { s.permissionHandler = nil s.permissionMux.Unlock() + if s.onDestroy != nil { + s.onDestroy() + } + return nil } diff --git a/test/snapshots/subagent_tool/subagent_denied_unlisted_tool_returns_unsupported.yaml b/test/snapshots/subagent_tool/subagent_denied_unlisted_tool_returns_unsupported.yaml new file mode 100644 index 000000000..b963a2a1b --- /dev/null +++ b/test/snapshots/subagent_tool/subagent_denied_unlisted_tool_returns_unsupported.yaml @@ -0,0 +1,18 @@ +# Placeholder snapshot for subagent denied tool E2E test. +# This snapshot needs to be captured from a real CLI session. +# +# To capture: +# 1. Ensure COPILOT_CLI_PATH is set or CLI is installed +# 2. Run: cd go && go test -tags integration -v ./internal/e2e -run TestSubagentCustomTools +# 3. The proxy will capture the exchanges and write this file +# +# Expected flow: +# 1. Parent creates session with allowed_tool and restricted_tool, plus restricted-agent +# 2. restricted-agent only has allowed_tool in its Tools allowlist +# 3. User asks restricted-agent to invoke restricted_tool +# 4. CLI sends tool.call with child session ID for restricted_tool +# 5. SDK resolves child→parent, checks allowlist, returns "unsupported" error +# 6. restricted_tool handler is never invoked +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/subagent_tool/subagent_invokes_parent_custom_tool.yaml b/test/snapshots/subagent_tool/subagent_invokes_parent_custom_tool.yaml new file mode 100644 index 000000000..c544f9ac1 --- /dev/null +++ b/test/snapshots/subagent_tool/subagent_invokes_parent_custom_tool.yaml @@ -0,0 +1,19 @@ +# Placeholder snapshot for subagent custom tool E2E test. +# This snapshot needs to be captured from a real CLI session. +# +# To capture: +# 1. Ensure COPILOT_CLI_PATH is set or CLI is installed +# 2. Run: cd go && go test -tags integration -v ./internal/e2e -run TestSubagentCustomTools +# 3. The proxy will capture the exchanges and write this file +# +# Expected flow: +# 1. Parent creates session with save_result tool and helper-agent subagent +# 2. User sends message triggering helper-agent +# 3. CLI creates child session, emits subagent.started with RemoteSessionID +# 4. Child LLM invokes save_result tool +# 5. CLI sends tool.call with child session ID +# 6. SDK resolves child→parent, dispatches to save_result handler +# 7. Tool result returned, subagent completes +models: + - claude-sonnet-4.5 +conversations: []