Skip to content

Commit 48acf04

Browse files
committed
Adjust API of ChatGPTFunction
1 parent cfd11d8 commit 48acf04

13 files changed

Lines changed: 157 additions & 109 deletions

File tree

Core/Sources/ChatContextCollectors/ActiveDocumentChatContextCollector/Functions/ExpandFocusRangeFunction.swift

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@ struct ExpandFocusRangeFunction: ChatGPTFunction {
1313
"Editing Document Context is updated to display code at \(range)."
1414
}
1515
}
16-
16+
1717
struct E: Error, LocalizedError {
1818
var errorDescription: String?
1919
}
2020

21-
var reportProgress: (String) async -> Void = { _ in }
22-
2321
var name: String {
2422
"expandFocusRange"
2523
}
@@ -32,18 +30,21 @@ struct ExpandFocusRangeFunction: ChatGPTFunction {
3230
.type: "object",
3331
.properties: [:],
3432
] }
35-
33+
3634
weak var contextCollector: ActiveDocumentChatContextCollector?
37-
35+
3836
init(contextCollector: ActiveDocumentChatContextCollector) {
3937
self.contextCollector = contextCollector
4038
}
4139

42-
func prepare() async {
40+
func prepare(reportProgress: @escaping (String) async -> Void) async {
4341
await reportProgress("Finding the focused code..")
4442
}
4543

46-
func call(arguments: Arguments) async throws -> Result {
44+
func call(
45+
arguments: Arguments,
46+
reportProgress: @escaping (String) async -> Void
47+
) async throws -> Result {
4748
await reportProgress("Finding the focused code..")
4849
contextCollector?.activeDocumentContext?.expandFocusedRangeToContextRange()
4950
guard let newContext = contextCollector?.activeDocumentContext?.focusedContext else {
@@ -56,3 +57,4 @@ struct ExpandFocusRangeFunction: ChatGPTFunction {
5657
return .init(range: newContext.codeRange)
5758
}
5859
}
60+

Core/Sources/ChatContextCollectors/ActiveDocumentChatContextCollector/Functions/MoveToCodeAroundLineFunction.swift

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@ struct MoveToCodeAroundLineFunction: ChatGPTFunction {
1515
"Editing Document Context is updated to display code at \(range)."
1616
}
1717
}
18-
18+
1919
struct E: Error, LocalizedError {
2020
var errorDescription: String?
2121
}
2222

23-
var reportProgress: (String) async -> Void = { _ in }
24-
2523
var name: String {
2624
"getCodeAtLine"
2725
}
@@ -36,7 +34,7 @@ struct MoveToCodeAroundLineFunction: ChatGPTFunction {
3634
"line": [
3735
.type: "number",
3836
.description: "The line number in the file",
39-
]
37+
],
4038
],
4139
.required: ["line"],
4240
] }
@@ -47,11 +45,14 @@ struct MoveToCodeAroundLineFunction: ChatGPTFunction {
4745
self.contextCollector = contextCollector
4846
}
4947

