Skip to content

Commit 5c1adde

Browse files
committed
Add proper support for non-streaming OpenAI API
1 parent 26c7489 commit 5c1adde

File tree

6 files changed

+172
-60
lines changed

6 files changed

+172
-60
lines changed

Core/Sources/ChatPlugins/AITerminalChatPlugin.swift

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public actor AITerminalChatPlugin: ChatPlugin {
5252
delegate?.pluginDidEndResponding(self)
5353
await chatGPTService.mutateHistory { history in
5454
history.append(.init(role: .assistant, content: """
55-
Confirm to run?
55+
Should I run this command? You can instruct me to modify it again.
5656
```
5757
\(result)
5858
```
@@ -63,7 +63,7 @@ public actor AITerminalChatPlugin: ChatPlugin {
6363
await chatGPTService.mutateHistory { history in
6464
history.append(.init(
6565
role: .assistant,
66-
content: "Should I run it? Or should I modify it?"
66+
content: "Sorry, I don't understand. Do you want me to run it?"
6767
))
6868
}
6969
}
@@ -77,7 +77,7 @@ public actor AITerminalChatPlugin: ChatPlugin {
7777
if isCancelled { return }
7878
await chatGPTService.mutateHistory { history in
7979
history.append(.init(role: .assistant, content: """
80-
Confirm to run?
80+
Should I run this command? You can instruct me to modify it.
8181
```
8282
\(result)
8383
```
@@ -118,7 +118,7 @@ public actor AITerminalChatPlugin: ChatPlugin {
118118
return extractCodeFromMarkdown(try await askChatGPT(
119119
systemPrompt: p,
120120
question: "the task is: \"\(task)\""
121-
))
121+
) ?? "")
122122
}
123123

124124
func modifyCommand(command: String, requirement: String) async throws -> String {
@@ -139,7 +139,7 @@ public actor AITerminalChatPlugin: ChatPlugin {
139139
return extractCodeFromMarkdown(try await askChatGPT(
140140
systemPrompt: p,
141141
question: "The requirement is: \"\(requirement)\""
142-
))
142+
) ?? "")
143143
}
144144

145145
func checkConfirmation(content: String) async throws -> Tone {
@@ -165,7 +165,8 @@ public actor AITerminalChatPlugin: ChatPlugin {
165165
systemPrompt: p,
166166
question: "The content is: \"\(content)\""
167167
)
168-
return Tone(rawValue: Int(result) ?? 2) ?? .cancellation
168+
let tone = result.flatMap(Int.init).flatMap(Tone.init(rawValue:)) ?? .other
169+
return tone
169170
}
170171

171172
enum Tone: Int {

Core/Sources/ChatPlugins/AskChatGPT.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import Foundation
22
import OpenAIService
33

44
/// Quickly ask a question to ChatGPT.
5-
func askChatGPT(systemPrompt: String, question: String) async throws -> String {
5+
func askChatGPT(systemPrompt: String, question: String) async throws -> String? {
66
let service = ChatGPTService(systemPrompt: systemPrompt)
77
return try await service.sendAndWait(content: question)
88
}

Core/Sources/ChatPlugins/CallAIFunction.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ func callAIFunction(
77
function: String,
88
args: [Any?],
99
description: String
10-
) async throws -> String {
10+
) async throws -> String? {
1111
let args = args.map { arg -> String in
1212
if let arg = arg {
1313
return String(describing: arg)

Core/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 80 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ public actor ChatGPTService: ChatGPTServiceType {
8888
var uuidGenerator: () -> String = { UUID().uuidString }
8989
var cancelTask: Cancellable?
9090
var buildCompletionStreamAPI: CompletionStreamAPIBuilder = OpenAICompletionStreamAPI.init
91+
var buildCompletionAPI: CompletionAPIBuilder = OpenAICompletionAPI.init
9192

9293
public init(
9394
systemPrompt: String = "",
@@ -121,72 +122,99 @@ public actor ChatGPTService: ChatGPTServiceType {
121122

122123
isReceivingMessage = true
123124

124-
do {
125-
let api = buildCompletionStreamAPI(apiKey, url, requestBody)
125+
let api = buildCompletionStreamAPI(apiKey, url, requestBody)
126126

127-
return AsyncThrowingStream<String, Error> { continuation in
128-
Task {
129-
do {
130-
let (trunks, cancel) = try await api()
131-
guard isReceivingMessage else {
132-
continuation.finish()
133-
return
134-
}
135-
cancelTask = cancel
136-
for try await trunk in trunks {
137-
guard let delta = trunk.choices.first?.delta else { continue }
138-
139-
if history.last?.id == trunk.id {
140-
if let role = delta.role {
141-
history[history.endIndex - 1].role = role
142-
}
143-
if let content = delta.content {
144-
history[history.endIndex - 1].content.append(content)
145-
}
146-
} else {
147-
history.append(.init(
148-
id: trunk.id,
149-
role: delta.role ?? .assistant,
150-
content: delta.content ?? ""
151-
))
152-
}
127+
return AsyncThrowingStream<String, Error> { continuation in
128+
Task {
129+
do {
130+
let (trunks, cancel) = try await api()
131+
guard isReceivingMessage else {
132+
continuation.finish()
133+
return
134+
}
135+
cancelTask = cancel
136+
for try await trunk in trunks {
137+
guard let delta = trunk.choices.first?.delta else { continue }
153138

139+
if history.last?.id == trunk.id {
140+
if let role = delta.role {
141+
history[history.endIndex - 1].role = role
142+
}
154143
if let content = delta.content {
155-
continuation.yield(content)
144+
history[history.endIndex - 1].content.append(content)
156145
}
146+
} else {
147+
history.append(.init(
148+
id: trunk.id,
149+
role: delta.role ?? .assistant,
150+
content: delta.content ?? ""
151+
))
157152
}
158153

159-
continuation.finish()
160-
isReceivingMessage = false
161-
} catch let error as CancellationError {
162-
isReceivingMessage = false
163-
continuation.finish(throwing: error)
164-
} catch let error as NSError where error.code == NSURLErrorCancelled {
165-
isReceivingMessage = false
166-
continuation.finish(throwing: error)
167-
} catch {
168-
history.append(.init(
169-
role: .assistant,
170-
content: error.localizedDescription
171-
))
172-
isReceivingMessage = false
173-
continuation.finish(throwing: error)
154+
if let content = delta.content {
155+
continuation.yield(content)
156+
}
174157
}
158+
159+
continuation.finish()
160+
isReceivingMessage = false
161+
} catch let error as CancellationError {
162+
isReceivingMessage = false
163+
continuation.finish(throwing: error)
164+
} catch let error as NSError where error.code == NSURLErrorCancelled {
165+
isReceivingMessage = false
166+
continuation.finish(throwing: error)
167+
} catch {
168+
history.append(.init(
169+
role: .assistant,
170+
content: error.localizedDescription
171+
))
172+
isReceivingMessage = false
173+
continuation.finish(throwing: error)
175174
}
176175
}
177176
}
178177
}
179-
178+
180179
public func sendAndWait(
181180
content: String,
182181
summary: String? = nil
183-
) async throws -> String {
184-
let stream = try await send(content: content, summary: summary)
185-
var content = ""
186-
for try await fragment in stream {
187-
content.append(fragment)
182+
) async throws -> String? {
183+
guard !isReceivingMessage else { throw CancellationError() }
184+
guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect }
185+
let newMessage = ChatMessage(
186+
id: uuidGenerator(),
187+
role: .user,
188+
content: content,
189+
summary: summary
190+
)
191+
history.append(newMessage)
192+
193+
let requestBody = CompletionRequestBody(
194+
model: model,
195+
messages: combineHistoryWithSystemPrompt(),
196+
temperature: temperature,
197+
stream: true,
198+
max_tokens: maxToken
199+
)
200+
201+
isReceivingMessage = true
202+
defer { isReceivingMessage = false }
203+
204+
let api = buildCompletionAPI(apiKey, url, requestBody)
205+
let response = try await api()
206+
207+
if let choice = response.choices.first {
208+
history.append(.init(
209+
id: response.id,
210+
role: choice.message.role,
211+
content: choice.message.content
212+
))
213+
214+
return choice.message.content
188215
}
189-
return content
216+
217+
return nil
190218
}
191219

192220
public func stopReceivingMessage() {
@@ -231,7 +259,7 @@ extension ChatGPTService {
231259
all.append(.init(role: message.role, content: message.content))
232260
count += 1
233261
}
234-
262+
235263
all.append(.init(role: .system, content: systemPrompt))
236264
return all.reversed()
237265
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import Foundation
2+
3+
typealias CompletionAPIBuilder = (String, URL, CompletionRequestBody) -> CompletionAPI
4+
5+
protocol CompletionAPI {
6+
func callAsFunction() async throws -> CompletionResponseBody
7+
}
8+
9+
/// https://platform.openai.com/docs/api-reference/chat/create
10+
struct CompletionResponseBody: Codable, Equatable {
11+
struct Message: Codable, Equatable {
12+
var role: ChatMessage.Role
13+
var content: String
14+
}
15+
16+
struct Choice: Codable, Equatable {
17+
var message: Message
18+
var index: Int
19+
var finish_reason: String
20+
}
21+
22+
struct Usage: Codable, Equatable {
23+
var prompt_tokens: Int
24+
var completion_tokens: Int
25+
var total_tokens: Int
26+
}
27+
28+
var id: String
29+
var object: String
30+
var created: Int
31+
var model: String
32+
var usage: Usage
33+
var choices: [Choice]
34+
}
35+
36+
struct CompletionAPIError: Error, Codable, LocalizedError {
37+
struct E: Codable {
38+
var message: String
39+
var type: String
40+
var param: String
41+
var code: String
42+
}
43+
var error: E
44+
45+
var errorDescription: String? { error.message }
46+
}
47+
48+
struct OpenAICompletionAPI: CompletionAPI {
49+
var apiKey: String
50+
var endpoint: URL
51+
var requestBody: CompletionRequestBody
52+
53+
init(apiKey: String, endpoint: URL, requestBody: CompletionRequestBody) {
54+
self.apiKey = apiKey
55+
self.endpoint = endpoint
56+
self.requestBody = requestBody
57+
self.requestBody.stream = false
58+
}
59+
60+
func callAsFunction() async throws -> CompletionResponseBody {
61+
var request = URLRequest(url: endpoint)
62+
request.httpMethod = "POST"
63+
let encoder = JSONEncoder()
64+
request.httpBody = try encoder.encode(requestBody)
65+
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
66+
if !apiKey.isEmpty {
67+
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
68+
}
69+
70+
let (result, response) = try await URLSession.shared.data(for: request)
71+
guard let response = response as? HTTPURLResponse else {
72+
throw ChatGPTServiceError.responseInvalid
73+
}
74+
75+
guard response.statusCode == 200 else {
76+
let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result)
77+
throw error ?? ChatGPTServiceError.responseInvalid
78+
}
79+
80+
return try JSONDecoder().decode(CompletionResponseBody.self, from: result)
81+
}
82+
}

Core/Sources/OpenAIService/CompletionStreamAPI.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct OpenAICompletionStreamAPI: CompletionStreamAPI {
5959
self.apiKey = apiKey
6060
self.endpoint = endpoint
6161
self.requestBody = requestBody
62+
self.requestBody.stream = true
6263
}
6364

6465
func callAsFunction() async throws -> (

0 commit comments

Comments
 (0)