forked from intitni/CopilotForXcode
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCompletionStreamAPI.swift
More file actions
115 lines (102 loc) · 3.71 KB
/
CompletionStreamAPI.swift
File metadata and controls
115 lines (102 loc) · 3.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import AsyncAlgorithms
import Foundation
typealias CompletionStreamAPIBuilder = (String, URL, CompletionRequestBody) -> CompletionStreamAPI
protocol CompletionStreamAPI {
func callAsFunction() async throws -> (
trunkStream: AsyncThrowingStream<CompletionStreamDataTrunk, Error>,
cancel: Cancellable
)
}
/// https://platform.openai.com/docs/api-reference/chat/create
struct CompletionRequestBody: Codable, Equatable {
struct Message: Codable, Equatable {
var role: ChatMessage.Role
var content: String
}
var model: String
var messages: [Message]
var temperature: Double?
var top_p: Double?
var n: Double?
var stream: Bool?
var stop: [String]?
var max_tokens: Int?
var presence_penalty: Double?
var frequency_penalty: Double?
var logit_bias: [String: Double]?
var user: String?
}
struct CompletionStreamDataTrunk: Codable {
var id: String
var object: String
var created: Int
var model: String
var choices: [Choice]
struct Choice: Codable {
var delta: Delta
var index: Int
var finish_reason: String?
struct Delta: Codable {
var role: ChatMessage.Role?
var content: String?
}
}
}
struct OpenAICompletionStreamAPI: CompletionStreamAPI {
var apiKey: String
var endpoint: URL
var requestBody: CompletionRequestBody
init(apiKey: String, endpoint: URL, requestBody: CompletionRequestBody) {
self.apiKey = apiKey
self.endpoint = endpoint
self.requestBody = requestBody
}
func callAsFunction() async throws -> (
trunkStream: AsyncThrowingStream<CompletionStreamDataTrunk, Error>,
cancel: Cancellable
) {
var request = URLRequest(url: endpoint)
request.httpMethod = "POST"
let encoder = JSONEncoder()
request.httpBody = try encoder.encode(requestBody)
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
let (result, response) = try await URLSession.shared.bytes(for: request)
guard let response = response as? HTTPURLResponse else {
throw ChatGPTServiceError.responseInvalid
}
guard response.statusCode == 200 else {
let text = try await result.lines.reduce(into: "") { partialResult, current in
partialResult += current
}
guard let data = text.data(using: .utf8)
else { throw ChatGPTServiceError.responseInvalid }
let decoder = JSONDecoder()
let error = try? decoder.decode(ChatGPTError.self, from: data)
throw error ?? ChatGPTServiceError.responseInvalid
}
return (
AsyncThrowingStream<CompletionStreamDataTrunk, Error> { continuation in
Task {
do {
for try await line in result.lines {
let prefix = "data: "
guard line.hasPrefix(prefix),
let content = line.dropFirst(prefix.count).data(using: .utf8),
let trunk = try? JSONDecoder()
.decode(CompletionStreamDataTrunk.self, from: content)
else { continue }
continuation.yield(trunk)
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
},
Cancellable {
result.task.cancel()
}
)
}
}