Skip to content

Commit 1ebe409

Browse files
committed
Merge branch 'feature/openai-service' into develop
2 parents 8ea4976 + 35fc87f commit 1ebe409

7 files changed

Lines changed: 426 additions & 0 deletions

File tree

Core/Package.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,5 +123,13 @@ let package = Package(
123123
),
124124
.target(name: "AXExtension"),
125125
.target(name: "Logger"),
126+
.target(
127+
name: "OpenAIService",
128+
dependencies: [.product(name: "AsyncAlgorithms", package: "swift-async-algorithms")]
129+
),
130+
.testTarget(
131+
name: "OpenAIServiceTests",
132+
dependencies: ["OpenAIService"]
133+
),
126134
]
127135
)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import AsyncAlgorithms
2+
import Foundation
3+
4+
public protocol ChatGPTServiceType {
5+
func send(content: String, summary: String?) async throws -> AsyncThrowingStream<String, Error>
6+
func stopReceivingMessage() async
7+
func restart() async
8+
func mutateSystemPrompt(_ newPrompt: String) async
9+
}
10+
11+
public enum ChatGPTServiceError: Error, LocalizedError {
12+
case endpointIncorrect
13+
case responseInvalid
14+
15+
public var errorDescription: String? {
16+
switch self {
17+
case .endpointIncorrect:
18+
return "ChatGPT endpoint is incorrect"
19+
case .responseInvalid:
20+
return "Response is invalid"
21+
}
22+
}
23+
}
24+
25+
public struct ChatGPTError: Error, Codable, LocalizedError {
26+
public var error: ErrorContent
27+
public init(error: ErrorContent) {
28+
self.error = error
29+
}
30+
31+
public struct ErrorContent: Codable {
32+
public var message: String
33+
public var type: String
34+
public var param: String?
35+
public var code: String?
36+
37+
public init(message: String, type: String, param: String? = nil, code: String? = nil) {
38+
self.message = message
39+
self.type = type
40+
self.param = param
41+
self.code = code
42+
}
43+
}
44+
45+
public var errorDescription: String? {
46+
error.message
47+
}
48+
}
49+
50+
public actor ChatGPTService: ChatGPTServiceType, ObservableObject {
51+
public var temperature: Double
52+
public var model: ChatGPTModel
53+
public var endpoint: String
54+
public var apiKey: String
55+
public var systemPrompt: String
56+
public var maxToken: Int
57+
public var history: [ChatMessage] = [] {
58+
didSet { objectWillChange.send() }
59+
}
60+
61+
public internal(set) var isReceivingMessage = false
62+
var cancelTask: Cancellable?
63+
var buildCompletionStreamAPI: CompletionStreamAPIBuilder = OpenAICompletionStreamAPI.init
64+
65+
public init(
66+
systemPrompt: String,
67+
apiKey: String,
68+
endpoint: String = "https://api.openai.com/v1/chat/completions",
69+
model: ChatGPTModel = .gpt_3_5_turbo,
70+
temperature: Double = 1,
71+
maxToken: Int = 2048
72+
) {
73+
self.systemPrompt = systemPrompt
74+
self.apiKey = apiKey
75+
self.model = model
76+
self.temperature = temperature
77+
self.maxToken = maxToken
78+
self.endpoint = endpoint
79+
}
80+
81+
public func send(
82+
content: String,
83+
summary: String? = nil
84+
) async throws -> AsyncThrowingStream<String, Error> {
85+
guard !isReceivingMessage else { throw CancellationError() }
86+
guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect }
87+
let newMessage = ChatMessage(role: .user, content: content, summary: summary)
88+
history.append(newMessage)
89+
90+
let requestBody = CompletionRequestBody(
91+
model: model.rawValue,
92+
messages: combineHistoryWithSystemPrompt(),
93+
temperature: temperature,
94+
stream: true,
95+
max_tokens: maxToken
96+
)
97+
98+
isReceivingMessage = true
99+
100+
do {
101+
let api = buildCompletionStreamAPI(apiKey, url, requestBody)
102+
let (trunks, cancel) = try await api()
103+
cancelTask = cancel
104+
105+
return AsyncThrowingStream<String, Error> { continuation in
106+
Task {
107+
do {
108+
for try await trunk in trunks {
109+
guard let delta = trunk.choices.first?.delta else { continue }
110+
111+
if history.last?.id == trunk.id {
112+
if let role = delta.role {
113+
history[history.endIndex - 1].role = role
114+
}
115+
if let content = delta.content {
116+
history[history.endIndex - 1].content.append(content)
117+
}
118+
} else {
119+
history.append(.init(
120+
role: delta.role ?? .assistant,
121+
content: delta.content ?? "",
122+
id: trunk.id
123+
))
124+
}
125+
126+
if let content = delta.content {
127+
continuation.yield(content)
128+
}
129+
}
130+
131+
continuation.finish()
132+
isReceivingMessage = false
133+
} catch {
134+
continuation.finish(throwing: error)
135+
}
136+
}
137+
}
138+
} catch {
139+
isReceivingMessage = false
140+
throw error
141+
}
142+
}
143+
144+
public func stopReceivingMessage() {
145+
cancelTask?()
146+
cancelTask = nil
147+
isReceivingMessage = false
148+
}
149+
150+
public func restart() {
151+
history = []
152+
}
153+
154+
public func mutateSystemPrompt(_ newPrompt: String) {
155+
systemPrompt = newPrompt
156+
}
157+
}
158+
159+
extension ChatGPTService {
160+
func changeBuildCompletionStreamAPI(_ builder: @escaping CompletionStreamAPIBuilder) {
161+
buildCompletionStreamAPI = builder
162+
}
163+
164+
func combineHistoryWithSystemPrompt() -> [ChatMessage] {
165+
if history.count > 4 {
166+
return [.init(role: .system, content: systemPrompt)] +
167+
history[history.endIndex - 4..<history.endIndex]
168+
}
169+
return [.init(role: .system, content: systemPrompt)] + history
170+
}
171+
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import AsyncAlgorithms
2+
import Foundation
3+
4+
typealias CompletionStreamAPIBuilder = (String, URL, CompletionRequestBody) -> CompletionStreamAPI
5+
6+
protocol CompletionStreamAPI {
7+
func callAsFunction() async throws -> (
8+
trunkStream: AsyncThrowingStream<CompletionStreamDataTrunk, Error>,
9+
cancel: Cancellable
10+
)
11+
}
12+
13+
/// https://platform.openai.com/docs/api-reference/chat/create
14+
struct CompletionRequestBody: Codable {
15+
var model: String
16+
var messages: [ChatMessage]
17+
var temperature: Double?
18+
var top_p: Double?
19+
var n: Double?
20+
var stream: Bool?
21+
var stop: [String]?
22+
var max_tokens: Int?
23+
var presence_penalty: Double?
24+
var frequency_penalty: Double?
25+
var logit_bias: [String: Double]?
26+
var user: String?
27+
}
28+
29+
struct CompletionStreamDataTrunk: Codable {
30+
var id: String
31+
var object: String
32+
var created: Int
33+
var model: String
34+
var choices: [Choice]
35+
36+
struct Choice: Codable {
37+
var delta: Delta
38+
var index: Int
39+
var finish_reason: String?
40+
41+
struct Delta: Codable {
42+
var role: ChatMessage.Role?
43+
var content: String?
44+
}
45+
}
46+
}
47+
48+
struct OpenAICompletionStreamAPI: CompletionStreamAPI {
49+
var apiKey: String
50+
var endpoint: URL
51+
var requestBody: CompletionRequestBody
52+
53+
init(apiKey: String, endpoint: URL, requestBody: CompletionRequestBody) {
54+
self.apiKey = apiKey
55+
self.endpoint = endpoint
56+
self.requestBody = requestBody
57+
}
58+
59+
func callAsFunction() async throws -> (
60+
trunkStream: AsyncThrowingStream<CompletionStreamDataTrunk, Error>,
61+
cancel: Cancellable
62+
) {
63+
var request = URLRequest(url: endpoint)
64+
request.httpMethod = "POST"
65+
let encoder = JSONEncoder()
66+
request.httpBody = try encoder.encode(requestBody)
67+
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
68+
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
69+
70+
let (result, response) = try await URLSession.shared.bytes(for: request)
71+
guard let response = response as? HTTPURLResponse else {
72+
throw ChatGPTServiceError.responseInvalid
73+
}
74+
75+
guard response.statusCode == 200 else {
76+
let text = try await result.lines.reduce(into: "") { partialResult, current in
77+
partialResult += current
78+
}
79+
guard let data = text.data(using: .utf8)
80+
else { throw ChatGPTServiceError.responseInvalid }
81+
let decoder = JSONDecoder()
82+
let error = try? decoder.decode(ChatGPTError.self, from: data)
83+
throw error ?? ChatGPTServiceError.responseInvalid
84+
}
85+
86+
return (
87+
AsyncThrowingStream<CompletionStreamDataTrunk, Error> { continuation in
88+
Task {
89+
do {
90+
for try await line in result.lines {
91+
let prefix = "data: "
92+
guard line.hasPrefix(prefix),
93+
let content = line.dropFirst(prefix.count).data(using: .utf8),
94+
let trunk = try? JSONDecoder()
95+
.decode(CompletionStreamDataTrunk.self, from: content)
96+
else { continue }
97+
continuation.yield(trunk)
98+
}
99+
continuation.finish()
100+
} catch {
101+
continuation.finish(throwing: error)
102+
}
103+
}
104+
},
105+
Cancellable {
106+
result.task.cancel()
107+
}
108+
)
109+
}
110+
}
111+
112+
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import Foundation
2+
3+
struct Cancellable {
4+
let cancel: () -> Void
5+
func callAsFunction() {
6+
cancel()
7+
}
8+
}
9+
10+
public struct ChatMessage: Equatable, Codable {
11+
public enum Role: String, Codable, Equatable {
12+
case system
13+
case user
14+
case assistant
15+
}
16+
17+
public var role: Role
18+
public var content: String
19+
public var summary: String?
20+
public var id: String?
21+
22+
public init(role: Role, content: String, summary: String? = nil, id: String? = nil) {
23+
self.role = role
24+
self.content = content
25+
self.summary = summary
26+
self.id = id
27+
}
28+
}
29+
30+
public enum ChatGPTModel: String {
31+
case gpt_3_5_turbo = "gpt-3.5-turbo"
32+
case gpt_3_5_turbo_0301 = "gpt-3.5-turbo-0301"
33+
case gpt_4_0314 = "gpt-4-0314"
34+
case gpt_4_32k = "gpt-4-32k"
35+
case gpt_4_32k_0314 = "gpt-4-32k-0314"
36+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import XCTest
2+
@testable import OpenAIService
3+
4+
final class ChatGPTServiceFieldTests: XCTestCase {
5+
func test_calling_the_api() async throws {
6+
let service = ChatGPTService(systemPrompt: "", apiKey: "Key")
7+
8+
if (await service.apiKey) == "Key" {
9+
return
10+
}
11+
12+
do {
13+
let stream = try await service.send(content: "Hello")
14+
for try await text in stream {
15+
print(text)
16+
}
17+
} catch {
18+
print("🔴", error.localizedDescription)
19+
}
20+
21+
XCTFail("🔴 Please reset the key to `Key` after the field tests.")
22+
}
23+
}

0 commit comments

Comments
 (0)