Skip to content

Commit 0f3f3e6

Browse files
committed
Fix cancellation
1 parent 52fe8af commit 0f3f3e6

2 files changed

Lines changed: 14 additions & 3 deletions

File tree

Pro

Submodule Pro updated from ce6f163 to 22e5f0a

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,13 @@ public class ChatGPTService: ChatGPTServiceType {
100100

101101
return Debugger.$id.withValue(.init()) {
102102
AsyncThrowingStream<String, Error> { continuation in
103-
Task(priority: .userInitiated) {
103+
let task = Task(priority: .userInitiated) {
104104
do {
105105
var functionCall: ChatMessage.FunctionCall?
106106
var functionCallMessageID = ""
107107
var isInitialCall = true
108108
loop: while functionCall != nil || isInitialCall {
109+
try Task.checkCancellation()
109110
isInitialCall = false
110111
if let call = functionCall {
111112
if !configuration.runFunctionsAutomatically {
@@ -121,6 +122,7 @@ public class ChatGPTService: ChatGPTServiceType {
121122
#endif
122123

123124
for try await content in stream {
125+
try Task.checkCancellation()
124126
switch content {
125127
case let .text(text):
126128
continuation.yield(text)
@@ -154,6 +156,9 @@ public class ChatGPTService: ChatGPTServiceType {
154156
continuation.finish(throwing: error)
155157
}
156158
}
159+
continuation.onTermination = { _ in
160+
task.cancel()
161+
}
157162
}
158163
}
159164
}
@@ -177,6 +182,7 @@ public class ChatGPTService: ChatGPTServiceType {
177182
var finalResult = message?.content
178183
var functionCall = message?.functionCall
179184
while let call = functionCall {
185+
try Task.checkCancellation()
180186
if !configuration.runFunctionsAutomatically {
181187
break
182188
}
@@ -270,12 +276,13 @@ extension ChatGPTService {
270276
#endif
271277

272278
return AsyncThrowingStream<StreamContent, Error> { continuation in
273-
Task {
279+
let task = Task {
274280
do {
275281
let (trunks, cancel) = try await api()
276282
cancelTask = cancel
277283
let proposedId = UUID().uuidString + String(Date().timeIntervalSince1970)
278284
for try await trunk in trunks {
285+
try Task.checkCancellation()
279286
guard let delta = trunk.choices?.first?.delta else { continue }
280287

281288
// The api will always return a function call with JSON object.
@@ -320,6 +327,10 @@ extension ChatGPTService {
320327
continuation.finish(throwing: error)
321328
}
322329
}
330+
331+
continuation.onTermination = { _ in
332+
task.cancel()
333+
}
323334
}
324335
}
325336

0 commit comments

Comments
 (0)