Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions Sources/AnyLanguageModel/LanguageModelSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@ public final class LanguageModelSession: @unchecked Sendable {
) {
self.model = model
self.tools = tools
self.instructions = instructions
let resolvedInstructions = instructions ?? Self.instructions(from: transcript)
self.instructions = resolvedInstructions

// Build transcript with instructions if provided and not already in transcript
var finalTranscript = transcript
if let instructions = instructions {
if let instructions = resolvedInstructions {
// Only add instructions if transcript doesn't already start with instructions
let hasInstructions =
finalTranscript.first.map { entry in
Expand All @@ -97,6 +98,16 @@ public final class LanguageModelSession: @unchecked Sendable {
self.state = .init(.init(finalTranscript))
}

private static func instructions(from transcript: Transcript) -> Instructions? {
guard case .instructions(let instructions)? = transcript.first else { return nil }
guard instructions.segments.count == 1,
case .text(let text) = instructions.segments[0]
else {
return nil
}
return Instructions(content: text.content)
}

public func prewarm(promptPrefix: Prompt? = nil) {
model.prewarm(for: self, promptPrefix: promptPrefix)
}
Expand Down
18 changes: 18 additions & 0 deletions Tests/AnyLanguageModelTests/TranscriptTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,24 @@ struct TranscriptTests {
#expect(Transcript.Segment.image(image).id == "image-id")
}

@Test func sessionRestoresInstructionsFromTranscript() throws {
let instructions = "First\n\nSecond trailing spaces "
let transcript = Transcript(entries: [
.instructions(.init(
id: "instructions-id",
segments: [.text(.init(content: instructions))],
toolDefinitions: []
)),
.prompt(.init(segments: [.text(.init(content: "Hello"))]))
])

let session = LanguageModelSession(model: MockLanguageModel(), transcript: transcript)

#expect(session.instructions?.description == instructions)
#expect(session.transcript.count == transcript.count)
#expect(session.transcript.first?.id == "instructions-id")
}

@Test func imageSourceRoundTripsForDataAndURL() throws {
let encoder = JSONEncoder()
let decoder = JSONDecoder()
Expand Down