Skip to content

Commit a707645

Browse files
committed
Support Azure OpenAI as chat provider
1 parent 182d426 commit a707645

8 files changed

Lines changed: 186 additions & 22 deletions

File tree

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import AppKit
2+
import Client
3+
import OpenAIService
4+
import Preferences
5+
import SuggestionModel
6+
import SwiftUI
7+
8+
final class AzureViewSettings: ObservableObject {
9+
@AppStorage(\.azureOpenAIAPIKey) var azureOpenAIAPIKey: String
10+
@AppStorage(\.azureOpenAIBaseURL) var azureOpenAIBaseURL: String
11+
@AppStorage(\.azureChatGPTDeployment) var azureChatGPTDeployment: String
12+
init() {}
13+
}
14+
15+
struct AzureView: View {
16+
@Environment(\.toast) var toast
17+
@State var isTesting = false
18+
@StateObject var settings = AzureViewSettings()
19+
20+
var body: some View {
21+
Form {
22+
SecureField(text: $settings.azureOpenAIAPIKey, prompt: Text("")) {
23+
Text("OpenAI Service API Key")
24+
}
25+
.textFieldStyle(.roundedBorder)
26+
27+
TextField(
28+
text: $settings.azureOpenAIBaseURL,
29+
prompt: Text("https://XXXXXX.openai.azure.com")
30+
) {
31+
Text("OpenAI Service Base URL")
32+
}.textFieldStyle(.roundedBorder)
33+
34+
HStack {
35+
TextField(
36+
text: $settings.azureChatGPTDeployment,
37+
prompt: Text("")
38+
) {
39+
Text("Chat Model Deployment Name")
40+
}.textFieldStyle(.roundedBorder)
41+
42+
Button("Test") {
43+
Task { @MainActor in
44+
isTesting = true
45+
defer { isTesting = false }
46+
do {
47+
let reply = try await ChatGPTService(designatedProvider: .azureOpenAI)
48+
.sendAndWait(content: "Hello", summary: nil)
49+
toast(Text("ChatGPT replied: \(reply ?? "N/A")"), .info)
50+
} catch {
51+
toast(Text(error.localizedDescription), .error)
52+
}
53+
}
54+
}
55+
.disabled(isTesting)
56+
}
57+
}
58+
}
59+
}
60+
61+
struct AzureView_Previews: PreviewProvider {
62+
static var previews: some View {
63+
VStack(alignment: .leading, spacing: 8) {
64+
AzureView()
65+
}
66+
.frame(height: 800)
67+
.padding(.all, 8)
68+
}
69+
}
70+

