Skip to content

Commit 061fc16

Browse files
committed
Add ChatGPTService
1 parent 8ea4976 commit 061fc16

4 files changed

Lines changed: 299 additions & 0 deletions

File tree

Core/Package.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,5 +123,13 @@ let package = Package(
123123
),
124124
.target(name: "AXExtension"),
125125
.target(name: "Logger"),
126+
.target(
127+
name: "OpenAIService",
128+
dependencies: [.product(name: "AsyncAlgorithms", package: "swift-async-algorithms")]
129+
),
130+
.testTarget(
131+
name: "OpenAIServiceTests",
132+
dependencies: ["OpenAIService"]
133+
),
126134
]
127135
)
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import AsyncAlgorithms
2+
import Foundation
3+
4+
public protocol ChatGPTServiceType {
5+
func send(content: String, summary: String?) async throws -> AsyncThrowingStream<String, Error>
6+
func stopReceivingMessage() async
7+
func restart() async
8+
func mutateSystemPrompt(_ newPrompt: String) async
9+
}
10+
11+
public enum ChatGPTServiceError: Error {
12+
case endpointIncorrect
13+
case responseInvalid
14+
}
15+
16+
public struct ChatGPTError: Error, Codable, LocalizedError {
17+
public var error: ErrorContent
18+
public init(error: ErrorContent) {
19+
self.error = error
20+
}
21+
22+
public struct ErrorContent: Codable {
23+
public var message: String
24+
public var type: String
25+
public var param: String?
26+
public var code: String?
27+
28+
public init(message: String, type: String, param: String? = nil, code: String? = nil) {
29+
self.message = message
30+
self.type = type
31+
self.param = param
32+
self.code = code
33+
}
34+
}
35+
36+
public var errorDescription: String? {
37+
error.message
38+
}
39+
}
40+
41+
public actor ChatGPTService: ChatGPTServiceType, ObservableObject {
42+
public var temperature: Double
43+
public var model: ChatGPTModel
44+
public var endpoint: String
45+
public var apiKey: String
46+
public var systemPrompt: String
47+
public var maxToken: Int
48+
public var history: [ChatGPTMessage] = [] {
49+
didSet { objectWillChange.send() }
50+
}
51+
52+
public internal(set) var isReceivingMessage = false
53+
var ongoingTask: URLSessionDataTask?
54+
55+
public init(
56+
systemPrompt: String,
57+
apiKey: String,
58+
endpoint: String = "https://api.openai.com/v1/chat/completions",
59+
model: ChatGPTModel = .gpt_3_5_turbo,
60+
temperature: Double = 1,
61+
maxToken: Int = 2048
62+
) {
63+
self.systemPrompt = systemPrompt
64+
self.apiKey = apiKey
65+
self.model = model
66+
self.temperature = temperature
67+
self.maxToken = maxToken
68+
self.endpoint = endpoint
69+
}
70+
71+
public func send(
72+
content: String,
73+
summary: String? = nil
74+
) async throws -> AsyncThrowingStream<String, Error> {
75+
guard !isReceivingMessage else { throw CancellationError() }
76+
guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect }
77+
let newMessage = ChatGPTMessage(role: .user, content: content, summary: summary)
78+
history.append(newMessage)
79+
var request = URLRequest(url: url)
80+
request.httpMethod = "POST"
81+
82+
let requestBody = ChatGPTRequest(
83+
model: model.rawValue,
84+
messages: combineHistoryWithSystemPrompt(),
85+
temperature: temperature,
86+
stream: true,
87+
max_tokens: maxToken
88+
)
89+
90+
let encoder = JSONEncoder()
91+
request.httpBody = try encoder.encode(requestBody)
92+
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
93+
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
94+
95+
isReceivingMessage = true
96+
97+
do {
98+
let (result, response) = try await URLSession.shared.bytes(for: request)
99+
ongoingTask = result.task
100+
101+
guard let response = response as? HTTPURLResponse else {
102+
throw ChatGPTServiceError.responseInvalid
103+
}
104+
guard response.statusCode == 200 else {
105+
let text = try await result.lines.reduce(into: "") { partialResult, current in
106+
partialResult += current
107+
}
108+
guard let data = text.data(using: .utf8)
109+
else { throw ChatGPTServiceError.responseInvalid }
110+
let decoder = JSONDecoder()
111+
let error = try? decoder.decode(ChatGPTError.self, from: data)
112+
throw error ?? ChatGPTServiceError.responseInvalid
113+
}
114+
115+
return AsyncThrowingStream<String, Error> { continuation in
116+
Task {
117+
do {
118+
for try await line in result.lines {
119+
let prefix = "data: "
120+
guard line.hasPrefix(prefix),
121+
let content = line.dropFirst(prefix.count).data(using: .utf8),
122+
let trunk = try? JSONDecoder()
123+
.decode(ChatGPTDataTrunk.self, from: content),
124+
let delta = trunk.choices.first?.delta
125+
else { continue }
126+
127+
if history.last?.id == trunk.id {
128+
if let role = delta.role {
129+
history[history.endIndex - 1].role = role
130+
}
131+
if let content = delta.content {
132+
history[history.endIndex - 1].content.append(content)
133+
}
134+
} else {
135+
history.append(.init(
136+
role: delta.role ?? .assistant,
137+
content: delta.content ?? "",
138+
id: trunk.id
139+
))
140+
}
141+
142+
if let content = delta.content {
143+
continuation.yield(content)
144+
}
145+
}
146+
147+
continuation.finish()
148+
isReceivingMessage = false
149+
} catch {
150+
continuation.finish(throwing: error)
151+
}
152+
}
153+
}
154+
} catch {
155+
isReceivingMessage = false
156+
throw error
157+
}
158+
}
159+
160+
public func stopReceivingMessage() {
161+
ongoingTask?.cancel()
162+
ongoingTask = nil
163+
isReceivingMessage = false
164+
}
165+
166+
public func restart() {
167+
history = []
168+
}
169+
170+
public func mutateSystemPrompt(_ newPrompt: String) {
171+
systemPrompt = newPrompt
172+
}
173+
}
174+
175+
extension ChatGPTService {
176+
func combineHistoryWithSystemPrompt() -> [ChatGPTMessage] {
177+
if history.count > 4 {
178+
return [.init(role: .system, content: systemPrompt)] +
179+
history[history.endIndex - 4..<history.endIndex]
180+
}
181+
return [.init(role: .system, content: systemPrompt)] + history
182+
}
183+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import Foundation
2+
3+
public struct ChatGPTMessage: Equatable, Codable {
4+
public enum Role: String, Codable, Equatable {
5+
case system
6+
case user
7+
case assistant
8+
}
9+
10+
public var role: Role
11+
public var content: String
12+
public var summary: String?
13+
public var id: String?
14+
15+
public init(role: Role, content: String, summary: String? = nil, id: String? = nil) {
16+
self.role = role
17+
self.content = content
18+
self.summary = summary
19+
self.id = id
20+
}
21+
}
22+
23+
public enum ChatGPTModel: String {
24+
case gpt_3_5_turbo = "gpt-3.5-turbo"
25+
case gpt_3_5_turbo_0301 = "gpt-3.5-turbo-0301"
26+
case gpt_4_0314 = "gpt-4-0314"
27+
case gpt_4_32k = "gpt-4-32k"
28+
case gpt_4_32k_0314 = "gpt-4-32k-0314"
29+
}
30+
31+
/// https://platform.openai.com/docs/api-reference/chat/create
32+
struct ChatGPTRequest: Codable {
33+
var model: String
34+
var messages: [ChatGPTMessage]
35+
var temperature: Double?
36+
var top_p: Double?
37+
var n: Double?
38+
var stream: Bool?
39+
var stop: [String]?
40+
var max_tokens: Int?
41+
var presence_penalty: Double?
42+
var frequency_penalty: Double?
43+
var logit_bias: [String: Double]?
44+
var user: String?
45+
}
46+
47+
/// https://platform.openai.com/docs/api-reference/chat/create
48+
struct ChatGPTResponse: Codable {
49+
var id: String
50+
var object: String
51+
var created: Int
52+
var choices: [Choice]
53+
var usage: Usage
54+
55+
struct Usage: Codable {
56+
var prompt_tokens: Int
57+
var completion_tokens: Int
58+
var total_tokens: Int
59+
}
60+
61+
struct Choice: Codable {
62+
var index: Int
63+
var message: ChatGPTMessage
64+
var finish_reason: String
65+
}
66+
}
67+
68+
struct ChatGPTDataTrunk: Codable {
69+
var id: String
70+
var object: String
71+
var created: Int
72+
var model: String
73+
var choices: [Choice]
74+
75+
struct Choice: Codable {
76+
var delta: Delta
77+
var index: Int
78+
var finish_reason: String?
79+
80+
struct Delta: Codable {
81+
var role: ChatGPTMessage.Role?
82+
var content: String?
83+
}
84+
}
85+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import XCTest
2+
@testable import OpenAIService
3+
4+
final class ChatGPTServiceTests: XCTestCase {
5+
func test_calling_the_api() async throws {
6+
let service = ChatGPTService(systemPrompt: "", apiKey: "Key")
7+
8+
if (await service.apiKey) == "Key" {
9+
return
10+
}
11+
12+
do {
13+
let stream = try await service.send(content: "Hello")
14+
for try await text in stream {
15+
print(text)
16+
}
17+
} catch {
18+
print("🔴", error.localizedDescription)
19+
}
20+
21+
XCTFail("🔴 Please reset the key to `Key` after the field tests.")
22+
}
23+
}

0 commit comments

Comments
 (0)