Skip to content

Commit aa14c78

Browse files
committed
use laziness for session test kits to support cases where configuration needs to be loaded dynamically (e.g. if dependent on dynamic ports for containers)
1 parent 0a859c0 commit aa14c78

File tree

1 file changed

+216
-7
lines changed

1 file changed

+216
-7
lines changed

server/testkit/src/main/scala/akka/http/scaladsl/testkit/PersistenceScalatestRouteTest.scala

Lines changed: 216 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ trait PersistenceScalatestRouteTest
2424
extends ApiServer
2525
with ServerTestKit
2626
with PersistenceTestKit
27-
with RouteTest
27+
with PersistenceRouteTest
2828
with TestFrameworkInterface
2929
with ScalatestUtils
3030
with Json4sSupport { this: Suite with ApiRoutes with Schema =>
@@ -33,12 +33,6 @@ trait PersistenceScalatestRouteTest
3333
typedSystem()
3434
}
3535

36-
override implicit lazy val system: ActorSystem = createActorSystem()
37-
38-
override implicit lazy val executor: ExecutionContextExecutor = system.dispatcher
39-
40-
override implicit lazy val materializer: Materializer = SystemMaterializer(system).materializer
41-
4236
implicit lazy val timeout: RouteTestTimeout = RouteTestTimeout(Settings.DefaultTimeout)
4337

