|
1 | 1 | package scalaoauth2.provider |
2 | 2 |
|
3 | | -import org.apache.commons.codec.binary.Base64 |
| 3 | +import java.util.Base64 |
| 4 | + |
| 5 | +import scala.util.Try |
4 | 6 |
|
5 | 7 | case class ClientCredential(clientId: String, clientSecret: Option[String]) |
6 | 8 |
|
7 | 9 | class AuthorizationRequest(headers: Map[String, Seq[String]], params: Map[String, Seq[String]]) extends RequestBase(headers, params) { |
8 | 10 |
|
9 | | - /** |
10 | | - * Returns grant_type. |
11 | | - * |
12 | | - * OAuth defines four grant types: |
13 | | - * authorization code |
14 | | - * implicit |
15 | | - * resource owner password credentials, and client credentials. |
16 | | - * |
17 | | - * @return grant_type |
18 | | - */ |
19 | | - def grantType: String = requireParam("grant_type") |
20 | | - |
21 | | - /** |
22 | | - * Returns scope. |
23 | | - * |
24 | | - * @return scope |
25 | | - */ |
26 | 11 | def scope: Option[String] = param("scope") |
27 | 12 |
|
28 | | - lazy val clientCredential: Option[ClientCredential] = { |
29 | | - header("Authorization").flatMap { |
30 | | - """^\s*Basic\s+(.+?)\s*$""".r.findFirstMatchIn |
31 | | - } match { |
32 | | - case Some(matcher) => |
33 | | - val authorization = matcher.group(1) |
34 | | - val decoded = new String(Base64.decodeBase64(authorization.getBytes), "UTF-8") |
35 | | - if (decoded.indexOf(':') > 0) { |
36 | | - decoded.split(":", 2) match { |
37 | | - case Array(clientId, clientSecret) => Some(ClientCredential(clientId, if (clientSecret == "") None else Some(clientSecret))) |
38 | | - case Array(clientId) => Some(ClientCredential(clientId, None)) |
39 | | - } |
40 | | - } else { |
| 13 | + def grantType: String = requireParam("grant_type") |
| 14 | + |
| 15 | + lazy val clientCredential: Option[ClientCredential] = |
| 16 | + findAuthorization |
| 17 | + .flatMap(clientCredentialByAuthorization) |
| 18 | + .orElse(clientCredentialByParam) |
| 19 | + |
| 20 | + private def findAuthorization = for { |
| 21 | + authorization <- header("Authorization") |
| 22 | + matcher <- """^\s*Basic\s+(.+?)\s*$""".r.findFirstMatchIn(authorization) |
| 23 | + } yield matcher.group(1) |
| 24 | + |
| 25 | + private def clientCredentialByAuthorization(s: String) = |
| 26 | + Try(new String(Base64.getDecoder.decode(s), "UTF-8")) |
| 27 | + .map(_.split(":", 2)) |
| 28 | + .getOrElse(Array.empty) match { |
| 29 | + case Array(clientId, clientSecret) => |
| 30 | + Some(ClientCredential(clientId, if (clientSecret.isEmpty) None else Some(clientSecret))) |
| 31 | + case _ => |
41 | 32 | None |
42 | | - } |
43 | | - case _ => param("client_id").map { clientId => |
44 | | - ClientCredential(clientId, param("client_secret")) |
45 | 33 | } |
46 | | - } |
47 | | - } |
| 34 | + |
| 35 | + private def clientCredentialByParam = param("client_id").map(ClientCredential(_, param("client_secret"))) |
| 36 | + |
48 | 37 | } |
49 | 38 |
|
50 | 39 | case class RefreshTokenRequest(request: AuthorizationRequest) extends AuthorizationRequest(request.headers, request.params) { |
|
0 commit comments