Skip to content

Commit 4bb5c4b

Browse files
committed
Support request cancellation
1 parent f525dbe commit 4bb5c4b

File tree

3 files changed

+111
-29
lines changed

3 files changed

+111
-29
lines changed

Core/Sources/GitHubCopilotService/CopilotLocalProcessServer.swift

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ import ProcessEnv
99
/// We need it because the original one does not allow us to handle custom notifications.
1010
class CopilotLocalProcessServer {
1111
private let transport: StdioDataTransport
12+
private let customTransport: CustomDataTransport
1213
private let process: Process
1314
private var wrappedServer: CustomJSONRPCLanguageServer?
1415
var terminationHandler: (() -> Void)?
16+
@MainActor var ongoingCompletionRequestIDs: [JSONId] = []
1517

1618
public convenience init(
1719
path: String,
@@ -29,10 +31,27 @@ class CopilotLocalProcessServer {
2931

3032
init(executionParameters parameters: Process.ExecutionParameters) {
3133
transport = StdioDataTransport()
32-
wrappedServer = CustomJSONRPCLanguageServer(dataTransport: transport)
34+
let framing = SeperatedHTTPHeaderMessageFraming()
35+
let messageTransport = MessageTransport(
36+
dataTransport: transport,
37+
messageProtocol: framing
38+
)
39+
customTransport = CustomDataTransport(nextTransport: messageTransport)
40+
wrappedServer = CustomJSONRPCLanguageServer(dataTransport: customTransport)
3341

3442
process = Process()
3543

44+
// Because the implementation of LanguageClient is so closed,
45+
// we need to get the request IDs from a custom transport before the data
46+
// is written to the language server.
47+
customTransport.onWriteRequest = { [weak self] request in
48+
if request.method == "getCompletionsCycling" {
49+
Task { @MainActor [weak self] in
50+
self?.ongoingCompletionRequestIDs.append(request.id)
51+
}
52+
}
53+
}
54+
3655
process.standardInput = transport.stdinPipe
3756
process.standardOutput = transport.stdoutPipe
3857
process.standardError = transport.stderrPipe
@@ -89,6 +108,27 @@ extension CopilotLocalProcessServer: LanguageServerProtocol.Server {
89108

90109
server.sendNotification(notif, completionHandler: completionHandler)
91110
}
111+
112+
/// Cancel ongoing completion requests.
113+
public func cancelOngoingTasks() async {
114+
guard let server = wrappedServer, process.isRunning else {
115+
return
116+
}
117+
118+
let task = Task { @MainActor in
119+
for id in self.ongoingCompletionRequestIDs {
120+
switch id {
121+
case let .numericId(id):
122+
try? await server.sendNotification(.protocolCancelRequest(.init(id: id)))
123+
case let .stringId(id):
124+
try? await server.sendNotification(.protocolCancelRequest(.init(id: id)))
125+
}
126+
}
127+
self.ongoingCompletionRequestIDs = []
128+
}
129+
130+
await task.value
131+
}
92132

93133
public func sendRequest<Response: Codable>(
94134
_ request: ClientRequest,
@@ -139,13 +179,7 @@ final class CustomJSONRPCLanguageServer: Server {
139179
}
140180

141181
convenience init(dataTransport: DataTransport) {
142-
let framing = SeperatedHTTPHeaderMessageFraming()
143-
let messageTransport = MessageTransport(
144-
dataTransport: dataTransport,
145-
messageProtocol: framing
146-
)
147-
148-
self.init(protocolTransport: ProtocolTransport(dataTransport: messageTransport))
182+
self.init(protocolTransport: ProtocolTransport(dataTransport: dataTransport))
149183
}
150184

151185
deinit {
@@ -219,3 +253,4 @@ extension CustomJSONRPCLanguageServer {
219253
internalServer.sendRequest(request, completionHandler: completionHandler)
220254
}
221255
}
256+
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import Foundation
2+
import JSONRPC
3+
import os.log
4+
5+
public class CustomDataTransport: DataTransport {
6+
let nextTransport: DataTransport
7+
8+
var onWriteRequest: (JSONRPCRequest<JSONValue>) -> Void = { _ in }
9+
10+
init(nextTransport: DataTransport) {
11+
self.nextTransport = nextTransport
12+
}
13+
14+
public func write(_ data: Data) {
15+
if let request = try? JSONDecoder().decode(JSONRPCRequest<JSONValue>.self, from: data) {
16+
onWriteRequest(request)
17+
}
18+
19+
nextTransport.write(data)
20+
}
21+
22+
public func setReaderHandler(_ handler: @escaping ReadHandler) {
23+
nextTransport.setReaderHandler(handler)
24+
}
25+
26+
public func close() {
27+
nextTransport.close()
28+
}
29+
}
30+

Core/Sources/GitHubCopilotService/GitHubCopilotService.swift

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ enum GitHubCopilotError: Error, LocalizedError {
5252
public class GitHubCopilotBaseService {
5353
let projectRootURL: URL
5454
var server: GitHubCopilotLSP
55+
var localProcessServer: CopilotLocalProcessServer?
5556

5657
init(designatedServer: GitHubCopilotLSP) {
5758
projectRootURL = URL(fileURLWithPath: "/")
@@ -60,7 +61,7 @@ public class GitHubCopilotBaseService {
6061

6162
init(projectRootURL: URL) throws {
6263
self.projectRootURL = projectRootURL
63-
server = try {
64+
let (server, localServer) = try {
6465
let urls = try GitHubCopilotBaseService.createFoldersIfNeeded()
6566
var userEnvPath = ProcessInfo.processInfo.userEnvironment["PATH"] ?? ""
6667
if userEnvPath.isEmpty {
@@ -124,6 +125,7 @@ public class GitHubCopilotBaseService {
124125
}()
125126
}
126127
let localServer = CopilotLocalProcessServer(executionParameters: executionParams)
128+
127129
localServer.logMessages = UserDefaults.shared.value(for: \.gitHubCopilotVerboseLog)
128130
localServer.notificationHandler = { _, respond in
129131
respond(.timeout)
@@ -156,8 +158,11 @@ public class GitHubCopilotBaseService {
156158
)
157159
}
158160

159-
return server
161+
return (server, localServer)
160162
}()
163+
164+
self.server = server
165+
self.localProcessServer = localServer
161166
}
162167

163168
public static func createFoldersIfNeeded() throws -> (
@@ -238,6 +243,8 @@ public final class GitHubCopilotAuthService: GitHubCopilotBaseService,
238243
public final class GitHubCopilotSuggestionService: GitHubCopilotBaseService,
239244
GitHubCopilotSuggestionServiceType
240245
{
246+
private var ongoingTasks = Set<Task<[CodeSuggestion], Error>>()
247+
241248
override public init(projectRootURL: URL = URL(fileURLWithPath: "/")) throws {
242249
try super.init(projectRootURL: projectRootURL)
243250
}
@@ -272,27 +279,37 @@ public final class GitHubCopilotSuggestionService: GitHubCopilotBaseService,
272279
return filePath
273280
}()
274281

275-
let completions = try await server
276-
.sendRequest(GitHubCopilotRequest.GetCompletionsCycling(doc: .init(
277-
source: content,
278-
tabSize: tabSize,
279-
indentSize: indentSize,
280-
insertSpaces: !usesTabsForIndentation,
281-
path: fileURL.path,
282-
uri: fileURL.path,
283-
relativePath: relativePath,
284-
languageId: languageId,
285-
position: cursorPosition
286-
)))
287-
.completions
288-
.filter { completion in
289-
if ignoreSpaceOnlySuggestions {
290-
return !completion.text.allSatisfy { $0.isWhitespace || $0.isNewline }
282+
ongoingTasks.forEach { $0.cancel() }
283+
ongoingTasks.removeAll()
284+
await localProcessServer?.cancelOngoingTasks()
285+
286+
let task = Task {
287+
let completions = try await server
288+
.sendRequest(GitHubCopilotRequest.GetCompletionsCycling(doc: .init(
289+
source: content,
290+
tabSize: tabSize,
291+
indentSize: indentSize,
292+
insertSpaces: !usesTabsForIndentation,
293+
path: fileURL.path,
294+
uri: fileURL.path,
295+
relativePath: relativePath,
296+
languageId: languageId,
297+
position: cursorPosition
298+
)))
299+
.completions
300+
.filter { completion in
301+
if ignoreSpaceOnlySuggestions {
302+
return !completion.text.allSatisfy { $0.isWhitespace || $0.isNewline }
303+
}
304+
return true
291305
}
292-
return true
293-
}
306+
try Task.checkCancellation()
307+
return completions
308+
}
309+
310+
ongoingTasks.insert(task)
294311

295-
return completions
312+
return try await task.value
296313
}
297314

298315
public func notifyAccepted(_ completion: CodeSuggestion) async {

0 commit comments

Comments
 (0)