4438
def failTest(msg: String) = throw new TestFailedException(msg, 11)
@@ -133,3 +127,218 @@ trait InMemoryPersistenceScalatestRouteTest
133127
with InMemoryPersistenceTestKit {
134128
_: Suite with ApiRoutes =>
135129
}
130+
131+
import akka.http.scaladsl.Http
132+
import akka.http.scaladsl.client.RequestBuilding
133+
import akka.http.scaladsl.model.HttpEntity.ChunkStreamPart
134+
import akka.http.scaladsl.model._
135+
import akka.http.scaladsl.model.headers.{ Host, Upgrade, `Sec-WebSocket-Protocol` }
136+
import akka.http.scaladsl.server._
137+
import akka.http.scaladsl.settings.ParserSettings
138+
import akka.http.scaladsl.settings.RoutingSettings
139+
import akka.http.scaladsl.settings.ServerSettings
140+
import akka.http.scaladsl.unmarshalling._
141+
import akka.http.scaladsl.util.FastFuture._
142+
import akka.stream.scaladsl.Source
143+
import akka.testkit.TestKit
144+
import akka.util.ConstantFun
145+
import com.typesafe.config.{ Config, ConfigFactory }
146+
147+
import scala.collection.immutable
148+
import scala.concurrent.duration._
149+
import scala.concurrent.{ Await, ExecutionContext, Future }
150+
import scala.reflect.ClassTag
151+
import scala.util.DynamicVariable
152+
153+
trait PersistenceRouteTest extends RequestBuilding with WSTestRequestBuilding with RouteTestResultComponent with MarshallingTestUtils {
154+
this: TestFrameworkInterface =>
155+
156+
/** Override to supply a custom ActorSystem */
157+
protected def createActorSystem(): ActorSystem =
158+
ActorSystem(actorSystemNameFrom(getClass), testConfig)
159+
160+
def actorSystemNameFrom(clazz: Class[_]) =
161+
clazz.getName
162+
.replace('.', '-')
163+
.replace('_', '-')
164+
.filter(_ != '$')
165+
166+
def testConfigSource: String = ""
167+
def testConfig: Config = {
168+
val source = testConfigSource
169+
val config = if (source.isEmpty) ConfigFactory.empty() else ConfigFactory.parseString(source)
170+
config.withFallback(ConfigFactory.load())
171+
}
172+
implicit lazy val system: ActorSystem = createActorSystem()
173+
implicit lazy val executor: ExecutionContextExecutor = system.dispatcher
174+
implicit lazy val materializer: Materializer = SystemMaterializer(system).materializer
175+
176+
def cleanUp(): Unit = TestKit.shutdownActorSystem(system)
177+
178+
private val dynRR = new DynamicVariable[RouteTestResult](null)
179+
private def result =
180+
if (dynRR.value ne null) dynRR.value
181+
else sys.error("This value is only available inside of a `check` construct!")
182+
183+
def check[T](body: => T): RouteTestResult => T = result => dynRR.withValue(result.awaitResult)(body)
184+
185+
private def responseSafe = if (dynRR.value ne null) dynRR.value.response else "<not available anymore>"
186+
187+
def handled: Boolean = result.handled
188+
def response: HttpResponse = result.response
189+
def responseEntity: HttpEntity = result.entity
190+
private def rawResponse: HttpResponse = result.rawResponse
191+
def chunks: immutable.Seq[HttpEntity.ChunkStreamPart] = result.chunks
192+
def chunksStream: Source[ChunkStreamPart, Any] = result.chunksStream
193+
def entityAs[T: FromEntityUnmarshaller: ClassTag](implicit timeout: Duration = 1.second): T = {
194+
def msg(e: Throwable) = s"Could not unmarshal entity to type '${implicitly[ClassTag[T]]}' for `entityAs` assertion: $e\n\nResponse was: $responseSafe"
195+
Await.result(Unmarshal(responseEntity).to[T].fast.recover[T] { case error => failTest(msg(error)) }, timeout)
196+
}
197+
def responseAs[T: FromResponseUnmarshaller: ClassTag](implicit timeout: Duration = 1.second): T = {
198+
def msg(e: Throwable) = s"Could not unmarshal response to type '${implicitly[ClassTag[T]]}' for `responseAs` assertion: $e\n\nResponse was: $responseSafe"
199+
Await.result(Unmarshal(response).to[T].fast.recover[T] { case error => failTest(msg(error)) }, timeout)
200+
}
201+
def contentType: ContentType = rawResponse.entity.contentType
202+
def mediaType: MediaType = contentType.mediaType
203+
def charsetOption: Option[HttpCharset] = contentType.charsetOption
204+
def charset: HttpCharset = charsetOption getOrElse sys.error("Binary entity does not have charset")
205+
def headers: immutable.Seq[HttpHeader] = rawResponse.headers
206+
def header[T >: Null <: HttpHeader: ClassTag]: Option[T] = rawResponse.header[T](implicitly[ClassTag[T]])
207+
def header(name: String): Option[HttpHeader] = rawResponse.headers.find(_.is(name.toLowerCase))
208+
def status: StatusCode = rawResponse.status
209+
210+
def closingExtension: String = chunks.lastOption match {
211+
case Some(HttpEntity.LastChunk(extension, _)) => extension
212+
case _ => ""
213+
}
214+
def trailer: immutable.Seq[HttpHeader] = chunks.lastOption match {
215+
case Some(HttpEntity.LastChunk(_, trailer)) => trailer
216+
case _ => Nil
217+
}
218+
219+
def rejections: immutable.Seq[Rejection] = result.rejections
220+
def rejection: Rejection = {
221+
val r = rejections
222+
if (r.size == 1) r.head else failTest("Expected a single rejection but got %s (%s)".format(r.size, r))
223+
}
224+
225+
def isWebSocketUpgrade: Boolean =
226+
status == StatusCodes.SwitchingProtocols && header[Upgrade].exists(_.hasWebSocket)
227+
228+
/**
229+
* Asserts that the received response is a WebSocket upgrade response and the extracts
230+
* the chosen subprotocol and passes it to the handler.
231+
*/
232+
def expectWebSocketUpgradeWithProtocol(body: String => Unit): Unit = {
233+
if (!isWebSocketUpgrade) failTest("Response was no WebSocket Upgrade response")
234+
header[`Sec-WebSocket-Protocol`] match {
235+
case Some(`Sec-WebSocket-Protocol`(Seq(protocol))) => body(protocol)
236+
case _ => failTest("No WebSocket protocol found in response.")
237+
}
238+
}
239+
240+
/**
241+
* A dummy that can be used as `~> runRoute` to run the route but without blocking for the result.
242+
* The result of the pipeline is the result that can later be checked with `check`. See the
243+
* "separate running route from checking" example from ScalatestRouteTestSpec.scala.
244+
*/
245+
def runRoute: RouteTestResult => RouteTestResult = ConstantFun.scalaIdentityFunction
246+
247+
// there is already an implicit class WithTransformation in scope (inherited from akka.http.scaladsl.testkit.TransformerPipelineSupport)
248+
// however, this one takes precedence
249+
implicit class WithTransformation2(request: HttpRequest) {
250+
/**
251+
* Apply request to given routes for further inspection in `check { }` block.
252+
*/
253+
def ~>[A, B](f: A => B)(implicit ta: TildeArrow[A, B]): ta.Out = ta(request, f)
254+
255+
/**
256+
* Evaluate request against routes run in server mode for further
257+
* inspection in `check { }` block.
258+
*
259+
* Compared to [[~>]], the given routes are run in a fully fledged
260+
* server, which allows more types of directives to be tested at the
261+
* cost of additional overhead related with server setup.
262+
*/
263+
def ~!>[A, B](f: A => B)(implicit tba: TildeBangArrow[A, B]): tba.Out = tba(request, f)
264+
}
265+
266+
abstract class TildeArrow[A, B] {
267+
type Out
268+
def apply(request: HttpRequest, f: A => B): Out
269+
}
270+
271+
case class DefaultHostInfo(host: Host, securedConnection: Boolean)
272+
object DefaultHostInfo {
273+
implicit def defaultHost: DefaultHostInfo = DefaultHostInfo(Host("example.com"), securedConnection = false)
274+
}
275+
object TildeArrow {
276+
implicit object InjectIntoRequestTransformer extends TildeArrow[HttpRequest, HttpRequest] {
277+
type Out = HttpRequest
278+
def apply(request: HttpRequest, f: HttpRequest => HttpRequest) = f(request)
279+
}
280+
implicit def injectIntoRoute(implicit timeout: RouteTestTimeout, defaultHostInfo: DefaultHostInfo): TildeArrow[RequestContext, Future[RouteResult]] { type Out = RouteTestResult } =
281+
new TildeArrow[RequestContext, Future[RouteResult]] {
282+
type Out = RouteTestResult
283+
def apply(request: HttpRequest, route: Route): Out = {
284+
if (request.method == HttpMethods.HEAD && ServerSettings(system).transparentHeadRequests)
285+
failTest("`akka.http.server.transparent-head-requests = on` not supported in PersistenceRouteTest using `~>`. Use `~!>` instead " +
286+
"for a full-stack test, e.g. `req ~!> route ~> check {...}`")
287+
288+
implicit val executionContext: ExecutionContext = system.classicSystem.dispatcher
289+
val routingSettings = RoutingSettings(system)
290+
val routingLog = RoutingLog(system.classicSystem.log)
291+
292+
val routeTestResult = new RouteTestResult(timeout.duration)
293+
val effectiveRequest =
294+
request.withEffectiveUri(
295+
securedConnection = defaultHostInfo.securedConnection,
296+
defaultHostHeader = defaultHostInfo.host)
297+
val parserSettings = ParserSettings.forServer(system)
298+
val ctx = new RequestContextImpl(effectiveRequest, routingLog.requestLog(effectiveRequest), routingSettings, parserSettings)
299+
300+
val sealedExceptionHandler = ExceptionHandler.seal(testExceptionHandler)
301+
302+
val semiSealedRoute = // sealed for exceptions but not for rejections
303+
Directives.handleExceptions(sealedExceptionHandler)(route)
304+
val deferrableRouteResult = semiSealedRoute(ctx)
305+
deferrableRouteResult.fast.foreach(routeTestResult.handleResult)(executionContext)
306+
routeTestResult
307+
}
308+
}
309+
}
310+
311+
abstract class TildeBangArrow[A, B] {
312+
type Out
313+
def apply(request: HttpRequest, f: A => B): Out
314+
}
315+
316+
object TildeBangArrow {
317+
implicit def injectIntoRoute(implicit timeout: RouteTestTimeout, serverSettings: ServerSettings): TildeBangArrow[RequestContext, Future[RouteResult]] { type Out = RouteTestResult } =
318+
new TildeBangArrow[RequestContext, Future[RouteResult]] {
319+
type Out = RouteTestResult
320+
def apply(request: HttpRequest, route: Route): Out = {
321+
val routeTestResult = new RouteTestResult(timeout.duration)
322+
val responseF = PersistenceRouteTest.runRouteClientServer(request, route, serverSettings)
323+
val response = Await.result(responseF, timeout.duration)
324+
routeTestResult.handleResponse(response)
325+
routeTestResult
326+
}
327+
}
328+
}
329+
}
330+
private[http] object PersistenceRouteTest {
331+
def runRouteClientServer(request: HttpRequest, route: Route, serverSettings: ServerSettings)(implicit system: ActorSystem): Future[HttpResponse] = {
332+
import system.dispatcher
333+
for {
334+
binding <- Http().newServerAt("127.0.0.1", 0).withSettings(settings = serverSettings).bind(route)
335+
port = binding.localAddress.getPort
336+
targetUri = request.uri.withHost("127.0.0.1").withPort(port).withScheme("http")
337+
338+
response <- Http().singleRequest(request.withUri(targetUri))
339+
} yield {
340+
binding.unbind()
341+
response
342+
}
343+
}
344+
}

0 commit comments

Comments
 (0)