Skip to content

Commit 42a8efb

Browse files
committed
Support tool call
1 parent ecf740e commit 42a8efb

File tree

18 files changed

+666
-366
lines changed

18 files changed

+666
-366
lines changed

Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ extension WebChatContextCollector {
3232
static func detectLinks(from messages: [ChatMessage]) -> [String] {
3333
return messages.lazy
3434
.compactMap {
35-
$0.content ?? $0.functionCall?.arguments
35+
$0.content ?? $0.toolCalls?.map(\.function.arguments).joined(separator: " ") ?? ""
3636
}
3737
.map(detectLinks(from:))
3838
.flatMap { $0 }

Core/Sources/ChatGPTChatTab/Chat.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ struct Chat: ReducerProtocol {
313313
}
314314
return .ignored
315315
case .function: return .function
316+
case .tool: return .function
316317
}
317318
}(),
318319
text: message.summary ?? message.content ?? "",

Core/Sources/ChatService/ChatService.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@ public final class ChatService: ObservableObject {
124124
await chatGPTService.stopReceivingMessage()
125125
isReceivingMessage = false
126126

127-
// if it's stopped before the function finishes, remove the function call.
127+
// if it's stopped before the tool calls finish, remove the message.
128128
await memory.mutateHistory { history in
129-
if history.last?.role == .assistant, history.last?.functionCall != nil {
129+
if history.last?.role == .assistant, history.last?.toolCalls != nil {
130130
history.removeLast()
131131
}
132132
}

Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ extension ChatModelEdit.State {
170170
maxTokens: model.info.maxTokens,
171171
supportsFunctionCalling: model.info.supportsFunctionCalling,
172172
modelName: model.info.modelName,
173-
ollamaKeepAlive: model.info.ollamaKeepAlive,
173+
ollamaKeepAlive: model.info.ollamaInfo.keepAlive,
174174
apiKeySelection: .init(
175175
apiKeyName: model.info.apiKeyName,
176176
apiKeyManagement: .init(availableAPIKeyNames: [model.info.apiKeyName])
@@ -198,7 +198,7 @@ extension ChatModel {
198198
return state.supportsFunctionCalling
199199
}(),
200200
modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines),
201-
ollamaKeepAlive: state.ollamaKeepAlive
201+
ollamaInfo: .init(keepAlive: state.ollamaKeepAlive)
202202
)
203203
)
204204
}

Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ extension EmbeddingModelEdit.State {
155155
format: model.format,
156156
maxTokens: model.info.maxTokens,
157157
modelName: model.info.modelName,
158-
ollamaKeepAlive: model.info.ollamaKeepAlive,
158+
ollamaKeepAlive: model.info.ollamaInfo.keepAlive,
159159
apiKeySelection: .init(
160160
apiKeyName: model.info.apiKeyName,
161161
apiKeyManagement: .init(availableAPIKeyNames: [model.info.apiKeyName])
@@ -177,7 +177,7 @@ extension EmbeddingModel {
177177
isFullURL: state.isFullURL,
178178
maxTokens: state.maxTokens,
179179
modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines),
180-
ollamaKeepAlive: state.ollamaKeepAlive
180+
ollamaInfo: .init(keepAlive: state.ollamaKeepAlive)
181181
)
182182
)
183183
}

Pro

Submodule Pro updated from fbb89b8 to c6cace8

Tool/Sources/LangChain/Chains/LLMChain.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ public class ChatModelChain<Input>: Chain {
3333
public func parseOutput(_ output: Output) -> String {
3434
if let content = output.content {
3535
return content
36-
} else if let functionCall = output.functionCall {
37-
return "\(functionCall.name): \(functionCall.arguments)"
36+
} else if let toolCalls = output.toolCalls {
37+
return toolCalls.map { "[\($0.id)] \($0.function.name): \($0.function.arguments)" }
38+
.joined(separator: "\n")
3839
}
39-
40+
4041
return ""
4142
}
4243
}

Tool/Sources/LangChain/Chains/RefineDocumentChain.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public final class RefineDocumentChain: Chain {
4242
}
4343

4444
class FunctionProvider: ChatGPTFunctionProvider {
45-
var functionCallStrategy: FunctionCallStrategy? = .name("respond")
45+
var functionCallStrategy: FunctionCallStrategy? = .function(name: "respond")
4646
var functions: [any ChatGPTFunction] = [RespondFunction()]
4747
}
4848

@@ -153,7 +153,7 @@ public final class RefineDocumentChain: Chain {
153153
}
154154

155155
func extractAnswer(_ chatMessage: ChatMessage) -> IntermediateAnswer {
156-
if let functionCall = chatMessage.functionCall {
156+
for functionCall in chatMessage.toolCalls?.map(\.function) ?? [] {
157157
do {
158158
let intermediateAnswer = try JSONDecoder().decode(
159159
IntermediateAnswer.self,

Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public final class RelevantInformationExtractionChain: Chain {
1515
public typealias Output = String
1616

1717
class FunctionProvider: ChatGPTFunctionProvider {
18-
var functionCallStrategy: FunctionCallStrategy? = .name("saveFinalAnswer")
18+
var functionCallStrategy: FunctionCallStrategy? = .function(name: "saveFinalAnswer")
1919
var functions: [any ChatGPTFunction] = [FinalAnswer()]
2020
}
2121

@@ -103,8 +103,10 @@ public final class RelevantInformationExtractionChain: Chain {
103103
taskInput,
104104
callbackManagers: callbackManagers
105105
)
106-
107-
if let functionCall = output.functionCall {
106+
107+
if let functionCall = output.toolCalls?
108+
.first(where: { $0.function.name == FinalAnswer().name })?.function
109+
{
108110
do {
109111
let arguments = try JSONDecoder().decode(
110112
FinalAnswer.Arguments.self,
@@ -118,7 +120,7 @@ public final class RelevantInformationExtractionChain: Chain {
118120
return output.content ?? ""
119121
}
120122
}
121-
123+
122124
return output.content ?? ""
123125
}
124126

Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public class StructuredOutputChatModelChain<Output: Decodable>: Chain {
5353
}
5454

5555
var functionCallStrategy: FunctionCallStrategy? {
56-
.name(endFunction.name)
56+
.function(name: endFunction.name)
5757
}
5858
}
5959

@@ -108,7 +108,7 @@ public class StructuredOutputChatModelChain<Output: Decodable>: Chain {
108108
}
109109

110110
public func parseOutput(_ message: ChatMessage) async -> Output? {
111-
if let functionCall = message.functionCall {
111+
if let functionCall = message.toolCalls?.first?.function {
112112
do {
113113
let result = try JSONDecoder().decode(
114114
EndFunction.Arguments.self,

0 commit comments

Comments
 (0)