Skip to content

Commit 4085fdb

Browse files
committed
Support create ChatGPTServiceType from RAGChatAgentConfiguration
1 parent 94d2a01 commit 4085fdb

2 files changed

Lines changed: 63 additions & 6 deletions

File tree

Tool/Sources/RAGChatAgent/RAGChatAgent.swift

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ public class RAGChatAgent: ChatAgent {
1111
}
1212

1313
public func send(_ request: Request) async -> AsyncThrowingStream<Response, any Error> {
14-
let service = getService()
1514
let stream = AsyncThrowingStream<Response, any Error> { continuation in
1615
let task = Task(priority: .userInitiated) {
1716
do {
17+
let service = try await createService(for: request)
1818
let response = try await service.send(content: request.text, summary: nil)
1919
for try await item in response {
2020
if Task.isCancelled {
@@ -28,21 +28,40 @@ public class RAGChatAgent: ChatAgent {
2828
continuation.finish(throwing: error)
2929
}
3030
}
31-
31+
3232
continuation.onTermination = { _ in
3333
task.cancel()
3434
}
3535
}
36-
36+
3737
return stream
3838
}
3939
}
4040

4141
extension RAGChatAgent {
42-
func getService() -> ChatGPTServiceType {
43-
fatalError()
42+
func createService(for request: Request) async throws -> ChatGPTServiceType {
43+
guard let chatGPTConfiguration = configuration.chatGPTConfiguration
44+
else { throw CancellationError() }
45+
let functionProvider = ChatFunctionProvider()
46+
let memory = AutoManagedChatGPTMemory(
47+
systemPrompt: configuration.modelConfiguration.systemPrompt,
48+
configuration: chatGPTConfiguration,
49+
functionProvider: functionProvider
50+
)
51+
52+
await memory.mutateHistory { messages in
53+
for history in request.history {
54+
messages.append(history)
55+
}
56+
}
57+
58+
return ChatGPTService(
59+
memory: memory,
60+
configuration: chatGPTConfiguration,
61+
functionProvider: functionProvider
62+
)
4463
}
45-
64+
4665
var allCapabilities: [String: any RAGChatAgentCapability] {
4766
RAGChatAgentCapabilityContainer.capabilities
4867
}

Tool/Sources/RAGChatAgent/RAGChatAgentConfiguration.swift

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
import AIModel
2+
import ChatBasic
13
import CodableWrappers
24
import Foundation
5+
import OpenAIService
6+
import Preferences
7+
import Keychain
38

49
public struct RAGChatAgentConfiguration: Codable {
510
public struct ModelConfiguration: Codable {
@@ -77,5 +82,38 @@ public struct RAGChatAgentConfiguration: Codable {
7782
) throws {
7883
_otherConfigurations = try JSONEncoder().encode(otherConfigurations)
7984
}
85+
86+
var chatGPTConfiguration: ChatGPTConfiguration? {
87+
guard case let .chatModel(id) = serviceProvider else { return nil }
88+
return .init(
89+
model: {
90+
let models = UserDefaults.shared.value(for: \.chatModels)
91+
let id = UserDefaults.shared.value(for: \.defaultChatFeatureChatModelId)
92+
return models.first { $0.id == id }
93+
?? models.first
94+
}(),
95+
temperature: modelConfiguration.temperature,
96+
stop: [],
97+
maxTokens: modelConfiguration.maxTokens,
98+
minimumReplyTokens: modelConfiguration.minimumReplyTokens,
99+
runFunctionsAutomatically: false,
100+
shouldEndTextWindow: { _ in false }
101+
)
102+
}
103+
104+
struct ChatGPTConfiguration: OpenAIService.ChatGPTConfiguration {
105+
var model: ChatModel?
106+
var temperature: Double
107+
var stop: [String]
108+
var maxTokens: Int
109+
var minimumReplyTokens: Int
110+
var runFunctionsAutomatically: Bool
111+
var shouldEndTextWindow: (String) -> Bool
112+
113+
var apiKey: String {
114+
guard let name = model?.info.apiKeyName else { return "" }
115+
return (try? Keychain.apiKey.get(name)) ?? ""
116+
}
117+
}
80118
}
81119

0 commit comments

Comments
 (0)