diff --git a/Package.resolved b/Package.resolved index 837d776..7ed371b 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,13 +1,22 @@ { - "originHash" : "f7b86b800200fa069a2b288e06bafe53bc937a1851b6effeebba326a62be227e", + "originHash" : "f2f0ba1d1b9625bd5147b2fbd7b82236dac35ee1baa399fcf5c76b22fd428bb8", "pins" : [ + { + "identity" : "async-http-client", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swift-server/async-http-client.git", + "state" : { + "revision" : "2fc4652fb4689eb24af10e55cabaa61d8ba774fd", + "version" : "1.32.0" + } + }, { "identity" : "eventsource", "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/EventSource.git", + "location" : "https://github.com/mattt/EventSource", "state" : { - "revision" : "ca2a9d90cbe49e09b92f4b6ebd922c03ebea51d0", - "version" : "1.3.0" + "revision" : "bd64824505da71a1a403adb221f6e25413c0bc7f", + "version" : "1.4.0" } }, { @@ -20,39 +29,57 @@ } }, { - "identity" : "llama.swift", + "identity" : "partialjsondecoder", "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/llama.swift", + "location" : "https://github.com/mattt/PartialJSONDecoder.git", "state" : { - "revision" : "4d57cff84ba85914baa39850157e7c27684db9c8", - "version" : "2.7966.0" + "revision" : "e4d389e6bcc6771bb988d1a8a17695d8bfa97172", + "version" : "1.0.0" } }, { - "identity" : "mlx-swift", + "identity" : "swift-algorithms", "kind" : "remoteSourceControl", - "location" : "https://github.com/ml-explore/mlx-swift", + "location" : "https://github.com/apple/swift-algorithms.git", "state" : { - "revision" : "072b684acaae80b6a463abab3a103732f33774bf", - "version" : "0.29.1" + "revision" : "87e50f483c54e6efd60e885f7f5aa946cee68023", + "version" : "1.2.1" } }, { - "identity" : "mlx-swift-lm", + "identity" : "swift-asn1", "kind" : "remoteSourceControl", - "location" : "https://github.com/ml-explore/mlx-swift-lm", + "location" : "https://github.com/apple/swift-asn1.git", "state" : { - "revision" : "5064b8c5d8ed3b0bbb71385c4124f0fc102e74a2", - "version" : "2.29.3" + "revision" : "810496cf121e525d660cd0ea89a758740476b85f", + "version" : "1.5.1" } }, { - "identity" : "partialjsondecoder", + "identity" : "swift-async-algorithms", "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/PartialJSONDecoder.git", + "location" : "https://github.com/apple/swift-async-algorithms.git", "state" : { - "revision" : "e4d389e6bcc6771bb988d1a8a17695d8bfa97172", - "version" : "1.0.0" + "revision" : "9d349bcc328ac3c31ce40e746b5882742a0d1272", + "version" : "1.1.3" + } + }, + { + "identity" : "swift-atomics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-atomics.git", + "state" : { + "revision" : "b601256eab081c0f92f059e12818ac1d4f178ff7", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-certificates", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-certificates.git", + "state" : { + "revision" : "24ccdeeeed4dfaae7955fcac9dbf5489ed4f1a25", + "version" : "1.18.0" } }, { @@ -65,23 +92,131 @@ } }, { - "identity" : "swift-jinja", + "identity" : "swift-configuration", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-configuration.git", + "state" : { + "revision" : "be76c4ad929eb6c4bcaf3351799f2adf9e6848a9", + "version" : "1.2.0" + } + }, + { + "identity" : "swift-crypto", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-crypto.git", + "state" : { + "revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095", + "version" : "4.2.0" + } + }, + { + "identity" : "swift-distributed-tracing", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-distributed-tracing.git", + "state" : { + "revision" : "e109d8b5308d0e05201d9a1dd1c475446a946a11", + "version" : "1.4.0" + } + }, + { + "identity" : "swift-http-structured-headers", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-http-structured-headers.git", + "state" : { + "revision" : "76d7627bd88b47bf5a0f8497dd244885960dde0b", + "version" : "1.6.0" + } + }, + { + "identity" : "swift-http-types", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-http-types.git", + "state" : { + "revision" : "45eb0224913ea070ec4fba17291b9e7ecf4749ca", + "version" : "1.5.1" + } + }, + { + "identity" : "swift-log", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-log.git", + "state" : { + "revision" : "bbd81b6725ae874c69e9b8c8804d462356b55523", + "version" : "1.10.1" + } + }, + { + "identity" : "swift-nio", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio.git", + "state" : { + "revision" : "e932d3c4d8f77433c8f7093b5ebcbf91463948a0", + "version" : "2.95.0" + } + }, + { + "identity" : "swift-nio-extras", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-extras.git", + "state" : { + "revision" : "3df009d563dc9f21a5c85b33d8c2e34d2e4f8c3b", + "version" : "1.32.1" + } + }, + { + "identity" : "swift-nio-http2", "kind" : "remoteSourceControl", - "location" : "https://github.com/huggingface/swift-jinja.git", + "location" : "https://github.com/apple/swift-nio-http2.git", "state" : { - "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0", - "version" : "2.3.1" + "revision" : "b6571f3db40799df5a7fc0e92c399aa71c883edd", + "version" : "1.40.0" + } + }, + { + "identity" : "swift-nio-ssl", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-ssl.git", + "state" : { + "revision" : "173cc69a058623525a58ae6710e2f5727c663793", + "version" : "2.36.0" + } + }, + { + "identity" : "swift-nio-transport-services", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-transport-services.git", + "state" : { + "revision" : "60c3e187154421171721c1a38e800b390680fb5d", + "version" : "1.26.0" } }, { "identity" : "swift-numerics", "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-numerics", + "location" : "https://github.com/apple/swift-numerics.git", "state" : { "revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2", "version" : "1.1.1" } }, + { + "identity" : "swift-service-context", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-service-context.git", + "state" : { + "revision" : "d0997351b0c7779017f88e7a93bc30a1878d7f29", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-service-lifecycle", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swift-server/swift-service-lifecycle", + "state" : { + "revision" : "89888196dd79c61c50bca9a103d8114f32e1e598", + "version" : "2.10.1" + } + }, { "identity" : "swift-syntax", "kind" : "remoteSourceControl", @@ -92,12 +227,12 @@ } }, { - "identity" : "swift-transformers", + "identity" : "swift-system", "kind" : "remoteSourceControl", - "location" : "https://github.com/huggingface/swift-transformers", + "location" : "https://github.com/apple/swift-system", "state" : { - "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0", - "version" : "1.1.6" + "revision" : "7c6ad0fc39d0763e0b699210e4124afd5041c5df", + "version" : "1.6.4" } } ], diff --git a/Package.swift b/Package.swift index 3916bf0..b8d6962 100644 --- a/Package.swift +++ b/Package.swift @@ -25,17 +25,22 @@ let package = Package( .trait(name: "CoreML"), .trait(name: "MLX"), .trait(name: "Llama"), + .trait(name: "AsyncHTTPClient"), .default(enabledTraits: []), ], dependencies: [ .package(url: "https://github.com/huggingface/swift-transformers", from: "1.0.0"), - .package(url: "https://github.com/mattt/EventSource", from: "1.3.0"), + .package(url: "https://github.com/mattt/EventSource", from: "1.3.0", traits: [ + .defaults, + .trait(name: "AsyncHTTPClient", condition: .when(traits: ["AsyncHTTPClient"])) + ]), .package(url: "https://github.com/mattt/JSONSchema", from: "1.3.0"), .package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")), .package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"), // mlx-swift-lm must be >= 2.25.5 for ToolSpec/tool calls and UserInput(chat:processing:tools:). .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.25.5"), .package(url: "https://github.com/swiftlang/swift-syntax", from: "600.0.0"), + .package(url: "https://github.com/swift-server/async-http-client.git", from: "1.24.0"), ], targets: [ .target( @@ -70,6 +75,11 @@ let package = Package( package: "llama.swift", condition: .when(traits: ["Llama"]) ), + .product( + name: "AsyncHTTPClient", + package: "async-http-client", + condition: .when(traits: ["AsyncHTTPClient"]) + ), ] ), .macro( @@ -83,7 +93,14 @@ let package = Package( ), .testTarget( name: "AnyLanguageModelTests", - dependencies: ["AnyLanguageModel"] + dependencies: [ + "AnyLanguageModel", + .product( + name: "AsyncHTTPClient", + package: "async-http-client", + condition: .when(traits: ["AsyncHTTPClient"]) + ), + ], ), ] ) diff --git a/Sources/AnyLanguageModel/Extensions/HTTPClient+Extensions.swift b/Sources/AnyLanguageModel/Extensions/HTTPClient+Extensions.swift new file mode 100644 index 0000000..1089608 --- /dev/null +++ b/Sources/AnyLanguageModel/Extensions/HTTPClient+Extensions.swift @@ -0,0 +1,213 @@ +#if canImport(AsyncHTTPClient) +// AsyncHTTPClient.HTTPHandler introduces a Task type that clashes +typealias SwiftTask = Task + +import AsyncHTTPClient +import EventSource +import Foundation +#if canImport(FoundationNetworking) +import FoundationNetworking +#endif +import JSONSchema +import NIOCore +import NIOHTTP1 +import NIOFoundationCompat + +extension HTTPClient { + func fetch( + _ method: HTTP.Method, + url: URL, + headers: [String: String] = [:], + body: Data? = nil, + dateDecodingStrategy: JSONDecoder.DateDecodingStrategy = .deferredToDate + ) async throws -> T { + var request = HTTPClientRequest(url: url.absoluteString) + request.method = HTTPMethod(rawValue: method.rawValue) + request.headers.add(name: "Accept", value: "application/json") + + for (key, value) in headers { + request.headers.add(name: key, value: value) + } + + if let body { + request.body = .bytes(ByteBuffer(data: body)) + request.headers.add(name: "Content-Type", value: "application/json") + } + + let response = try await self.execute(request, timeout: .seconds(180)) + + guard (200 ..< 300).contains(response.status.code) else { + let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) + if let errorString = String(data: bodyData, encoding: .utf8) { + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: errorString) + } + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: "Invalid response") + } + + let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) + + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = dateDecodingStrategy + + do { + return try decoder.decode(T.self, from: bodyData) + } catch { + throw HTTPClientError.decodingError(detail: error.localizedDescription) + } + } + + func fetchStream( + _ method: HTTP.Method, + url: URL, + headers: [String: String] = [:], + body: Data? = nil, + dateDecodingStrategy: JSONDecoder.DateDecodingStrategy = .deferredToDate + ) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let task = SwiftTask { @Sendable in + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = dateDecodingStrategy + + do { + var request = HTTPClientRequest(url: url.absoluteString) + request.method = HTTPMethod(rawValue: method.rawValue) + request.headers.add(name: "Accept", value: "application/json") + + for (key, value) in headers { + request.headers.add(name: key, value: value) + } + + if let body { + request.body = .bytes(ByteBuffer(data: body)) + request.headers.add(name: "Content-Type", value: "application/json") + } + + let response = try await self.execute(request, timeout: .seconds(60)) + + guard (200 ..< 300).contains(response.status.code) else { + let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) + if let errorString = String(data: bodyData, encoding: .utf8) { + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: errorString) + } + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: "Invalid response") + } + + var buffer = Data() + + for try await chunk in response.body { + buffer.append(contentsOf: chunk.readableBytesView) + + while let newlineIndex = buffer.firstIndex(of: UInt8(ascii: "\n")) { + let line = buffer[..( + _ method: HTTP.Method, + url: URL, + headers: [String: String] = [:], + body: Data? = nil + ) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let task = SwiftTask { @Sendable in + do { + var request = HTTPClientRequest(url: url.absoluteString) + request.method = HTTPMethod(rawValue: method.rawValue) + request.headers.add(name: "Accept", value: "text/event-stream") + + for (key, value) in headers { + request.headers.add(name: key, value: value) + } + + if let body { + request.body = .bytes(ByteBuffer(data: body)) + request.headers.add(name: "Content-Type", value: "application/json") + } + + let response = try await self.execute(request, timeout: .seconds(60)) + + guard (200 ..< 300).contains(response.status.code) else { + let bodyData = try await Data(buffer: response.body.collect(upTo: 1024 * 1024)) + if let errorString = String(data: bodyData, encoding: .utf8) { + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: errorString) + } + throw HTTPClientError.httpError(statusCode: Int(response.status.code), detail: "Invalid response") + } + + let asyncBytes = AsyncStream { byteContinuation in + SwiftTask { + do { + for try await buffer in response.body { + for byte in buffer.readableBytesView { + byteContinuation.yield(byte) + } + } + byteContinuation.finish() + } catch { + byteContinuation.finish() + } + } + } + + try await self.decodeAndYieldEventStream(asyncBytes, to: continuation) + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } + + continuation.onTermination = { _ in + task.cancel() + } + } + } + + private func decodeAndYieldEventStream( + _ asyncBytes: Bytes, + to continuation: AsyncThrowingStream.Continuation + ) async throws where Bytes: AsyncSequence, Bytes.Element == UInt8 { + let decoder = JSONDecoder() + for try await event in asyncBytes.events { + guard let data = event.data.data(using: .utf8) else { continue } + if let decoded = try? decoder.decode(T.self, from: data) { + continuation.yield(decoded) + } + } + } +} + +enum HTTPClientError: Error, CustomStringConvertible { + case invalidResponse + case httpError(statusCode: Int, detail: String) + case decodingError(detail: String) + + var description: String { + switch self { + case .invalidResponse: + return "Invalid response" + case .httpError(let statusCode, let detail): + return "HTTP error (Status \(statusCode)): \(detail)" + case .decodingError(let detail): + return "Decoding error: \(detail)" + } + } +} +#endif diff --git a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift index bd21a1f..851c555 100644 --- a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift @@ -278,7 +278,7 @@ public struct AnthropicLanguageModel: LanguageModel { /// The model identifier to use for generation. public let model: String - private let urlSession: URLSession + private let urlSession: SessionType /// Creates an Anthropic language model. /// @@ -295,7 +295,7 @@ public struct AnthropicLanguageModel: LanguageModel { apiVersion: String = defaultAPIVersion, betas: [String]? = nil, model: String, - session: URLSession = URLSession(configuration: .default) + session: SessionType = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { diff --git a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift index 2da15a4..e2230ab 100644 --- a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift @@ -186,7 +186,7 @@ public struct GeminiLanguageModel: LanguageModel { /// Internal storage for the deprecated serverTools property. internal var _serverTools: [CustomGenerationOptions.ServerTool] - private let urlSession: URLSession + private let urlSession: SessionType /// Creates a new Gemini language model. /// @@ -201,7 +201,7 @@ public struct GeminiLanguageModel: LanguageModel { apiKey tokenProvider: @escaping @autoclosure @Sendable () -> String, apiVersion: String = defaultAPIVersion, model: String, - session: URLSession = URLSession(configuration: .default) + session: SessionType = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { @@ -243,7 +243,7 @@ public struct GeminiLanguageModel: LanguageModel { model: String, thinking: CustomGenerationOptions.Thinking = .disabled, serverTools: [CustomGenerationOptions.ServerTool] = [], - session: URLSession = URLSession(configuration: .default) + session: SessionType = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { diff --git a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift index 6be5d02..1e49856 100644 --- a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift @@ -46,7 +46,7 @@ public struct OllamaLanguageModel: LanguageModel { /// The model identifier to use for generation. public let model: String - private let urlSession: URLSession + private let urlSession: SessionType /// Creates an Ollama language model. /// @@ -57,7 +57,7 @@ public struct OllamaLanguageModel: LanguageModel { public init( baseURL: URL = defaultBaseURL, model: String, - session: URLSession = URLSession(configuration: .default) + session: SessionType = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { diff --git a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift index db4eab3..a1b6b0e 100644 --- a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift @@ -393,7 +393,7 @@ public struct OpenAILanguageModel: LanguageModel { /// The API variant to use. public let apiVariant: APIVariant - private let urlSession: URLSession + private let urlSession: SessionType /// Creates an OpenAI language model. /// @@ -408,7 +408,7 @@ public struct OpenAILanguageModel: LanguageModel { apiKey tokenProvider: @escaping @autoclosure @Sendable () -> String, model: String, apiVariant: APIVariant = .chatCompletions, - session: URLSession = URLSession(configuration: .default) + session: SessionType = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { diff --git a/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift index c4ba51e..c123e80 100644 --- a/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift @@ -365,7 +365,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { /// Model identifier to use for generation. public let model: String - private let urlSession: URLSession + private let urlSession: SessionType /// Creates an Open Responses language model. /// @@ -378,7 +378,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { baseURL: URL, apiKey tokenProvider: @escaping @autoclosure @Sendable () -> String, model: String, - session: URLSession = URLSession(configuration: .default) + session: SessionType = makeDefaultSession(), ) { var baseURL = baseURL if !baseURL.path.hasSuffix("/") { diff --git a/Sources/AnyLanguageModel/Transport.swift b/Sources/AnyLanguageModel/Transport.swift new file mode 100644 index 0000000..a8d8186 --- /dev/null +++ b/Sources/AnyLanguageModel/Transport.swift @@ -0,0 +1,20 @@ +#if canImport(AsyncHTTPClient) + import AsyncHTTPClient + + public typealias SessionType = HTTPClient + + public func makeDefaultSession() -> SessionType { + return HTTPClient.shared + } +#else + import Foundation + #if canImport(FoundationNetworking) + import FoundationNetworking + #endif + + public typealias SessionType = URLSession + + public func makeDefaultSession() -> SessionType { + return URLSession(configuration: .default) + } +#endif