Skip to content

Commit 2917628

Browse files
committed
Add dimensions settings to embedding models
1 parent 0c277c1 commit 2917628

3 files changed

Lines changed: 85 additions & 2 deletions

File tree

Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ struct EmbeddingModelEdit {
1515
var name: String
1616
var format: EmbeddingModel.Format
1717
var maxTokens: Int = 8191
18+
var dimensions: Int = 1536
1819
var modelName: String = ""
1920
var ollamaKeepAlive: String = ""
2021
var apiKeyName: String { apiKeySelection.apiKeyName }
@@ -38,6 +39,7 @@ struct EmbeddingModelEdit {
3839
case testButtonClicked
3940
case testSucceeded(String)
4041
case testFailed(String)
42+
case fixDimensions(Int)
4143
case checkSuggestedMaxTokens
4244
case apiKeySelection(APIKeySelection.Action)
4345
case baseURLSelection(BaseURLSelection.Action)
@@ -80,6 +82,7 @@ struct EmbeddingModelEdit {
8082
case .testButtonClicked:
8183
guard !state.isTesting else { return .none }
8284
state.isTesting = true
85+
let dimensions = state.dimensions
8386
let model = EmbeddingModel(
8487
id: state.id,
8588
name: state.name,
@@ -89,18 +92,33 @@ struct EmbeddingModelEdit {
8992
baseURL: state.baseURL,
9093
isFullURL: state.isFullURL,
9194
maxTokens: state.maxTokens,
95+
dimensions: dimensions,
9296
modelName: state.modelName
9397
)
9498
)
9599
return .run { send in
96100
do {
97-
_ = try await EmbeddingService(
101+
let result = try await EmbeddingService(
98102
configuration: UserPreferenceEmbeddingConfiguration()
99103
.overriding {
100104
$0.model = model
101105
}
102106
).embed(text: "Hello")
103-
await send(.testSucceeded("Succeeded!"))
107+
if result.data.isEmpty {
108+
await send(.testFailed("No data returned"))
109+
return
110+
}
111+
let actualDimensions = result.data.first?.embedding.count ?? 0
112+
if actualDimensions != dimensions {
113+
await send(
114+
.testFailed("Invalid dimension, should be \(actualDimensions)")
115+
)
116+
await send(.fixDimensions(actualDimensions))
117+
} else {
118+
await send(
119+
.testSucceeded("Succeeded! (Dimensions: \(actualDimensions))")
120+
)
121+
}
104122
} catch {
105123
await send(.testFailed(error.localizedDescription))
106124
}
@@ -131,6 +149,11 @@ struct EmbeddingModelEdit {
131149
return .none
132150
}
133151
state.suggestedMaxTokens = knownModel.maxToken
152+
state.dimensions = knownModel.dimensions
153+
return .none
154+
155+
case let .fixDimensions(value):
156+
state.dimensions = value
134157
return .none
135158

136159
case .apiKeySelection:

Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEditView.swift

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,51 @@ struct EmbeddingModelEditView: View {
174174
}
175175
}
176176
}
177+
178+
struct DimensionsTextField: View {
179+
@Perception.Bindable var store: StoreOf<EmbeddingModelEdit>
180+
181+
var body: some View {
182+
WithPerceptionTracking {
183+
HStack {
184+
let textFieldBinding = Binding(
185+
get: { String(store.dimensions) },
186+
set: {
187+
if let selectionDimensions = Int($0) {
188+
$store.dimensions.wrappedValue = selectionDimensions
189+
} else {
190+
$store.dimensions.wrappedValue = 0
191+
}
192+
}
193+
)
194+
195+
TextField(text: textFieldBinding) {
196+
Text("Dimensions")
197+
.multilineTextAlignment(.trailing)
198+
}
199+
.overlay(alignment: .trailing) {
200+
Stepper(
201+
value: $store.dimensions,
202+
in: 0...Int.max,
203+
step: 100
204+
) {
205+
EmptyView()
206+
}
207+
}
208+
.foregroundColor({
209+
if store.dimensions <= 0 {
210+
return .red
211+
}
212+
return .primary
213+
}() as Color)
214+
}
215+
216+
Text("If you are not sure, run test to get the correct value.")
217+
.font(.caption)
218+
.dynamicHeightTextInFormWorkaround()
219+
}
220+
}
221+
}
177222

178223
struct ApiKeyNamePicker: View {
179224
let store: StoreOf<EmbeddingModelEdit>
@@ -215,6 +260,7 @@ struct EmbeddingModelEditView: View {
215260
}
216261

217262
MaxTokensTextField(store: store)
263+
DimensionsTextField(store: store)
218264

219265
VStack(alignment: .leading, spacing: 8) {
220266
Text(Image(systemName: "exclamationmark.triangle.fill")) + Text(
@@ -242,6 +288,7 @@ struct EmbeddingModelEditView: View {
242288
TextField("Deployment Name", text: $store.modelName)
243289

244290
MaxTokensTextField(store: store)
291+
DimensionsTextField(store: store)
245292
}
246293
}
247294
}
@@ -279,6 +326,7 @@ struct EmbeddingModelEditView: View {
279326
TextField("Model Name", text: $store.modelName)
280327

281328
MaxTokensTextField(store: store)
329+
DimensionsTextField(store: store)
282330

283331
Button("Custom Headers") {
284332
isEditingCustomHeader.toggle()
@@ -299,6 +347,7 @@ struct EmbeddingModelEditView: View {
299347
TextField("Model Name", text: $store.modelName)
300348

301349
MaxTokensTextField(store: store)
350+
DimensionsTextField(store: store)
302351

303352
WithPerceptionTracking {
304353
TextField(text: $store.ollamaKeepAlive, prompt: Text("Default Value")) {

Tool/Sources/Preferences/Types/OpenAIEmbeddingModel.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,16 @@ public extension OpenAIEmbeddingModel {
1515
return 8191
1616
}
1717
}
18+
19+
var dimensions: Int {
20+
switch self {
21+
case .textEmbeddingAda002:
22+
return 1536
23+
case .textEmbedding3Small:
24+
return 1536
25+
case .textEmbedding3Large:
26+
return 3072
27+
}
28+
}
1829
}
1930

0 commit comments

Comments
 (0)