Skip to content

Commit d741333

Browse files
committed
Reset ChatCompletionsStreamAPI to return AsyncThrowingStream
1 parent 6d6fc49 commit d741333

5 files changed

Lines changed: 46 additions & 15 deletions

File tree

Pro

Submodule Pro updated from 5f1f1dd to 322e945

Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,28 @@ typealias ChatCompletionsStreamAPIBuilder = (
128128
) -> any ChatCompletionsStreamAPI
129129

130130
protocol ChatCompletionsStreamAPI {
131-
associatedtype CompletionSequence: AsyncSequence
132-
where CompletionSequence.Element == ChatCompletionsStreamDataChunk
133-
func callAsFunction() async throws -> CompletionSequence
131+
func callAsFunction() async throws -> AsyncThrowingStream<ChatCompletionsStreamDataChunk, Error>
132+
}
133+
134+
extension AsyncSequence {
135+
func toStream() -> AsyncThrowingStream<Element, Error> {
136+
AsyncThrowingStream { continuation in
137+
let task = Task {
138+
do {
139+
for try await element in self {
140+
continuation.yield(element)
141+
}
142+
continuation.finish()
143+
} catch {
144+
continuation.finish(throwing: error)
145+
}
146+
}
147+
148+
continuation.onTermination = { _ in
149+
task.cancel()
150+
}
151+
}
152+
}
134153
}
135154

136155
struct ChatCompletionsStreamDataChunk: Codable {
@@ -159,7 +178,13 @@ struct ChatCompletionsStreamDataChunk: Codable {
159178

160179
// MARK: - Non Stream API
161180

162-
typealias ChatCompletionsAPIBuilder = (String, ChatModel, URL, ChatCompletionsRequestBody, ChatGPTPrompt)
181+
typealias ChatCompletionsAPIBuilder = (
182+
String,
183+
ChatModel,
184+
URL,
185+
ChatCompletionsRequestBody,
186+
ChatGPTPrompt
187+
)
163188
-> any ChatCompletionsAPI
164189

165190
protocol ChatCompletionsAPI {

Tool/Sources/OpenAIService/APIs/OlamaService.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ extension OllamaService: ChatCompletionsAPI {
3333
}
3434

3535
extension OllamaService: ChatCompletionsStreamAPI {
36-
typealias CompletionSequence = AsyncMapSequence<ResponseStream<OllamaService.ChatCompletionResponseChunk>, ChatCompletionsStreamDataChunk>
37-
38-
func callAsFunction() async throws -> CompletionSequence {
36+
func callAsFunction() async throws
37+
-> AsyncThrowingStream<ChatCompletionsStreamDataChunk, Swift.Error>
38+
{
3939
let requestBody = ChatCompletionRequestBody(
4040
model: model.info.modelName,
4141
messages: requestBody.messages.map { message in
@@ -115,7 +115,7 @@ extension OllamaService: ChatCompletionsStreamAPI {
115115
)
116116
}
117117

118-
return sequence
118+
return sequence.toStream()
119119
}
120120
}
121121

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,14 @@ public class ChatGPTService: ChatGPTServiceType {
7373
apiKey, model, endpoint, requestBody, prompt in
7474
switch model.format {
7575
case .googleAI:
76-
return GoogleCompletionStreamAPI(
76+
return GoogleAIService(
7777
apiKey: apiKey,
7878
model: model,
7979
requestBody: requestBody,
8080
prompt: prompt
8181
)
8282
case .openAI, .openAICompatible, .azureOpenAI:
83-
return OpenAICompletionStreamAPI(
83+
return OpenAIService(
8484
apiKey: apiKey,
8585
model: model,
8686
endpoint: endpoint,
@@ -93,14 +93,14 @@ public class ChatGPTService: ChatGPTServiceType {
9393
apiKey, model, endpoint, requestBody, prompt in
9494
switch model.format {
9595
case .googleAI:
96-
return GoogleCompletionAPI(
96+
return GoogleAIService(
9797
apiKey: apiKey,
9898
model: model,
9999
requestBody: requestBody,
100100
prompt: prompt
101101
)
102102
case .openAI, .openAICompatible, .azureOpenAI:
103-
return OpenAICompletionAPI(
103+
return OpenAIService(
104104
apiKey: apiKey,
105105
model: model,
106106
endpoint: endpoint,

Tool/Sources/OpenAIService/EmbeddingService.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ public struct EmbeddingService {
7373
}
7474

7575
guard response.statusCode == 200 else {
76-
let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result)
76+
let error = try? JSONDecoder().decode(
77+
OpenAIService.CompletionAPIError.self,
78+
from: result
79+
)
7780
throw error ?? ChatGPTServiceError
7881
.otherError(String(data: result, encoding: .utf8) ?? "Unknown Error")
7982
}
@@ -124,7 +127,10 @@ public struct EmbeddingService {
124127
}
125128

126129
guard response.statusCode == 200 else {
127-
let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result)
130+
let error = try? JSONDecoder().decode(
131+
OpenAIService.CompletionAPIError.self,
132+
from: result
133+
)
128134
throw error ?? ChatGPTServiceError
129135
.otherError(String(data: result, encoding: .utf8) ?? "Unknown Error")
130136
}

0 commit comments

Comments
 (0)