Skip to content

Commit ad8f6ec

Browse files
committed
Fix modification agent
1 parent e19dfc1 commit ad8f6ec

5 files changed

Lines changed: 587 additions & 22 deletions

File tree

Core/Sources/PromptToCodeService/OpenAIPromptToCodeService.swift

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ public final class SimpleModificationAgent: ModificationAgent {
2626
generateDescriptionRequirement: false
2727
)
2828

29-
for try await (code, description) in stream {
30-
continuation.yield(.code(code))
29+
for try await response in stream {
30+
continuation.yield(response)
3131
}
3232

3333
continuation.finish()
@@ -51,7 +51,7 @@ public final class SimpleModificationAgent: ModificationAgent {
5151
isDetached: Bool,
5252
extraSystemPrompt: String?,
5353
generateDescriptionRequirement: Bool?
54-
) async throws -> AsyncThrowingStream<(code: String, description: String), Error> {
54+
) async throws -> AsyncThrowingStream<Response, Error> {
5555
let userPreferredLanguage = UserDefaults.shared.value(for: \.chatGPTLanguage)
5656
let textLanguage = {
5757
if !UserDefaults.shared
@@ -226,32 +226,38 @@ public final class SimpleModificationAgent: ModificationAgent {
226226
history.append(.init(role: .user, content: requirement))
227227
}
228228
}
229-
let stream = chatGPTService.send(memory).compactMap { response in
230-
switch response {
231-
case let .partialText(token): return token
232-
default: return nil
233-
}
234-
}.eraseToThrowingStream()
235-
229+
let stream = chatGPTService.send(memory)
230+
236231
return .init { continuation in
237-
Task {
238-
var content = ""
239-
var extracted = extractCodeAndDescription(from: content)
232+
let task = Task {
233+
let parser = ExplanationThenCodeStreamParser()
240234
do {
241-
for try await fragment in stream {
242-
content.append(fragment)
243-
extracted = extractCodeAndDescription(from: content)
244-
if !content.isEmpty, extracted.code.isEmpty {
245-
continuation.yield((code: content, description: ""))
246-
} else {
247-
continuation.yield(extracted)
235+
func yield(fragments: [ExplanationThenCodeStreamParser.Fragment]) {
236+
for fragment in fragments {
237+
switch fragment {
238+
case let .code(code):
239+
continuation.yield(.code(code))
240+
case let .explanation(explanation):
241+
continuation.yield(.explanation(explanation))
242+
}
248243
}
249244
}
245+
246+
for try await response in stream {
247+
guard case let .partialText(fragment) = response else { continue }
248+
try Task.checkCancellation()
249+
await yield(fragments: parser.yield(fragment))
250+
}
251+
await yield(fragments: parser.finish())
250252
continuation.finish()
251253
} catch {
252254
continuation.finish(throwing: error)
253255
}
254256
}
257+
258+
continuation.onTermination = { _ in
259+
task.cancel()
260+
}
255261
}
256262
}
257263
}

Core/Sources/SuggestionWidget/FeatureReducers/PromptToCodePanel.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,8 @@ public struct PromptToCodeSnippetPanel {
324324
return .none
325325

326326
case let .modifyCodeChunkReceived(code, description):
327-
state.snippet.modifiedCode = code
328-
state.snippet.description = description
327+
state.snippet.modifiedCode += code
328+
state.snippet.description += description
329329
return .none
330330

331331
case let .modifyCodeFailed(error):

Tool/Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ let package = Package(
231231
),
232232
]
233233
),
234+
.testTarget(name: "ModificationBasicTests", dependencies: ["ModificationBasic"]),
234235

235236
.target(
236237
name: "PromptToCodeCustomization",
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
import Foundation
2+
3+
/// Parse a stream that contains explanation followed by a code block.
4+
public actor ExplanationThenCodeStreamParser {
5+
enum State {
6+
case explanation
7+
case code
8+
case codeOpening
9+
case codeClosing
10+
}
11+
12+
public enum Fragment: Sendable {
13+
case explanation(String)
14+
case code(String)
15+
}
16+
17+
struct Buffer {
18+
var content: String = ""
19+
}
20+
21+
var _buffer: Buffer = .init()
22+
var isAtBeginning = true
23+
var buffer: String { _buffer.content }
24+
var state: State = .explanation
25+
let fullCodeDelimiter = "```"
26+
27+
public init() {}
28+
29+
private func appendBuffer(_ character: Character) {
30+
_buffer.content.append(character)
31+
}
32+
33+
private func appendBuffer(_ content: String) {
34+
_buffer.content += content
35+
}
36+
37+
private func resetBuffer() {
38+
_buffer.content = ""
39+
}
40+
41+
func flushBuffer() -> String? {
42+
if buffer.isEmpty { return nil }
43+
guard let targetIndex = _buffer.content.lastIndex(where: { $0 != "`" && !$0.isNewline })
44+
else { return nil }
45+
let prefix = _buffer.content[...targetIndex]
46+
if prefix.isEmpty { return nil }
47+
let nextIndex = _buffer.content.index(
48+
targetIndex,
49+
offsetBy: 1,
50+
limitedBy: _buffer.content.endIndex
51+
) ?? _buffer.content.endIndex
52+
53+
if nextIndex == _buffer.content.endIndex {
54+
_buffer.content = ""
55+
} else {
56+
_buffer.content = String(
57+
_buffer.content[nextIndex...]
58+
)
59+
}
60+
61+
// If we flushed something, we are no longer at the beginning
62+
isAtBeginning = false
63+
return String(prefix)
64+
}
65+
66+
func flushBufferIfNeeded(into results: inout [Fragment]) {
67+
switch state {
68+
case .explanation:
69+
if let flushed = flushBuffer() {
70+
results.append(.explanation(flushed))
71+
}
72+
case .code:
73+
if let flushed = flushBuffer() {
74+
results.append(.code(flushed))
75+
}
76+
case .codeOpening, .codeClosing:
77+
break
78+
}
79+
}
80+
81+
public func yield(_ fragment: String) -> [Fragment] {
82+
var results: [Fragment] = []
83+
84+
func flushBuffer() {
85+
flushBufferIfNeeded(into: &results)
86+
}
87+
88+
for character in fragment {
89+
switch state {
90+
case .explanation:
91+
func forceFlush() {
92+
if !buffer.isEmpty {
93+
isAtBeginning = false
94+
results.append(.explanation(buffer))
95+
resetBuffer()
96+
}
97+
}
98+
99+
switch character {
100+
case "`":
101+
if let last = buffer.last, last == "`" || last.isNewline {
102+
flushBuffer()
103+
// if we are seeing the pattern of "\n`" or "``"
104+
// that mean we may be hitting a code delimiter
105+
appendBuffer(character)
106+
let shouldOpenCodeBlock: Bool = {
107+
guard buffer.hasSuffix(fullCodeDelimiter)
108+
else { return false }
109+
if isAtBeginning { return true }
110+
let temp = String(buffer.dropLast(fullCodeDelimiter.count))
111+
if let last = temp.last, last.isNewline {
112+
return true
113+
}
114+
return false
115+
}()
116+
// if we meet a code delimiter while in explanation state,
117+
// it means we are opening a code block
118+
if shouldOpenCodeBlock {
119+
results.append(.explanation(
120+
String(buffer.dropLast(fullCodeDelimiter.count))
121+
.trimmingTrailingCharacters(in: .whitespacesAndNewlines)
122+
))
123+
resetBuffer()
124+
state = .codeOpening
125+
}
126+
} else {
127+
// Otherwise, the backtick is probably part of the explanation.
128+
forceFlush()
129+
appendBuffer(character)
130+
}
131+
case let char where char.isNewline:
132+
// we keep the trailing new lines in case they are right
133+
// ahead of the code block that should be ignored.
134+
if let last = buffer.last, last.isNewline {
135+
flushBuffer()
136+
appendBuffer(character)
137+
} else {
138+
forceFlush()
139+
appendBuffer(character)
140+
}
141+
default:
142+
appendBuffer(character)
143+
}
144+
case .code:
145+
func forceFlush() {
146+
if !buffer.isEmpty {
147+
isAtBeginning = false
148+
results.append(.code(buffer))
149+
resetBuffer()
150+
}
151+
}
152+
153+
switch character {
154+
case "`":
155+
if let last = buffer.last, last == "`" || last.isNewline {
156+
flushBuffer()
157+
// if we are seeing the pattern of "\n`" or "``"
158+
// that mean we may be hitting a code delimiter
159+
appendBuffer(character)
160+
let possibleClosingDelimiter: String? = {
161+
guard buffer.hasSuffix(fullCodeDelimiter) else { return nil }
162+
let temp = String(buffer.dropLast(fullCodeDelimiter.count))
163+
if let last = temp.last, last.isNewline {
164+
return "\(last)\(fullCodeDelimiter)"
165+
}
166+
return nil
167+
}()
168+
// if we meet a code delimiter while in code state,
169+
// // it means we are closing the code block
170+
if let possibleClosingDelimiter {
171+
results.append(.code(
172+
String(buffer.dropLast(possibleClosingDelimiter.count))
173+
))
174+
resetBuffer()
175+
appendBuffer(possibleClosingDelimiter)
176+
state = .codeClosing
177+
}
178+
} else {
179+
// Otherwise, the backtick is probably part of the code.
180+
forceFlush()
181+
appendBuffer(character)
182+
}
183+
184+
case let char where char.isNewline:
185+
if let last = buffer.last, last.isNewline {
186+
flushBuffer()
187+
appendBuffer(character)
188+
} else {
189+
forceFlush()
190+
appendBuffer(character)
191+
}
192+
default:
193+
appendBuffer(character)
194+
}
195+
case .codeOpening:
196+
// skip the code block fence
197+
if character.isNewline {
198+
state = .code
199+
}
200+
case .codeClosing:
201+
appendBuffer(character)
202+
switch character {
203+
case "`":
204+
let possibleClosingDelimiter: String? = {
205+
guard buffer.hasSuffix(fullCodeDelimiter) else { return nil }
206+
let temp = String(buffer.dropLast(fullCodeDelimiter.count))
207+
if let last = temp.last, last.isNewline {
208+
return "\(last)\(fullCodeDelimiter)"
209+
}
210+
return nil
211+
}()
212+
// if we meet another code delimiter while in codeClosing state,
213+
// it means the previous code delimiter was part of the code
214+
if let possibleClosingDelimiter {
215+
results.append(.code(
216+
String(buffer.dropLast(possibleClosingDelimiter.count))
217+
))
218+
resetBuffer()
219+
appendBuffer(possibleClosingDelimiter)
220+
}
221+
default:
222+
break
223+
}
224+
}
225+
}
226+
227+
flushBuffer()
228+
229+
return results
230+
}
231+
232+
public func finish() -> [Fragment] {
233+
guard !buffer.isEmpty else { return [] }
234+
235+
var results: [Fragment] = []
236+
switch state {
237+
case .explanation:
238+
results.append(
239+
.explanation(buffer.trimmingTrailingCharacters(in: .whitespacesAndNewlines))
240+
)
241+
case .code:
242+
results.append(.code(buffer))
243+
case .codeClosing:
244+
break
245+
case .codeOpening:
246+
break
247+
}
248+
resetBuffer()
249+
250+
return results
251+
}
252+
}
253+
254+
extension String {
255+
func trimmingTrailingCharacters(in characterSet: CharacterSet) -> String {
256+
guard !isEmpty else {
257+
return ""
258+
}
259+
var unicodeScalars = unicodeScalars
260+
while let scalar = unicodeScalars.last {
261+
if !characterSet.contains(scalar) {
262+
return String(unicodeScalars)
263+
}
264+
unicodeScalars.removeLast()
265+
}
266+
return ""
267+
}
268+
}
269+

0 commit comments

Comments
 (0)