diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index ba38550..5ac3337 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -69,11 +69,12 @@ public final class LanguageModelSession: @unchecked Sendable { ) { self.model = model self.tools = tools - self.instructions = instructions + let resolvedInstructions = instructions ?? Self.instructions(from: transcript) + self.instructions = resolvedInstructions // Build transcript with instructions if provided and not already in transcript var finalTranscript = transcript - if let instructions = instructions { + if let instructions = resolvedInstructions { // Only add instructions if transcript doesn't already start with instructions let hasInstructions = finalTranscript.first.map { entry in @@ -97,6 +98,16 @@ public final class LanguageModelSession: @unchecked Sendable { self.state = .init(.init(finalTranscript)) } + private static func instructions(from transcript: Transcript) -> Instructions? { + guard case .instructions(let instructions)? = transcript.first else { return nil } + guard instructions.segments.count == 1, + case .text(let text) = instructions.segments[0] + else { + return nil + } + return Instructions(content: text.content) + } + public func prewarm(promptPrefix: Prompt? = nil) { model.prewarm(for: self, promptPrefix: promptPrefix) } diff --git a/Tests/AnyLanguageModelTests/TranscriptTests.swift b/Tests/AnyLanguageModelTests/TranscriptTests.swift index cb3ac94..60e9122 100644 --- a/Tests/AnyLanguageModelTests/TranscriptTests.swift +++ b/Tests/AnyLanguageModelTests/TranscriptTests.swift @@ -55,6 +55,24 @@ struct TranscriptTests { #expect(Transcript.Segment.image(image).id == "image-id") } + @Test func sessionRestoresInstructionsFromTranscript() throws { + let instructions = "First\n\nSecond trailing spaces " + let transcript = Transcript(entries: [ + .instructions(.init( + id: "instructions-id", + segments: [.text(.init(content: instructions))], + toolDefinitions: [] + )), + .prompt(.init(segments: [.text(.init(content: "Hello"))])) + ]) + + let session = LanguageModelSession(model: MockLanguageModel(), transcript: transcript) + + #expect(session.instructions?.description == instructions) + #expect(session.transcript.count == transcript.count) + #expect(session.transcript.first?.id == "instructions-id") + } + @Test func imageSourceRoundTripsForDataAndURL() throws { let encoder = JSONEncoder() let decoder = JSONDecoder()