|
1 | 1 | package scalaoauth2.provider |
2 | 2 |
|
| 3 | +import scala.concurrent.Future |
| 4 | +import scala.concurrent.ExecutionContext.Implicits.global |
3 | 5 |
|
4 | 6 | case class GrantHandlerResult(tokenType: String, accessToken: String, expiresIn: Option[Long], refreshToken: Option[String], scope: Option[String]) |
5 | 7 |
|
6 | 8 | trait GrantHandler { |
7 | 9 |
|
8 | | - def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): GrantHandlerResult |
| 10 | + def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): Future[GrantHandlerResult] |
9 | 11 |
|
10 | 12 |
|
11 | 13 | /** |
12 | 14 | * Returns valid access token. |
13 | | - * |
| 15 | + * |
14 | 16 | * @param dataHandler |
15 | 17 | * @param authInfo |
16 | | - * @return |
| 18 | + * @return |
17 | 19 | */ |
18 | | - def issueAccessToken[U](dataHandler: DataHandler[U], authInfo: AuthInfo[U]): GrantHandlerResult = { |
19 | | - val accessToken = dataHandler.getStoredAccessToken(authInfo) match { |
20 | | - case Some(token) if dataHandler.isAccessTokenExpired(token) => |
21 | | - token.refreshToken.map(dataHandler.refreshAccessToken(authInfo, _)).getOrElse(dataHandler.createAccessToken(authInfo)) |
22 | | - case Some(token) => token |
23 | | - case None => dataHandler.createAccessToken(authInfo) |
| 20 | + def issueAccessToken[U](dataHandler: DataHandler[U], authInfo: AuthInfo[U]): Future[GrantHandlerResult] = { |
| 21 | + dataHandler.getStoredAccessToken(authInfo).flatMap { optionalAccessToken => |
| 22 | + (optionalAccessToken match { |
| 23 | + case Some(token) if dataHandler.isAccessTokenExpired(token) => { |
| 24 | + token.refreshToken.map(dataHandler.refreshAccessToken(authInfo, _)).getOrElse(dataHandler.createAccessToken(authInfo)) |
| 25 | + } |
| 26 | + case Some(token) => Future.successful(token) |
| 27 | + case None => dataHandler.createAccessToken(authInfo) |
| 28 | + }).map { accessToken => |
| 29 | + GrantHandlerResult( |
| 30 | + "Bearer", |
| 31 | + accessToken.token, |
| 32 | + accessToken.expiresIn, |
| 33 | + accessToken.refreshToken, |
| 34 | + accessToken.scope |
| 35 | + ) |
| 36 | + } |
24 | 37 | } |
25 | | - |
26 | | - GrantHandlerResult( |
27 | | - "Bearer", |
28 | | - accessToken.token, |
29 | | - accessToken.expiresIn, |
30 | | - accessToken.refreshToken, |
31 | | - accessToken.scope |
32 | | - ) |
33 | 38 | } |
34 | 39 | } |
35 | 40 |
|
36 | 41 | class RefreshToken(clientCredentialFetcher: ClientCredentialFetcher) extends GrantHandler { |
37 | 42 |
|
38 | | - override def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): GrantHandlerResult = { |
| 43 | + override def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): Future[GrantHandlerResult] = { |
39 | 44 | val clientCredential = clientCredentialFetcher.fetch(request).getOrElse(throw new InvalidRequest("BadRequest")) |
40 | 45 | val refreshToken = request.requireRefreshToken |
41 | | - val authInfo = dataHandler.findAuthInfoByRefreshToken(refreshToken).getOrElse(throw new InvalidGrant("NotFound")) |
42 | | - if (authInfo.clientId != clientCredential.clientId) { |
43 | | - throw new InvalidClient |
| 46 | + |
| 47 | + dataHandler.findAuthInfoByRefreshToken(refreshToken).flatMap { authInfoOption => |
| 48 | + val authInfo = authInfoOption.getOrElse(throw new InvalidGrant("NotFound")) |
| 49 | + if (authInfo.clientId != clientCredential.clientId) { |
| 50 | + throw new InvalidClient |
| 51 | + } |
| 52 | + |
| 53 | + dataHandler.refreshAccessToken(authInfo, refreshToken).map { accessToken => |
| 54 | + GrantHandlerResult( |
| 55 | + "Bearer", |
| 56 | + accessToken.token, |
| 57 | + accessToken.expiresIn, |
| 58 | + accessToken.refreshToken, |
| 59 | + accessToken.scope |
| 60 | + ) |
| 61 | + } |
44 | 62 | } |
45 | | - |
46 | | - val accessToken = dataHandler.refreshAccessToken(authInfo, refreshToken) |
47 | | - GrantHandlerResult( |
48 | | - "Bearer", |
49 | | - accessToken.token, |
50 | | - accessToken.expiresIn, |
51 | | - accessToken.refreshToken, |
52 | | - accessToken.scope |
53 | | - ) |
54 | 63 | } |
55 | 64 | } |
56 | 65 |
|
57 | 66 | class Password(clientCredentialFetcher: ClientCredentialFetcher) extends GrantHandler { |
58 | 67 |
|
59 | | - override def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): GrantHandlerResult = { |
| 68 | + override def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): Future[GrantHandlerResult] = { |
60 | 69 | val clientCredential = clientCredentialFetcher.fetch(request).getOrElse(throw new InvalidRequest("BadRequest")) |
61 | 70 | val username = request.requireUsername |
62 | 71 | val password = request.requirePassword |
63 | | - val user = dataHandler.findUser(username, password).getOrElse(throw new InvalidGrant()) |
64 | | - val scope = request.scope |
65 | | - val clientId = clientCredential.clientId |
66 | | - val authInfo = AuthInfo(user, clientId, scope, None) |
67 | 72 |
|
68 | | - issueAccessToken(dataHandler, authInfo) |
| 73 | + dataHandler.findUser(username, password).flatMap { userOption => |
| 74 | + val user = userOption.getOrElse(throw new InvalidGrant("username or password is incorrect")) |
| 75 | + val scope = request.scope |
| 76 | + val clientId = clientCredential.clientId |
| 77 | + val authInfo = AuthInfo(user, clientId, scope, None) |
| 78 | + |
| 79 | + issueAccessToken(dataHandler, authInfo) |
| 80 | + } |
69 | 81 | } |
70 | 82 | } |
71 | 83 |
|
72 | 84 | class ClientCredentials(clientCredentialFetcher: ClientCredentialFetcher) extends GrantHandler { |
73 | 85 |
|
74 | | - override def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): GrantHandlerResult = { |
| 86 | + override def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): Future[GrantHandlerResult] = { |
75 | 87 | val clientCredential = clientCredentialFetcher.fetch(request).getOrElse(throw new InvalidRequest("BadRequest")) |
76 | 88 | val clientSecret = clientCredential.clientSecret |
77 | 89 | val clientId = clientCredential.clientId |
78 | 90 | val scope = request.scope |
79 | | - val user = dataHandler.findClientUser(clientId, clientSecret, scope).getOrElse(throw new InvalidGrant()) |
80 | | - val authInfo = AuthInfo(user, clientId, scope, None) |
81 | | - |
82 | | - issueAccessToken(dataHandler, authInfo) |
| 91 | + |
| 92 | + dataHandler.findClientUser(clientId, clientSecret, scope).flatMap { userOption => |
| 93 | + val user = userOption.getOrElse(throw new InvalidGrant()) |
| 94 | + val authInfo = AuthInfo(user, clientId, scope, None) |
| 95 | + |
| 96 | + issueAccessToken(dataHandler, authInfo) |
| 97 | + } |
83 | 98 | } |
84 | 99 |
|
85 | 100 | } |
86 | 101 |
|
87 | 102 | class AuthorizationCode(clientCredentialFetcher: ClientCredentialFetcher) extends GrantHandler { |
88 | 103 |
|
89 | | - override def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): GrantHandlerResult = { |
| 104 | + override def handleRequest[U](request: AuthorizationRequest, dataHandler: DataHandler[U]): Future[GrantHandlerResult] = { |
90 | 105 | val clientCredential = clientCredentialFetcher.fetch(request).getOrElse(throw new InvalidRequest("BadRequest")) |
91 | 106 | val clientId = clientCredential.clientId |
92 | 107 | val code = request.requireCode |
93 | 108 | val redirectUri = request.redirectUri |
94 | | - val authInfo = dataHandler.findAuthInfoByCode(code).getOrElse(throw new InvalidGrant()) |
95 | | - if (authInfo.clientId != clientId) { |
96 | | - throw new InvalidClient |
97 | | - } |
98 | 109 |
|
99 | | - if (authInfo.redirectUri.isDefined && authInfo.redirectUri != redirectUri) { |
100 | | - throw new RedirectUriMismatch |
101 | | - } |
| 110 | + dataHandler.findAuthInfoByCode(code).flatMap { authInfoOption => |
| 111 | + val authInfo = authInfoOption.getOrElse(throw new InvalidGrant()) |
| 112 | + if (authInfo.clientId != clientId) { |
| 113 | + throw new InvalidClient |
| 114 | + } |
| 115 | + |
| 116 | + if (authInfo.redirectUri.isDefined && authInfo.redirectUri != redirectUri) { |
| 117 | + throw new RedirectUriMismatch |
| 118 | + } |
102 | 119 |
|
103 | | - issueAccessToken(dataHandler, authInfo) |
| 120 | + issueAccessToken(dataHandler, authInfo) |
| 121 | + } |
104 | 122 | } |
105 | 123 |
|
106 | 124 | } |
0 commit comments