diff --git a/apps/smp-server/Main.hs b/apps/smp-server/Main.hs index 3a334d0d5..7be11d2dd 100644 --- a/apps/smp-server/Main.hs +++ b/apps/smp-server/Main.hs @@ -3,7 +3,7 @@ module Main where import Control.Logger.Simple import Simplex.Messaging.Server.CLI (getEnvPath) import Simplex.Messaging.Server.Main (smpServerCLI_) -import Simplex.Messaging.Server.Web (serveStaticFiles, attachStaticFiles) +import Simplex.Messaging.Server.Web (serveStaticFiles, attachStaticAndWS) import SMPWeb (smpGenerateSite) defaultCfgPath :: FilePath @@ -19,4 +19,4 @@ main :: IO () main = do cfgPath <- getEnvPath "SMP_SERVER_CFG_PATH" defaultCfgPath logPath <- getEnvPath "SMP_SERVER_LOG_PATH" defaultLogPath - withGlobalLogging logCfg $ smpServerCLI_ smpGenerateSite serveStaticFiles attachStaticFiles cfgPath logPath + withGlobalLogging logCfg $ smpServerCLI_ smpGenerateSite serveStaticFiles attachStaticAndWS cfgPath logPath diff --git a/simplexmq.cabal b/simplexmq.cabal index c13fe8f5f..8dabd9208 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -354,6 +354,7 @@ library , temporary ==1.3.* , wai >=3.2 && <3.3 , wai-app-static >=3.1 && <3.2 + , wai-websockets >=3.0.1 && <3.1 , warp ==3.3.30 , warp-tls ==3.4.7 , websockets ==0.12.* diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index ec75a07d4..8a8bf7a14 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -40,6 +40,7 @@ module Simplex.Messaging.Server dummyVerifyCmd, randomId, AttachHTTP, + WSHandler, MessageStats (..), ) where @@ -121,6 +122,7 @@ import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Buffer (trimCR) import Simplex.Messaging.Transport.Server +import Simplex.Messaging.Transport.WebSockets (WS (..)) import Simplex.Messaging.Util import Simplex.Messaging.Version import System.Environment (lookupEnv) @@ -160,7 +162,8 @@ runSMPServerBlocking :: MsgStoreClass s => TMVar Bool -> ServerConfig s -> Maybe runSMPServerBlocking started cfg attachHTTP_ = newEnv cfg >>= runReaderT (smpServer started cfg attachHTTP_) type M s a = ReaderT (Env s) IO a -type AttachHTTP = Socket -> TLS.Context -> IO () +type AttachHTTP = Socket -> TLS 'TServer -> Maybe WSHandler -> IO () +type WSHandler = WS 'TServer -> IO () -- actions used in serverThread to reduce STM transaction scope data ClientSubAction @@ -211,10 +214,11 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt (Just httpCreds, Just attachHTTP) | addHTTP -> runTransportServerState_ ss started tcpPort defaultSupportedParamsHTTPS combinedCreds tCfg $ \s (sniUsed, h) -> case cast h of - Just (TLS {tlsContext} :: TLS 'TServer) | sniUsed -> labelMyThread "https client" >> attachHTTP s tlsContext + Just (tls :: TLS 'TServer) | sniUsed -> labelMyThread "https client" >> attachHTTP s tls wsHandler _ -> runClient srvCert srvSignKey t h `runReaderT` env where combinedCreds = TLSServerCredential {credential = smpCreds, sniCredential = Just httpCreds} + wsHandler = Just $ \ws -> runClient srvCert srvSignKey (TProxy :: TProxy WS 'TServer) ws `runReaderT` env _ -> runTransportServerState ss started tcpPort defaultSupportedParams smpCreds tCfg $ \h -> runClient srvCert srvSignKey t h `runReaderT` env diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index 92f0b0821..22c114ddd 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -106,7 +106,7 @@ import System.Directory (renameFile) #endif smpServerCLI :: FilePath -> FilePath -> IO () -smpServerCLI = smpServerCLI_ (\_ _ _ -> pure ()) (\_ -> pure ()) (\_ -> error "attachStaticFiles not available") +smpServerCLI = smpServerCLI_ (\_ _ _ -> pure ()) (\_ -> pure ()) (\_ -> error "attachStaticAndWS not available") smpServerCLI_ :: (ServerInformation -> Maybe TransportHost -> FilePath -> IO ()) -> @@ -115,7 +115,7 @@ smpServerCLI_ :: FilePath -> FilePath -> IO () -smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = +smpServerCLI_ generateSite serveStaticFiles attachStaticAndWS cfgPath logPath = getCliCommand' (cliCommandP cfgPath logPath iniFile) serverVersion >>= \case Init opts -> doesFileExist iniFile >>= \case @@ -489,7 +489,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = case webStaticPath' of Just path | sharedHTTP -> do runWebServer path Nothing ServerInformation {config, information} - attachStaticFiles path $ \attachHTTP -> do + attachStaticAndWS path $ \attachHTTP -> do logDebug "Allocated web server resources" runSMPServer cfg (Just attachHTTP) `finally` logDebug "Releasing web server resources..." Just path -> do diff --git a/src/Simplex/Messaging/Server/Web.hs b/src/Simplex/Messaging/Server/Web.hs index 7044a7e39..c1eabedf8 100644 --- a/src/Simplex/Messaging/Server/Web.hs +++ b/src/Simplex/Messaging/Server/Web.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} @@ -8,7 +9,7 @@ module Simplex.Messaging.Server.Web WebHttpsParams (..), EmbeddedContent (..), serveStaticFiles, - attachStaticFiles, + attachStaticAndWS, serveStaticPageH2, generateSite, serverInfoSubsts, @@ -41,11 +42,14 @@ import qualified Network.Wai.Application.Static as S import qualified Network.Wai.Handler.Warp as W import qualified Network.Wai.Handler.Warp.Internal as WI import qualified Network.Wai.Handler.WarpTLS as WT +import qualified Network.Wai.Handler.WebSockets as WaiWS +import Network.WebSockets (defaultConnectionOptions, ConnectionOptions(..), SizeLimit(..), PendingConnection) import Simplex.Messaging.Encoding.String (strEncode) -import Simplex.Messaging.Server (AttachHTTP) +import Simplex.Messaging.Server (AttachHTTP, WSHandler) import Simplex.Messaging.Server.CLI (simplexmqCommit) import Simplex.Messaging.Server.Information -import Simplex.Messaging.Transport (simplexMQVersion) +import Simplex.Messaging.Transport (TLS (..), smpBlockSize, simplexMQVersion) +import Simplex.Messaging.Transport.WebSockets (WS (..), acceptWSConnection) import Simplex.Messaging.Util (tshow) import System.Directory (canonicalizePath, createDirectoryIfMissing, doesFileExist) import System.FilePath @@ -84,20 +88,23 @@ serveStaticFiles EmbeddedWebParams {webStaticPath, webHttpPort, webHttpsParams} where mkSettings port = W.setPort port warpSettings --- | Prepare context and prepare HTTP handler for TLS connections that already passed TLS.handshake and ALPN check. -attachStaticFiles :: FilePath -> (AttachHTTP -> IO ()) -> IO () -attachStaticFiles path action = do - app <- staticFiles path - -- Initialize global internal state for http server. +attachStaticAndWS :: FilePath -> (AttachHTTP -> IO a) -> IO a +attachStaticAndWS path action = WI.withII warpSettings $ \ii -> do - action $ \socket cxt -> do - -- Initialize internal per-connection resources. + action $ \socket tls wsHandler_ -> do + app <- case wsHandler_ of + Just wsHandler -> + WaiWS.websocketsOr wsOpts (acceptWSConnection tls >=> wsHandler) <$> staticFiles path + Nothing -> staticFiles path addr <- getPeerName socket - withConnection addr cxt $ \(conn, transport) -> + withConnection addr (tlsContext tls) $ \(conn, transport) -> withTimeout ii conn $ \th -> - -- Run Warp connection handler to process HTTP requests for static files. WI.serveConnection conn ii th addr transport warpSettings app where + wsOpts = defaultConnectionOptions + { connectionFramePayloadSizeLimit = SizeLimit $ fromIntegral smpBlockSize, + connectionMessageDataSizeLimit = SizeLimit 65536 + } -- from warp-tls withConnection socket cxt = bracket (WT.attachConn socket cxt) (terminate . fst) -- from warp @@ -105,7 +112,6 @@ attachStaticFiles path action = do bracket (WI.registerKillThread (WI.timeoutManager ii) (WI.connClose conn)) WI.cancel - -- shared clean up terminate conn = WI.connClose conn `finally` (readIORef (WI.connWriteBuffer conn) >>= WI.bufFree) warpSettings :: W.Settings diff --git a/src/Simplex/Messaging/Transport/WebSockets.hs b/src/Simplex/Messaging/Transport/WebSockets.hs index 3ab213dcd..e1b6c1711 100644 --- a/src/Simplex/Messaging/Transport/WebSockets.hs +++ b/src/Simplex/Messaging/Transport/WebSockets.hs @@ -7,7 +7,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} -module Simplex.Messaging.Transport.WebSockets (WS (..)) where +module Simplex.Messaging.Transport.WebSockets (WS (..), acceptWSConnection) where import qualified Control.Exception as E import Data.ByteString.Char8 (ByteString) @@ -20,6 +20,7 @@ import Network.WebSockets.Stream (Stream) import qualified Network.WebSockets.Stream as S import Simplex.Messaging.Transport ( ALPN, + TLS (TLS, tlsContext, tlsPeerCert, tlsTransportConfig), Transport (..), TransportConfig (..), TransportError (..), @@ -101,6 +102,15 @@ getWS cfg wsCertSent wsPeerCert cxt = withTlsUnique @WS @p cxt connectWS acceptClientRequest s = makePendingConnectionFromStream s websocketsOpts >>= acceptRequest sendClientRequest s = newClientConnection s "" "/" websocketsOpts [] +acceptWSConnection :: TLS 'TServer -> PendingConnection -> IO (WS 'TServer) +acceptWSConnection tls pending = withTlsUnique @WS @'TServer cxt $ \wsUniq -> do + wsStream <- makeTLSContextStream cxt + wsConnection <- acceptRequest pending + wsALPN <- T.getNegotiatedProtocol cxt + pure WS {tlsUniq = wsUniq, wsALPN, wsStream, wsConnection, wsTransportConfig = tlsTransportConfig tls, wsCertSent = False, wsPeerCert = tlsPeerCert tls} + where + cxt = tlsContext tls + makeTLSContextStream :: T.Context -> IO Stream makeTLSContextStream cxt = S.makeStream readStream writeStream diff --git a/tests/CLITests.hs b/tests/CLITests.hs index 66af74ab8..5489877ad 100644 --- a/tests/CLITests.hs +++ b/tests/CLITests.hs @@ -31,7 +31,7 @@ import qualified Simplex.Messaging.Transport.HTTP2.Client as HC import Simplex.Messaging.Transport.Server (loadFileFingerprint) import Simplex.Messaging.Util (catchAll_) import qualified SMPWeb -import Simplex.Messaging.Server.Web (serveStaticFiles, attachStaticFiles) +import Simplex.Messaging.Server.Web (serveStaticFiles, attachStaticAndWS) import System.Directory (doesFileExist) import System.Environment (withArgs) import System.FilePath (()) @@ -152,7 +152,7 @@ smpServerTestStatic = do Right ini_ <- readIniFile iniFile lookupValue "WEB" "https" ini_ `shouldBe` Right "5223" - let smpServerCLI' = smpServerCLI_ SMPWeb.smpGenerateSite serveStaticFiles attachStaticFiles + let smpServerCLI' = smpServerCLI_ SMPWeb.smpGenerateSite serveStaticFiles attachStaticAndWS let server = capture_ (withArgs ["start"] $ smpServerCLI' cfgPath logPath `catchAny` print) bracket (async server) cancel $ \_t -> do threadDelay 1000000 diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index c51079d5e..0a65944d2 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -26,13 +26,15 @@ import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClie import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Protocol -import Simplex.Messaging.Server (runSMPServerBlocking) +import Simplex.Messaging.Server (runSMPServerBlocking, AttachHTTP) import Simplex.Messaging.Server.Env.STM import Simplex.Messaging.Server.MsgStore.Types (MsgStoreClass (..), SMSType (..), SQSType (..)) import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) +import Data.X509.Validation (Fingerprint (..)) import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client -import Simplex.Messaging.Transport.Server +import Simplex.Messaging.Transport.Server (ServerCredentials (..), TransportServerConfig (..), loadFileFingerprint, loadFingerprint, loadServerCredential, mkTransportServerConfig) +import Simplex.Messaging.Transport.WebSockets (WS) import Simplex.Messaging.Util (ifM) import Simplex.Messaging.Version import Simplex.Messaging.Version.Internal @@ -155,7 +157,8 @@ testSMPClientVR vr client = do testSMPClient_ :: Transport c => TransportHost -> ServiceName -> VersionRangeSMP -> (THandleSMP c 'TClient -> IO a) -> IO a testSMPClient_ host port vr client = do - let tcConfig = defaultTransportClientConfig {clientALPN} :: TransportClientConfig + -- SMP clients use useSNI = False (matches defaultSMPClientConfig) + let tcConfig = defaultTransportClientConfig {clientALPN, useSNI = False} :: TransportClientConfig runTransportClient tcConfig Nothing host port (Just testKeyHash) $ \h -> runExceptT (smpClientHandshake h Nothing testKeyHash vr False Nothing) >>= \case Right th -> client th @@ -283,6 +286,16 @@ serverStoreConfig_ useDbStoreLog = \case dbStoreLogPath = if useDbStoreLog then Just testStoreLogFile else Nothing storeCfg = PostgresStoreCfg {dbOpts = testStoreDBOpts, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = 86400} +cfgWebOn :: AStoreType -> ServiceName -> AServerConfig +cfgWebOn msType port' = updateCfg (cfgMS msType) $ \cfg' -> + cfg' { transports = [(port', transport @TLS, True)], + httpCredentials = Just ServerCredentials + { caCertificateFile = Nothing, + privateKeyFile = "tests/fixtures/web.key", + certificateFile = "tests/fixtures/web.crt" + } + } + cfgV7 :: AServerConfig cfgV7 = updateCfg cfg $ \cfg' -> cfg' {smpServerVRange = mkVersionRange minServerSMPRelayVersion authCmdsSMPVersion} @@ -333,9 +346,12 @@ withServerCfg :: AServerConfig -> (forall s. ServerConfig s -> a) -> a withServerCfg (ASrvCfg _ _ cfg') f = f cfg' withSmpServerConfigOn :: HasCallStack => ASrvTransport -> AServerConfig -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a -withSmpServerConfigOn t (ASrvCfg _ _ cfg') port' = +withSmpServerConfigOn t cfg' port' = withSmpServerConfig (updateCfg cfg' $ \c -> c {transports = [(port', t, False)]}) Nothing + +withSmpServerConfig :: HasCallStack => AServerConfig -> Maybe AttachHTTP -> (HasCallStack => ThreadId -> IO a) -> IO a +withSmpServerConfig (ASrvCfg _ _ cfg') attachHTTP_ = serverBracket - (\started -> runSMPServerBlocking started cfg' {transports = [(port', t, False)]} Nothing) + (\started -> runSMPServerBlocking started cfg' attachHTTP_) (threadDelay 10000) withSmpServerThreadOn :: HasCallStack => (ASrvTransport, AStoreType) -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index b2c2d997c..22f4af798 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -23,6 +23,7 @@ import Control.Concurrent.Async (concurrently_) import Control.Concurrent.STM import Control.Exception (SomeException, throwIO, try) import Control.Monad +import Control.Monad.Except (runExceptT) import Control.Monad.IO.Class import CoreTests.MsgStoreTests (testJournalStoreCfg) import Data.Bifunctor (first) @@ -42,6 +43,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (parseAll, parseString) import Simplex.Messaging.Protocol +import Simplex.Messaging.Client (chooseTransportHost, defaultNetworkConfig) import Simplex.Messaging.Server (exportMessages) import Simplex.Messaging.Server.Env.STM (AStoreType (..), MsgStore (..), ServerConfig (..), ServerStoreCfg (..), readWriteQueueStore) import Simplex.Messaging.Server.Expiration @@ -50,6 +52,11 @@ import Simplex.Messaging.Server.MsgStore.Types (MsgStoreClass (..), QSType (..), import Simplex.Messaging.Server.Stats (PeriodStatsData (..), ServerStatsData (..)) import Simplex.Messaging.Server.StoreLog (StoreLogRecord (..), closeStoreLog) import Simplex.Messaging.Transport +import Simplex.Messaging.Transport.Client (TransportClientConfig (..), defaultTransportClientConfig, runTLSTransportClient) +import Simplex.Messaging.Transport.WebSockets (WS) +import Simplex.Messaging.Transport.Server (ServerCredentials (..), loadFileFingerprint) +import Simplex.Messaging.Server.Web (attachStaticAndWS) +import Data.X509.Validation (Fingerprint (..)) import Simplex.Messaging.Util (whenM) import Simplex.Messaging.Version (mkVersionRange) import System.Directory (doesDirectoryExist, doesFileExist, removeDirectoryRecursive, removeFile) @@ -101,6 +108,7 @@ serverTests = do describe "Short links" $ do testInvQueueLinkData testContactQueueLinkData + describe "WebSocket and TLS on same port" testWebSocketAndTLS pattern Resp :: CorrId -> QueueId -> BrokerMsg -> Transmission (Either ErrorType BrokerMsg) pattern Resp corrId queueId command <- (corrId, queueId, Right command) @@ -1484,3 +1492,41 @@ serverSyntaxTests (ATransport t) = do (Maybe TAuthorizations, ByteString, ByteString, BrokerMsg) -> Expectation command >#> response = withFrozenCallStack $ smpServerTest t command `shouldReturn` response + +-- | Test that both native TLS and WebSocket clients can connect to the same port. +-- Native TLS uses useSNI=False, WebSocket uses useSNI=True for routing. +testWebSocketAndTLS :: SpecWith (ASrvTransport, AStoreType) +testWebSocketAndTLS = + it "native TLS and WebSocket clients work on same port" $ \(_t, msType) -> do + Fingerprint fpHTTP <- loadFileFingerprint "tests/fixtures/web_ca.crt" + let httpKeyHash = C.KeyHash fpHTTP + attachStaticAndWS "tests/fixtures" $ \attachHTTP -> + withSmpServerConfig (cfgWebOn msType testPort) (Just attachHTTP) $ \_ -> do + g <- C.newRandom + (rPub, rKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (sPub, sKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (dhPub, dhPriv :: C.PrivateKeyX25519) <- atomically $ C.generateKeyPair g + + -- Connect via native TLS (useSNI=False, default) and create a queue + (sId, rId, srvDh) <- testSMPClient @TLS $ \rh -> do + Resp "1" _ (Ids rId sId srvDh) <- signSendRecv rh rKey ("1", NoEntity, New rPub dhPub) + Resp "2" _ OK <- signSendRecv rh rKey ("2", rId, KEY sPub) + pure (sId, rId, srvDh) + let dec = decryptMsgV3 $ C.dh' srvDh dhPriv + + -- Connect via WebSocket (useSNI=True) and send a message + Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost + let wsTcConfig = defaultTransportClientConfig {useSNI = True} :: TransportClientConfig + runTLSTransportClient defaultSupportedParamsHTTPS Nothing wsTcConfig Nothing useHost testPort (Just httpKeyHash) $ \(h :: WS 'TClient) -> + runExceptT (smpClientHandshake h Nothing testKeyHash supportedClientSMPRelayVRange False Nothing) >>= \case + Right sh -> do + Resp "3" _ OK <- signSendRecv sh sKey ("3", sId, _SEND "hello from websocket") + pure () + Left e -> error $ show e + + -- Verify message received via native TLS + testSMPClient @TLS $ \rh -> do + (Resp "4" _ (SOK Nothing), Resp "" _ (Msg mId msg)) <- signSendRecv2 rh rKey ("4", rId, SUB) + dec mId msg `shouldBe` Right "hello from websocket" + Resp "5" _ OK <- signSendRecv rh rKey ("5", rId, ACK mId) + pure ()