Core/Sources/HostApp/AccountSettings/OpenAIView.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct OpenAIView: View {
5252
Button("Test") {
5353
Task {
5454
do {
55-
let reply = try await ChatGPTService()
55+
let reply = try await ChatGPTService(designatedProvider: .openAI)
5656
.sendAndWait(content: "Hello", summary: nil)
5757
toast(Text("ChatGPT replied: \(reply ?? "N/A")"), .info)
5858
} catch {

Core/Sources/HostApp/ServiceView.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ struct ServiceView: View {
3030
subtitle: "Chat, Prompt to Code",
3131
image: "globe"
3232
)
33+
34+
ScrollView {
35+
AzureView().padding()
36+
}.sidebarItem(
37+
tag: 3,
38+
title: "Azure",
39+
subtitle: "Chat, Prompt to Code",
40+
image: "globe"
41+
)
3342
}
3443
}
3544
}

Core/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@ public protocol ChatGPTServiceType: ObservableObject {
1717
public enum ChatGPTServiceError: Error, LocalizedError {
1818
case endpointIncorrect
1919
case responseInvalid
20+
case otherError(String)
2021

2122
public var errorDescription: String? {
2223
switch self {
2324
case .endpointIncorrect:
2425
return "ChatGPT endpoint is incorrect"
2526
case .responseInvalid:
2627
return "Response is invalid"
28+
case let .otherError(content):
29+
return content
2730
}
2831
}
2932
}
@@ -59,7 +62,7 @@ public actor ChatGPTService: ChatGPTServiceType {
5962
public var defaultTemperature: Double {
6063
min(max(0, UserDefaults.shared.value(for: \.chatGPTTemperature)), 2)
6164
}
62-
65+
6366
var temperature: Double?
6467

6568
public var model: String {
@@ -68,14 +71,30 @@ public actor ChatGPTService: ChatGPTServiceType {
6871
return value
6972
}
7073

74+
var designatedProvider: ChatFeatureProvider?
75+
7176
public var endpoint: String {
72-
var baseURL = UserDefaults.shared.value(for: \.openAIBaseURL)
73-
if baseURL.isEmpty { return "https://api.openai.com/v1/chat/completions" }
74-
return "\(baseURL)/v1/chat/completions"
77+
switch designatedProvider ?? UserDefaults.shared.value(for: \.chatFeatureProvider) {
78+
case .openAI:
79+
let baseURL = UserDefaults.shared.value(for: \.openAIBaseURL)
80+
if baseURL.isEmpty { return "https://api.openai.com/v1/chat/completions" }
81+
return "\(baseURL)/v1/chat/completions"
82+
case .azureOpenAI:
83+
let baseURL = UserDefaults.shared.value(for: \.azureOpenAIBaseURL)
84+
let deployment = UserDefaults.shared.value(for: \.azureChatGPTDeployment)
85+
let version = "2023-05-15"
86+
if baseURL.isEmpty { return "" }
87+
return "\(baseURL)/openai/deployments/\(deployment)/chat/completions?api-version=\(version)"
88+
}
7589
}
7690

7791
public var apiKey: String {
78-
UserDefaults.shared.value(for: \.openAIAPIKey)
92+
switch designatedProvider ?? UserDefaults.shared.value(for: \.chatFeatureProvider) {
93+
case .openAI:
94+
return UserDefaults.shared.value(for: \.openAIAPIKey)
95+
case .azureOpenAI:
96+
return UserDefaults.shared.value(for: \.azureOpenAIAPIKey)
97+
}
7998
}
8099

81100
public var maxToken: Int {
@@ -97,10 +116,12 @@ public actor ChatGPTService: ChatGPTServiceType {
97116

98117
public init(
99118
systemPrompt: String = "",
100-
temperature: Double? = nil
119+
temperature: Double? = nil,
120+
designatedProvider: ChatFeatureProvider? = nil
101121
) {
102122
self.systemPrompt = systemPrompt
103123
self.temperature = temperature
124+
self.designatedProvider = designatedProvider
104125
}
105126

106127
public func send(
@@ -129,7 +150,12 @@ public actor ChatGPTService: ChatGPTServiceType {
129150

130151
isReceivingMessage = true
131152

132-
let api = buildCompletionStreamAPI(apiKey, url, requestBody)
153+
let api = buildCompletionStreamAPI(
154+
apiKey,
155+
designatedProvider ?? UserDefaults.shared.value(for: \.chatFeatureProvider),
156+
url,
157+
requestBody
158+
)
133159

134160
return AsyncThrowingStream<String, Error> { continuation in
135161
Task {
@@ -210,7 +236,12 @@ public actor ChatGPTService: ChatGPTServiceType {
210236
isReceivingMessage = true
211237
defer { isReceivingMessage = false }
212238

213-
let api = buildCompletionAPI(apiKey, url, requestBody)
239+
let api = buildCompletionAPI(
240+
apiKey,
241+
designatedProvider ?? UserDefaults.shared.value(for: \.chatFeatureProvider),
242+
url,
243+
requestBody
244+
)
214245
let response = try await api()
215246

216247
if let choice = response.choices.first {
@@ -296,3 +327,4 @@ func maxTokenForReply(model: String, remainingTokens: Int) -> Int {
296327
guard let model = ChatGPTModel(rawValue: model) else { return remainingTokens }
297328
return min(model.maxToken / 2, remainingTokens)
298329
}
330+

Core/Sources/OpenAIService/CompletionAPI.swift

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import Foundation
2+
import Preferences
23

3-
typealias CompletionAPIBuilder = (String, URL, CompletionRequestBody) -> CompletionAPI
4+
typealias CompletionAPIBuilder = (String, ChatFeatureProvider, URL, CompletionRequestBody)
5+
-> CompletionAPI
46

57
protocol CompletionAPI {
68
func callAsFunction() async throws -> CompletionResponseBody
@@ -12,13 +14,13 @@ struct CompletionResponseBody: Codable, Equatable {
1214
var role: ChatMessage.Role
1315
var content: String
1416
}
15-
17+
1618
struct Choice: Codable, Equatable {
1719
var message: Message
1820
var index: Int
1921
var finish_reason: String
2022
}
21-
23+
2224
struct Usage: Codable, Equatable {
2325
var prompt_tokens: Int
2426
var completion_tokens: Int
@@ -40,21 +42,29 @@ struct CompletionAPIError: Error, Codable, LocalizedError {
4042
var param: String
4143
var code: String
4244
}
45+
4346
var error: E
44-
47+
4548
var errorDescription: String? { error.message }
4649
}
4750

4851
struct OpenAICompletionAPI: CompletionAPI {
4952
var apiKey: String
5053
var endpoint: URL
5154
var requestBody: CompletionRequestBody
55+
var provider: ChatFeatureProvider
5256

53-
init(apiKey: String, endpoint: URL, requestBody: CompletionRequestBody) {
57+
init(
58+
apiKey: String,
59+
provider: ChatFeatureProvider,
60+
endpoint: URL,
61+
requestBody: CompletionRequestBody
62+
) {
5463
self.apiKey = apiKey
5564
self.endpoint = endpoint
5665
self.requestBody = requestBody
5766
self.requestBody.stream = false
67+
self.provider = provider
5868
}
5969

6070
func callAsFunction() async throws -> CompletionResponseBody {
@@ -64,7 +74,11 @@ struct OpenAICompletionAPI: CompletionAPI {
6474
request.httpBody = try encoder.encode(requestBody)
6575
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
6676
if !apiKey.isEmpty {
67-
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
77+
if provider == .openAI {
78+
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
79+
} else {
80+
request.setValue(apiKey, forHTTPHeaderField: "api-key")
81+
}
6882
}
6983

7084
let (result, response) = try await URLSession.shared.data(for: request)
@@ -74,9 +88,11 @@ struct OpenAICompletionAPI: CompletionAPI {
7488

7589
guard response.statusCode == 200 else {
7690
let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result)
77-
throw error ?? ChatGPTServiceError.responseInvalid
91+
throw error ?? ChatGPTServiceError
92+
.otherError(String(data: result, encoding: .utf8) ?? "Unknown Error")
7893
}
79-
94+
8095
return try JSONDecoder().decode(CompletionResponseBody.self, from: result)
8196
}
8297
}
98+

Core/Sources/OpenAIService/CompletionStreamAPI.swift

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import AsyncAlgorithms
22
import Foundation
3+
import Preferences
34

4-
typealias CompletionStreamAPIBuilder = (String, URL, CompletionRequestBody) -> CompletionStreamAPI
5+
typealias CompletionStreamAPIBuilder = (String, ChatFeatureProvider, URL, CompletionRequestBody) -> CompletionStreamAPI
56

67
protocol CompletionStreamAPI {
78
func callAsFunction() async throws -> (
@@ -54,12 +55,19 @@ struct OpenAICompletionStreamAPI: CompletionStreamAPI {
5455
var apiKey: String
5556
var endpoint: URL
5657
var requestBody: CompletionRequestBody
58+
var provider: ChatFeatureProvider
5759

58-
init(apiKey: String, endpoint: URL, requestBody: CompletionRequestBody) {
60+
init(
61+
apiKey: String,
62+
provider: ChatFeatureProvider,
63+
endpoint: URL,
64+
requestBody: CompletionRequestBody
65+
) {
5966
self.apiKey = apiKey
6067
self.endpoint = endpoint
6168
self.requestBody = requestBody
6269
self.requestBody.stream = true
70+
self.provider = provider
6371
}
6472

6573
func callAsFunction() async throws -> (
@@ -72,7 +80,11 @@ struct OpenAICompletionStreamAPI: CompletionStreamAPI {
7280
request.httpBody = try encoder.encode(requestBody)
7381
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
7482
if !apiKey.isEmpty {
75-
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
83+
if provider == .openAI {
84+
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
85+
} else {
86+
request.setValue(apiKey, forHTTPHeaderField: "api-key")
87+
}
7688
}
7789

7890
let (result, response) = try await URLSession.shared.bytes(for: request)
@@ -90,9 +102,9 @@ struct OpenAICompletionStreamAPI: CompletionStreamAPI {
90102
let error = try? decoder.decode(ChatGPTError.self, from: data)
91103
throw error ?? ChatGPTServiceError.responseInvalid
92104
}
93-
105+
94106
var receivingDataTask: Task<Void, Error>?
95-
107+
96108
let stream = AsyncThrowingStream<CompletionStreamDataTrunk, Error> { continuation in
97109
receivingDataTask = Task {
98110
do {
@@ -122,3 +134,4 @@ struct OpenAICompletionStreamAPI: CompletionStreamAPI {
122134
)
123135
}
124136
}
137+
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
public enum ChatFeatureProvider: String, CaseIterable {
2+
case openAI
3+
case azureOpenAI
4+
}

Core/Sources/Preferences/Keys.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,22 @@ public extension UserDefaultPreferenceKeys {
9999
}
100100
}
101101

102+
// MARK: - Azure OpenAI Settings
103+
104+
public extension UserDefaultPreferenceKeys {
105+
var azureOpenAIAPIKey: PreferenceKey<String> {
106+
.init(defaultValue: "", key: "AzureOpenAIAPIKey")
107+
}
108+
109+
var azureOpenAIBaseURL: PreferenceKey<String> {
110+
.init(defaultValue: "", key: "AzureOpenAIBaseURL")
111+
}
112+
113+
var azureChatGPTDeployment: PreferenceKey<String> {
114+
.init(defaultValue: "", key: "AzureChatGPTDeployment")
115+
}
116+
}
117+
102118
// MARK: - GitHub Copilot Settings
103119

104120
public extension UserDefaultPreferenceKeys {
@@ -186,6 +202,10 @@ public extension UserDefaultPreferenceKeys {
186202
// MARK: - Chat
187203

188204
public extension UserDefaultPreferenceKeys {
205+
var chatFeatureProvider: PreferenceKey<ChatFeatureProvider> {
206+
.init(defaultValue: .openAI, key: "ChatFeatureProvider")
207+
}
208+
189209
var chatFontSize: PreferenceKey<Double> {
190210
.init(defaultValue: 12, key: "ChatFontSize")
191211
}

0 commit comments

Comments
 (0)