Skip to content

Commit 35fc87f

Browse files
committed
Extract some logics to CompletionStreamAPI
1 parent 061fc16 commit 35fc87f

File tree

6 files changed

+238
-111
lines changed

6 files changed

+238
-111
lines changed

Core/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,18 @@ public protocol ChatGPTServiceType {
88
func mutateSystemPrompt(_ newPrompt: String) async
99
}
1010

11-
public enum ChatGPTServiceError: Error {
11+
public enum ChatGPTServiceError: Error, LocalizedError {
1212
case endpointIncorrect
1313
case responseInvalid
14+
15+
public var errorDescription: String? {
16+
switch self {
17+
case .endpointIncorrect:
18+
return "ChatGPT endpoint is incorrect"
19+
case .responseInvalid:
20+
return "Response is invalid"
21+
}
22+
}
1423
}
1524

1625
public struct ChatGPTError: Error, Codable, LocalizedError {
@@ -45,12 +54,13 @@ public actor ChatGPTService: ChatGPTServiceType, ObservableObject {
4554
public var apiKey: String
4655
public var systemPrompt: String
4756
public var maxToken: Int
48-
public var history: [ChatGPTMessage] = [] {
57+
public var history: [ChatMessage] = [] {
4958
didSet { objectWillChange.send() }
5059
}
5160

5261
public internal(set) var isReceivingMessage = false
53-
var ongoingTask: URLSessionDataTask?
62+
var cancelTask: Cancellable?
63+
var buildCompletionStreamAPI: CompletionStreamAPIBuilder = OpenAICompletionStreamAPI.init
5464

5565
public init(
5666
systemPrompt: String,
@@ -74,55 +84,29 @@ public actor ChatGPTService: ChatGPTServiceType, ObservableObject {
7484
) async throws -> AsyncThrowingStream<String, Error> {
7585
guard !isReceivingMessage else { throw CancellationError() }
7686
guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect }
77-
let newMessage = ChatGPTMessage(role: .user, content: content, summary: summary)
87+
let newMessage = ChatMessage(role: .user, content: content, summary: summary)
7888
history.append(newMessage)
79-
var request = URLRequest(url: url)
80-
request.httpMethod = "POST"
8189

82-
let requestBody = ChatGPTRequest(
90+
let requestBody = CompletionRequestBody(
8391
model: model.rawValue,
8492
messages: combineHistoryWithSystemPrompt(),
8593
temperature: temperature,
8694
stream: true,
8795
max_tokens: maxToken
8896
)
89-
90-
let encoder = JSONEncoder()
91-
request.httpBody = try encoder.encode(requestBody)
92-
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
93-
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
9497

9598
isReceivingMessage = true
9699

97100
do {
98-
let (result, response) = try await URLSession.shared.bytes(for: request)
99-
ongoingTask = result.task
100-
101-
guard let response = response as? HTTPURLResponse else {
102-
throw ChatGPTServiceError.responseInvalid
103-
}
104-
guard response.statusCode == 200 else {
105-
let text = try await result.lines.reduce(into: "") { partialResult, current in
106-
partialResult += current
107-
}
108-
guard let data = text.data(using: .utf8)
109-
else { throw ChatGPTServiceError.responseInvalid }
110-
let decoder = JSONDecoder()
111-
let error = try? decoder.decode(ChatGPTError.self, from: data)
112-
throw error ?? ChatGPTServiceError.responseInvalid
113-
}
101+
let api = buildCompletionStreamAPI(apiKey, url, requestBody)
102+
let (trunks, cancel) = try await api()
103+
cancelTask = cancel
114104

115105
return AsyncThrowingStream<String, Error> { continuation in
116106
Task {
117107
do {
118-
for try await line in result.lines {
119-
let prefix = "data: "
120-
guard line.hasPrefix(prefix),
121-
let content = line.dropFirst(prefix.count).data(using: .utf8),
122-
let trunk = try? JSONDecoder()
123-
.decode(ChatGPTDataTrunk.self, from: content),
124-
let delta = trunk.choices.first?.delta
125-
else { continue }
108+
for try await trunk in trunks {
109+
guard let delta = trunk.choices.first?.delta else { continue }
126110

127111
if history.last?.id == trunk.id {
128112
if let role = delta.role {
@@ -158,8 +142,8 @@ public actor ChatGPTService: ChatGPTServiceType, ObservableObject {
158142
}
159143

160144
public func stopReceivingMessage() {
161-
ongoingTask?.cancel()
162-
ongoingTask = nil
145+
cancelTask?()
146+
cancelTask = nil
163147
isReceivingMessage = false
164148
}
165149

@@ -173,7 +157,11 @@ public actor ChatGPTService: ChatGPTServiceType, ObservableObject {
173157
}
174158

175159
extension ChatGPTService {
176-
func combineHistoryWithSystemPrompt() -> [ChatGPTMessage] {
160+
func changeBuildCompletionStreamAPI(_ builder: @escaping CompletionStreamAPIBuilder) {
161+
buildCompletionStreamAPI = builder
162+
}
163+
164+
func combineHistoryWithSystemPrompt() -> [ChatMessage] {
177165
if history.count > 4 {
178166
return [.init(role: .system, content: systemPrompt)] +
179167
history[history.endIndex - 4..<history.endIndex]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import AsyncAlgorithms
2+
import Foundation
3+
4+
typealias CompletionStreamAPIBuilder = (String, URL, CompletionRequestBody) -> CompletionStreamAPI
5+
6+
protocol CompletionStreamAPI {
7+
func callAsFunction() async throws -> (
8+
trunkStream: AsyncThrowingStream<CompletionStreamDataTrunk, Error>,
9+
cancel: Cancellable
10+
)
11+
}
12+
13+
/// https://platform.openai.com/docs/api-reference/chat/create
14+
struct CompletionRequestBody: Codable {
15+
var model: String
16+
var messages: [ChatMessage]
17+
var temperature: Double?
18+
var top_p: Double?
19+
var n: Double?
20+
var stream: Bool?
21+
var stop: [String]?
22+
var max_tokens: Int?
23+
var presence_penalty: Double?
24+
var frequency_penalty: Double?
25+
var logit_bias: [String: Double]?
26+
var user: String?
27+
}
28+
29+
struct CompletionStreamDataTrunk: Codable {
30+
var id: String
31+
var object: String
32+
var created: Int
33+
var model: String
34+
var choices: [Choice]
35+
36+
struct Choice: Codable {
37+
var delta: Delta
38+
var index: Int
39+
var finish_reason: String?
40+
41+
struct Delta: Codable {
42+
var role: ChatMessage.Role?
43+
var content: String?
44+
}
45+
}
46+
}
47+
48+
struct OpenAICompletionStreamAPI: CompletionStreamAPI {
49+
var apiKey: String
50+
var endpoint: URL
51+
var requestBody: CompletionRequestBody
52+
53+
init(apiKey: String, endpoint: URL, requestBody: CompletionRequestBody) {
54+
self.apiKey = apiKey
55+
self.endpoint = endpoint
56+
self.requestBody = requestBody
57+
}
58+
59+
func callAsFunction() async throws -> (
60+
trunkStream: AsyncThrowingStream<CompletionStreamDataTrunk, Error>,
61+
cancel: Cancellable
62+
) {
63+
var request = URLRequest(url: endpoint)
64+
request.httpMethod = "POST"
65+
let encoder = JSONEncoder()
66+
request.httpBody = try encoder.encode(requestBody)
67+
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
68+
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
69+
70+
let (result, response) = try await URLSession.shared.bytes(for: request)
71+
guard let response = response as? HTTPURLResponse else {
72+
throw ChatGPTServiceError.responseInvalid
73+
}
74+
75+
guard response.statusCode == 200 else {
76+
let text = try await result.lines.reduce(into: "") { partialResult, current in
77+
partialResult += current
78+
}
79+
guard let data = text.data(using: .utf8)
80+
else { throw ChatGPTServiceError.responseInvalid }
81+
let decoder = JSONDecoder()
82+
let error = try? decoder.decode(ChatGPTError.self, from: data)
83+
throw error ?? ChatGPTServiceError.responseInvalid
84+
}
85+
86+
return (
87+
AsyncThrowingStream<CompletionStreamDataTrunk, Error> { continuation in
88+
Task {
89+
do {
90+
for try await line in result.lines {
91+
let prefix = "data: "
92+
guard line.hasPrefix(prefix),
93+
let content = line.dropFirst(prefix.count).data(using: .utf8),
94+
let trunk = try? JSONDecoder()
95+
.decode(CompletionStreamDataTrunk.self, from: content)
96+
else { continue }
97+
continuation.yield(trunk)
98+
}
99+
continuation.finish()
100+
} catch {
101+
continuation.finish(throwing: error)
102+
}
103+
}
104+
},
105+
Cancellable {
106+
result.task.cancel()
107+
}
108+
)
109+
}
110+
}
111+
112+
Lines changed: 8 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import Foundation
22

3-
public struct ChatGPTMessage: Equatable, Codable {
3+
struct Cancellable {
4+
let cancel: () -> Void
5+
func callAsFunction() {
6+
cancel()
7+
}
8+
}
9+
10+
public struct ChatMessage: Equatable, Codable {
411
public enum Role: String, Codable, Equatable {
512
case system
613
case user
@@ -27,59 +34,3 @@ public enum ChatGPTModel: String {
2734
case gpt_4_32k = "gpt-4-32k"
2835
case gpt_4_32k_0314 = "gpt-4-32k-0314"
2936
}
30-
31-
/// https://platform.openai.com/docs/api-reference/chat/create
32-
struct ChatGPTRequest: Codable {
33-
var model: String
34-
var messages: [ChatGPTMessage]
35-
var temperature: Double?
36-
var top_p: Double?
37-
var n: Double?
38-
var stream: Bool?
39-
var stop: [String]?
40-
var max_tokens: Int?
41-
var presence_penalty: Double?
42-
var frequency_penalty: Double?
43-
var logit_bias: [String: Double]?
44-
var user: String?
45-
}
46-
47-
/// https://platform.openai.com/docs/api-reference/chat/create
48-
struct ChatGPTResponse: Codable {
49-
var id: String
50-
var object: String
51-
var created: Int
52-
var choices: [Choice]
53-
var usage: Usage
54-
55-
struct Usage: Codable {
56-
var prompt_tokens: Int
57-
var completion_tokens: Int
58-
var total_tokens: Int
59-
}
60-
61-
struct Choice: Codable {
62-
var index: Int
63-
var message: ChatGPTMessage
64-
var finish_reason: String
65-
}
66-
}
67-
68-
struct ChatGPTDataTrunk: Codable {
69-
var id: String
70-
var object: String
71-
var created: Int
72-
var model: String
73-
var choices: [Choice]
74-
75-
struct Choice: Codable {
76-
var delta: Delta
77-
var index: Int
78-
var finish_reason: String?
79-
80-
struct Delta: Codable {
81-
var role: ChatGPTMessage.Role?
82-
var content: String?
83-
}
84-
}
85-
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import XCTest
2+
@testable import OpenAIService
3+
4+
final class ChatGPTServiceFieldTests: XCTestCase {
5+
func test_calling_the_api() async throws {
6+
let service = ChatGPTService(systemPrompt: "", apiKey: "Key")
7+
8+
if (await service.apiKey) == "Key" {
9+
return
10+
}
11+
12+
do {
13+
let stream = try await service.send(content: "Hello")
14+
for try await text in stream {
15+
print(text)
16+
}
17+
} catch {
18+
print("🔴", error.localizedDescription)
19+
}
20+
21+
XCTFail("🔴 Please reset the key to `Key` after the field tests.")
22+
}
23+
}

0 commit comments

Comments
 (0)