50-
func prepare() async {
48+
func prepare(reportProgress: @escaping (String) async -> Void) async {
5149
await reportProgress("Finding code around..")
5250
}
5351

54-
func call(arguments: Arguments) async throws -> Result {
52+
func call(
53+
arguments: Arguments,
54+
reportProgress: @escaping (String) async -> Void
55+
) async throws -> Result {
5556
await reportProgress("Finding code around line \(arguments.line)..")
5657
contextCollector?.activeDocumentContext?.moveToCodeAroundLine(arguments.line)
5758
guard let newContext = contextCollector?.activeDocumentContext?.focusedContext else {

Core/Sources/ChatContextCollectors/ActiveDocumentChatContextCollector/Functions/MoveToFocusedCodeFunction.swift

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import OpenAIService
44
import SuggestionModel
55

66
struct MoveToFocusedCodeFunction: ChatGPTFunction {
7-
struct Arguments: Codable {}
7+
typealias Arguments = NoArguments
88

99
struct Result: ChatGPTFunctionResult {
1010
var range: CursorRange
@@ -13,13 +13,11 @@ struct MoveToFocusedCodeFunction: ChatGPTFunction {
1313
"Editing Document Context is updated to display code at \(range)."
1414
}
1515
}
16-
16+
1717
struct E: Error, LocalizedError {
1818
var errorDescription: String?
1919
}
2020

21-
var reportProgress: (String) async -> Void = { _ in }
22-
2321
var name: String {
2422
"moveToFocusedCode"
2523
}
@@ -28,22 +26,20 @@ struct MoveToFocusedCodeFunction: ChatGPTFunction {
2826
"Move editing document context to the selected or focused code"
2927
}
3028

31-
var argumentSchema: JSONSchemaValue { [
32-
.type: "object",
33-
.properties: [:],
34-
] }
35-
3629
weak var contextCollector: ActiveDocumentChatContextCollector?
37-
30+
3831
init(contextCollector: ActiveDocumentChatContextCollector) {
3932
self.contextCollector = contextCollector
4033
}
4134

42-
func prepare() async {
35+
func prepare(reportProgress: @escaping (String) async -> Void) async {
4336
await reportProgress("Finding the focused code..")
4437
}
4538

46-
func call(arguments: Arguments) async throws -> Result {
39+
func call(
40+
arguments: Arguments,
41+
reportProgress: @escaping (String) async -> Void
42+
) async throws -> Result {
4743
await reportProgress("Finding the focused code..")
4844
contextCollector?.activeDocumentContext?.moveToFocusedCode()
4945
guard let newContext = contextCollector?.activeDocumentContext?.focusedContext else {
@@ -56,3 +52,4 @@ struct MoveToFocusedCodeFunction: ChatGPTFunction {
5652
return .init(range: newContext.codeRange)
5753
}
5854
}
55+

Core/Sources/ChatContextCollectors/WebChatContextCollector/QueryWebsiteFunction.swift

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ struct QueryWebsiteFunction: ChatGPTFunction {
1717
}
1818
}
1919

20-
var reportProgress: (String) async -> Void = { _ in }
21-
2220
var name: String {
2321
"queryWebsite"
2422
}
@@ -47,11 +45,14 @@ struct QueryWebsiteFunction: ChatGPTFunction {
4745
]
4846
}
4947

50-
func prepare() async {
48+
func prepare(reportProgress: @escaping (String) async -> Void) async {
5149
await reportProgress("Reading..")
5250
}
5351

54-
func call(arguments: Arguments) async throws -> Result {
52+
func call(
53+
arguments: Arguments,
54+
reportProgress: @escaping (String) async -> Void
55+
) async throws -> Result {
5556
do {
5657
let embedding = OpenAIEmbedding(configuration: UserPreferenceEmbeddingConfiguration())
5758

@@ -61,10 +62,13 @@ struct QueryWebsiteFunction: ChatGPTFunction {
6162
group.addTask {
6263
// 1. grab the website content
6364
await reportProgress("Loading \(url)..")
64-
65+
6566
if let database = await TemporaryUSearch.view(identifier: urlString) {
6667
await reportProgress("Getting relevant information..")
67-
let qa = QAInformationRetrievalChain(vectorStore: database, embedding: embedding)
68+
let qa = QAInformationRetrievalChain(
69+
vectorStore: database,
70+
embedding: embedding
71+
)
6872
return try await qa.call(.init(arguments.query)).information
6973
}
7074
let loader = WebLoader(urls: [url])
@@ -83,7 +87,10 @@ struct QueryWebsiteFunction: ChatGPTFunction {
8387
try await database.set(embeddedDocuments)
8488
// 4. generate answer
8589
await reportProgress("Getting relevant information..")
86-
let qa = QAInformationRetrievalChain(vectorStore: database, embedding: embedding)
90+
let qa = QAInformationRetrievalChain(
91+
vectorStore: database,
92+
embedding: embedding
93+
)
8794
let result = try await qa.call(.init(arguments.query))
8895
return result.information
8996
}
@@ -101,7 +108,7 @@ struct QueryWebsiteFunction: ChatGPTFunction {
101108
.joined(separator: "\n")
102109
)
103110
""")
104-
111+
105112
return all
106113
}
107114

Core/Sources/ChatContextCollectors/WebChatContextCollector/SearchFunction.swift

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ struct SearchFunction: ChatGPTFunction {
2828
}.joined(separator: "\n")
2929
}
3030
}
31-
32-
let maxTokens: Int
3331

34-
var reportProgress: (String) async -> Void = { _ in }
32+
let maxTokens: Int
3533

3634
var name: String {
3735
"searchWeb"
@@ -62,19 +60,22 @@ struct SearchFunction: ChatGPTFunction {
6260
]
6361
}
6462

65-
func prepare() async {
63+
func prepare(reportProgress: @escaping ReportProgress) async {
6664
await reportProgress("Searching..")
6765
}
6866

69-
func call(arguments: Arguments) async throws -> Result {
67+
func call(
68+
arguments: Arguments,
69+
reportProgress: @escaping ReportProgress
70+
) async throws -> Result {
7071
await reportProgress("Searching \(arguments.query)")
7172

7273
do {
7374
let bingSearch = BingSearchService(
7475
subscriptionKey: UserDefaults.shared.value(for: \.bingSearchSubscriptionKey),
7576
searchURL: UserDefaults.shared.value(for: \.bingSearchEndpoint)
7677
)
77-
78+
7879
let result = try await bingSearch.search(
7980
query: arguments.query,
8081
numberOfResult: maxTokens > 5000 ? 5 : 3,

Pro

Submodule Pro updated from f17934c to d02bbce

Tool/Sources/LangChain/Agent.swift

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@ public extension CallbackEvents {
4444
var agentActionDidEnd: AgentActionDidEnd.Type {
4545
AgentActionDidEnd.self
4646
}
47+
48+
struct AgentFunctionCallingToolReportProgress: CallbackEvent {
49+
public struct Info {
50+
let functionName: String
51+
let progress: String
52+
}
53+
54+
public let info: Info
55+
}
56+
57+
var agentFunctionCallingToolReportProgress: AgentFunctionCallingToolReportProgress.Type {
58+
AgentFunctionCallingToolReportProgress.self
59+
}
4760
}
4861

4962
public struct AgentFinish<Output: AgentOutputParsable> {
@@ -68,7 +81,7 @@ public enum AgentNextStep<Output: AgentOutputParsable> {
6881

6982
public struct AgentScratchPad<Content: Equatable>: Equatable {
7083
public var content: Content
71-
84+
7285
public init(content: Content) {
7386
self.content = content
7487
}
@@ -99,7 +112,8 @@ public protocol Agent {
99112

100113
func validateTools(tools: [AgentTool]) throws
101114
func constructScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad<ScratchPadContent>
102-
func constructFinalScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad<ScratchPadContent>
115+
func constructFinalScratchpad(intermediateSteps: [AgentAction])
116+
-> AgentScratchPad<ScratchPadContent>
103117
func extraPlan(input: AgentInput<Input, ScratchPadContent>)
104118
func parseOutput(_ output: ChatModelChain<AgentInput<Input, ScratchPadContent>>.Output) async
105119
-> AgentNextStep<Output>
@@ -146,8 +160,12 @@ public extension Agent {
146160
case let .finish(finish):
147161
return finish
148162
case .actions:
149-
return .init(returnValue: .unstructured(output.content ?? ""), log: output.content ?? "")
163+
return .init(
164+
returnValue: .unstructured(output.content ?? ""),
165+
log: output.content ?? ""
166+
)
150167
}
151168
}
152169
}
153170
}
171+

Tool/Sources/LangChain/AgentTool.swift

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,17 @@ public struct SimpleAgentTool: AgentTool {
3131
}
3232
}
3333

34-
public struct FunctionCallingAgentTool<F: ChatGPTFunction>: AgentTool, ChatGPTFunction {
34+
public class FunctionCallingAgentTool<F: ChatGPTFunction>: AgentTool {
3535
public func call(arguments: F.Arguments) async throws -> F.Result {
36-
try await function.call(arguments: arguments)
36+
try await function.call(arguments: arguments, reportProgress: reportProgress)
3737
}
3838

3939
public var argumentSchema: OpenAIService.JSONSchemaValue { function.argumentSchema }
4040

41-
public func prepare() async { await function.prepare() }
42-
43-
public var reportProgress: (String) async -> Void {
44-
get { function.reportProgress }
45-
set { function.reportProgress = newValue }
41+
public func prepare() async {
42+
await function.prepare(reportProgress: { [weak self] p in
43+
self?.reportProgress(p)
44+
})
4645
}
4746

4847
public typealias Arguments = F.Arguments
@@ -53,15 +52,37 @@ public struct FunctionCallingAgentTool<F: ChatGPTFunction>: AgentTool, ChatGPTFu
5352
public var description: String
5453
public var returnDirectly: Bool
5554

56-
public init(function: F, returnDirectly: Bool = false) {
55+
let callbackManagers: [CallbackManager]
56+
57+
public init(
58+
function: F,
59+
returnDirectly: Bool = false,
60+
callbackManagers: [CallbackManager] = []
61+
) {
5762
self.function = function
63+
self.callbackManagers = callbackManagers
5864
name = function.name
59-
description = "Run an action: \(function.description)"
65+
description = function.description
6066
self.returnDirectly = returnDirectly
6167
}
6268

69+
func reportProgress(_ progress: String) {
70+
callbackManagers.send(
71+
CallbackEvents.AgentFunctionCallingToolReportProgress(info: .init(
72+
functionName: name,
73+
progress: progress
74+
))
75+
)
76+
}
77+
6378
public func run(input: String) async throws -> String {
64-
try await function.call(argumentsJsonString: input).botReadableContent
79+
try await function.call(
80+
argumentsJsonString: input,
81+
reportProgress: { [weak self] p in
82+
self?.reportProgress(p)
83+
}
84+
)
85+
.botReadableContent
6586
}
6687
}
6788

Tool/Sources/LangChain/Chains/RefineDocumentChain.swift

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,8 @@ public final class RefineDocumentChain: Chain {
4646
var functions: [any ChatGPTFunction] = [RespondFunction()]
4747
}
4848

49-
struct RespondFunction: ChatGPTFunction {
49+
struct RespondFunction: ChatGPTArgumentsCollectingFunction {
5050
typealias Arguments = IntermediateAnswer
51-
52-
struct Result: ChatGPTFunctionResult {
53-
var botReadableContent: String { "" }
54-
}
55-
56-
var reportProgress: (String) async -> Void = { _ in }
57-
5851
var name: String = "respond"
5952
var description: String = "Respond with the refined answer"
6053
var argumentSchema: JSONSchemaValue {
@@ -77,12 +70,6 @@ public final class RefineDocumentChain: Chain {
7770
.required: ["answer", "more", "usefulness"],
7871
]
7972
}
80-
81-
func prepare() async {}
82-
83-
func call(arguments: Arguments) async throws -> Result {
84-
return Result()
85-
}
8673
}
8774

8875
func buildChatModel() -> ChatModelChain<RefinementInput> {

0 commit comments

Comments
 (0)