Skip to content

Commit 204b4f3

Browse files
committed
Fix importing large packages like langchain
1 parent 878cda6 commit 204b4f3

File tree

7 files changed

+223
-70
lines changed

7 files changed

+223
-70
lines changed
Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import Foundation
22
import LangChain
33
import PythonKit
4+
import PythonHelper
45

56
func solveMathProblem(_ problem: String) async throws -> String {
67
#if DEBUG
@@ -15,20 +16,16 @@ func solveMathProblem(_ problem: String) async throws -> String {
1516
}
1617
}
1718

18-
let task = Task {
19-
try withReadableThrowingPython {
20-
let langchain = try Python.attemptImport("langchain")
21-
let LLMMathChain = langchain.LLMMathChain
22-
let llm = try LangChainChatModel.DynamicChatOpenAI(temperature: 0)
23-
let llmMath = LLMMathChain.from_llm(llm, verbose: verbose)
24-
let result = try llmMath.run.throwing.dynamicallyCall(withArguments: problem)
25-
let answer = String(result)
26-
if let answer { return answer }
19+
return try await runPython {
20+
let langchain = try Python.attemptImportOnPythonThread("langchain")
21+
let LLMMathChain = langchain.LLMMathChain
22+
let llm = try LangChainChatModel.DynamicChatOpenAI(temperature: 0)
23+
let llmMath = LLMMathChain.from_llm(llm, verbose: verbose)
24+
let result = try llmMath.run.throwing.dynamicallyCall(withArguments: problem)
25+
let answer = String(result)
26+
if let answer { return answer }
2727

28-
throw E()
29-
}
28+
throw E()
3029
}
31-
32-
return try await task.value
3330
}
3431

ExtensionService/InitializePython.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import Foundation
22
import Python
33
import PythonKit
4+
import PythonHelper
45

56
func initializePython() {
67
guard let sitePackagePath = Bundle.main.path(forResource: "site-packages", ofType: nil),
@@ -10,8 +11,32 @@ func initializePython() {
1011
ofType: nil
1112
)
1213
else { return }
14+
1315
setenv("PYTHONHOME", stdLibPath, 1)
1416
setenv("PYTHONPATH", "\(stdLibPath):\(libDynloadPath):\(sitePackagePath)", 1)
17+
18+
// Initialize python
1519
Py_Initialize()
20+
21+
// Immediately release the thread, so that we can ensure the GIL state later.
22+
// We may not recover the thread because all future tasks will be done in the Python Thread.
23+
let _ = PyEval_SaveThread()
24+
25+
// Setup GIL state guard.
26+
PythonHelper.PyGILState_Guard = { closure in
27+
let gilState = PyGILState_Ensure()
28+
try closure()
29+
PyGILState_Release(gilState)
30+
}
31+
32+
Task {
33+
// All future task should run inside runPython.
34+
try await runPython {
35+
let sys = Python.import("sys")
36+
print("Python Version: \(sys.version_info.major).\(sys.version_info.minor)")
37+
}
38+
}
1639
}
1740

41+
let queue = DispatchQueue(label: "")
42+

Tool/Package.swift

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ let package = Package(
88
platforms: [.macOS(.v12)],
99
products: [
1010
.library(name: "Terminal", targets: ["Terminal"]),
11-
.library(name: "LangChain", targets: ["LangChain"]),
11+
.library(name: "LangChain", targets: ["LangChain", "PythonHelper"]),
1212
.library(name: "Preferences", targets: ["Preferences", "Configs"]),
1313
],
1414
dependencies: [
@@ -27,6 +27,14 @@ let package = Package(
2727

2828
.target(
2929
name: "LangChain",
30+
dependencies: [
31+
"PythonHelper",
32+
.product(name: "PythonKit", package: "PythonKit"),
33+
]
34+
),
35+
36+
.target(
37+
name: "PythonHelper",
3038
dependencies: [
3139
.product(name: "PythonKit", package: "PythonKit"),
3240
]

Tool/Sources/LangChain/LangChainChatModel.swift

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import Foundation
22
import Preferences
3+
import PythonHelper
34
import PythonKit
45

56
public enum LangChainChatModel {
@@ -11,32 +12,28 @@ public enum LangChainChatModel {
1112
let model = UserDefaults.shared.value(for: \.chatGPTModel)
1213
let apiBaseURL = UserDefaults.shared.value(for: \.openAIBaseURL)
1314
let apiKey = UserDefaults.shared.value(for: \.openAIAPIKey)
14-
return try withReadableThrowingPython {
15-
let chatModels = try Python.attemptImport("langchain.chat_models")
16-
let ChatOpenAI = chatModels.ChatOpenAI
17-
return ChatOpenAI(
18-
temperature: temperature,
19-
model: model,
20-
openai_api_base: "\(apiBaseURL)/v1",
21-
openai_api_key: apiKey
22-
)
23-
}
15+
let chatModels = try Python.attemptImportOnPythonThread("langchain.chat_models")
16+
let ChatOpenAI = chatModels.ChatOpenAI
17+
return ChatOpenAI(
18+
temperature: temperature,
19+
model: model,
20+
openai_api_base: "\(apiBaseURL)/v1",
21+
openai_api_key: apiKey
22+
)
2423
case .azureOpenAI:
2524
let apiBaseURL = UserDefaults.shared.value(for: \.azureOpenAIBaseURL)
2625
let apiKey = UserDefaults.shared.value(for: \.azureOpenAIAPIKey)
2726
let deployment = UserDefaults.shared.value(for: \.azureChatGPTDeployment)
28-
return try withReadableThrowingPython {
29-
let chatModels = try Python.attemptImport("langchain.chat_models")
30-
let ChatOpenAI = chatModels.AzureChatOpenAI
31-
return ChatOpenAI(
32-
temperature: temperature,
33-
openai_api_type: "azure",
34-
openai_api_version: "2023-03-15-preview",
35-
deployment_name: deployment,
36-
openai_api_base: apiBaseURL,
37-
openai_api_key: apiKey
38-
)
39-
}
27+
let chatModels = try Python.attemptImportOnPythonThread("langchain.chat_models")
28+
let ChatOpenAI = chatModels.AzureChatOpenAI
29+
return ChatOpenAI(
30+
temperature: temperature,
31+
openai_api_type: "azure",
32+
openai_api_version: "2023-03-15-preview",
33+
deployment_name: deployment,
34+
openai_api_base: apiBaseURL,
35+
openai_api_key: apiKey
36+
)
4037
}
4138
}
4239
}

Tool/Sources/LangChain/ReadablePythonError.swift

Lines changed: 0 additions & 34 deletions
This file was deleted.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import Foundation
2+
import PythonKit
3+
4+
final class PythonThread: Thread {
5+
static let shared = {
6+
let thread = PythonThread(
7+
target: PythonThread.self,
8+
selector: #selector(PythonThread.setup),
9+
object: nil
10+
)
11+
thread.name = "Python Thread"
12+
thread.stackSize = 1_048_576 // so that langchain can be correctly imported.
13+
return thread
14+
}()
15+
16+
@objc static func setup() {
17+
CFRunLoopRun()
18+
}
19+
20+
@objc static func runPythonJob(_ job: PythonJob) {
21+
job.run()
22+
}
23+
24+
func runPython(_ closure: @escaping () -> Void) {
25+
if !isExecuting {
26+
start()
27+
}
28+
29+
if Thread.current.isEqual(self) {
30+
closure()
31+
} else {
32+
PythonThread.perform(
33+
#selector(PythonThread.runPythonJob),
34+
on: self,
35+
with: PythonJob(closure: closure),
36+
waitUntilDone: false
37+
)
38+
}
39+
}
40+
41+
func runPythonAndWait<T>(_ closure: @escaping () throws -> T) throws -> T {
42+
if !isExecuting {
43+
start()
44+
}
45+
46+
if Thread.current.isEqual(self) {
47+
return try closure()
48+
} else {
49+
let job = PythonJob(closure: closure)
50+
PythonThread.perform(
51+
#selector(PythonThread.runPythonJob),
52+
on: self,
53+
with: job,
54+
waitUntilDone: true
55+
)
56+
guard let result = job.result else {
57+
throw FailedToGetPythonJobResultError()
58+
}
59+
switch result {
60+
case let .success(value):
61+
if let value = value as? T {
62+
return value
63+
} else {
64+
throw FailedToGetPythonJobResultError()
65+
}
66+
case let .failure(error):
67+
throw error
68+
}
69+
}
70+
}
71+
}
72+
73+
struct FailedToGetPythonJobResultError: Error, LocalizedError {
74+
var errorDescription: String? {
75+
"Failed to get PythonJob result."
76+
}
77+
}
78+
79+
final class PythonJob: NSObject {
80+
let closure: () throws -> Any
81+
var result: Result<Any, Error>?
82+
init(closure: @escaping () throws -> Any) {
83+
self.closure = closure
84+
}
85+
86+
func run() {
87+
result = Result(catching: closure)
88+
}
89+
}
90+
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import Foundation
2+
import PythonKit
3+
4+
public var PyGILState_Guard: ((() throws -> Void) throws -> Void)! = nil
5+
6+
let pythonQueue = DispatchQueue(label: "Python Queue")
7+
8+
public func runPython<T>(
9+
usePythonThread: Bool = false,
10+
_ closure: @escaping () throws -> T
11+
) async throws -> T {
12+
return try await withUnsafeThrowingContinuation { con in
13+
if usePythonThread {
14+
PythonThread.shared.runPython {
15+
do {
16+
try PyGILState_Guard {
17+
con.resume(returning: try closure())
18+
}
19+
} catch let error as PythonError {
20+
con.resume(throwing: ReadablePythonError(error))
21+
} catch {
22+
con.resume(throwing: error)
23+
}
24+
}
25+
} else {
26+
pythonQueue.async {
27+
do {
28+
try PyGILState_Guard {
29+
con.resume(returning: try closure())
30+
}
31+
} catch let error as PythonError {
32+
con.resume(throwing: ReadablePythonError(error))
33+
} catch {
34+
con.resume(throwing: error)
35+
}
36+
}
37+
}
38+
}
39+
}
40+
41+
public extension PythonInterface {
42+
func attemptImportOnPythonThread(_ name: String) throws -> PythonObject {
43+
try PythonThread.shared.runPythonAndWait {
44+
let module = try Python.attemptImport(name)
45+
return module
46+
}
47+
}
48+
}
49+
50+
public struct ReadablePythonError: Error, LocalizedError {
51+
public var error: PythonError
52+
53+
public init(_ error: PythonError) {
54+
self.error = error
55+
}
56+
57+
public var errorDescription: String? {
58+
switch error {
59+
case let .exception(object, _):
60+
return "\(object)"
61+
case let .invalidCall(object):
62+
return "Invalid call: \(object)"
63+
case let .invalidModule(module):
64+
return "Invalid module: \(module)"
65+
}
66+
}
67+
}
68+
69+
70+

0 commit comments

Comments
 (0)