2024-05-09 14:18:18 +05:30

318 lines
13 KiB
Kotlin

package com.restapi.config
import com.fasterxml.jackson.module.kotlin.readValue
import com.restapi.config.AppConfig.Companion.appConfig
import com.restapi.domain.AuthTokenCache
import com.restapi.domain.Plant
import com.restapi.domain.RefreshHistory
import com.restapi.domain.Session
import com.restapi.domain.Session.database
import com.restapi.domain.Session.objectMapper
import io.javalin.http.BadRequestResponse
import io.javalin.http.ContentType
import io.javalin.http.Context
import io.javalin.http.UnauthorizedResponse
import io.javalin.security.RouteRole
import org.apache.http.client.methods.HttpGet
import org.apache.http.impl.client.HttpClients
import org.apache.http.util.EntityUtils
import org.jose4j.jwk.HttpsJwks
import org.jose4j.jwt.consumer.JwtConsumerBuilder
import org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver
import org.slf4j.LoggerFactory
import java.net.URI
import java.net.URLEncoder
import java.net.http.HttpClient
import java.net.http.HttpRequest
import java.net.http.HttpResponse
import java.nio.charset.StandardCharsets
import java.time.LocalDateTime
import java.time.ZoneId
import java.time.format.DateTimeFormatter
import java.util.*
import java.util.concurrent.ConcurrentHashMap
private fun getFormDataAsString(formData: Map<String, String>): String {
return formData.entries.joinToString("&") {
val key = URLEncoder.encode(it.key, StandardCharsets.UTF_8)
val value = URLEncoder.encode(it.value, StandardCharsets.UTF_8)
"$key=$value"
}
}
object Auth {
private val logger = LoggerFactory.getLogger("Auth")
private val authCache = ConcurrentHashMap<String, AuthEndpoint>()
private fun getAuthEndpoint(): AuthEndpoint {
return authCache.computeIfAbsent("AUTH") {
val wellKnown = "${appConfig.iamUrl()}/realms/${appConfig.iamRealm()}/.well-known/openid-configuration"
val client = HttpClient.newHttpClient()
val req = HttpRequest.newBuilder().uri(URI.create(wellKnown)).GET().build()
objectMapper.readValue<AuthEndpoint>(
client.send(req, HttpResponse.BodyHandlers.ofString()).body()
)
}
}
private val jwtConsumer =
JwtConsumerBuilder().setRequireExpirationTime().setAllowedClockSkewInSeconds(30).setRequireSubject()
.setExpectedAudience("account")
.setExpectedIssuer(getAuthEndpoint().issuer)
.setVerificationKeyResolver(HttpsJwksVerificationKeyResolver(HttpsJwks(getAuthEndpoint().jwksUri))).build()
private val jwtConsumerSkipValidate =
JwtConsumerBuilder().setSkipAllValidators()
.setVerificationKeyResolver(HttpsJwksVerificationKeyResolver(HttpsJwks(getAuthEndpoint().jwksUri))).build()
fun validateAuthToken(authToken: String, skipValidate: Boolean = false): AuthUser {
// Validate the JWT and process it to the Claims
val jwtClaims = if (skipValidate) jwtConsumerSkipValidate.process(authToken) else jwtConsumer.process(authToken)
val userId = jwtClaims.jwtClaims.claimsMap["preferred_username"] as String
val tenant = jwtClaims.jwtClaims.claimsMap["tenant"] as String
val plantIds = jwtClaims.jwtClaims.claimsMap["plantIds"] as List<String>
val roles = ((jwtClaims.jwtClaims.claimsMap["realm_access"] as Map<String, Any>)["roles"]) as List<String>
val date = Date(jwtClaims.jwtClaims.expirationTime.valueInMillis)
try {
HttpClients.createDefault().use { h ->
//sync plant's from rmc to here, just name and id
for (plantId in plantIds) {
val existing = Session.database.find(Plant::class.java).where().eq("plantId", plantId).findOne()
h.execute(HttpGet("${appConfig.integrationRmc()}/plant?id=${plantId}")).use { r ->
if (r.statusLine.statusCode == 200) {
val response = EntityUtils.toString(r.entity)
if (existing == null) {
Session.database.save(Plant().apply {
this.plantId = plantId
this.plantName = response
})
} else {
existing.apply {
this.plantName = response
this.save()
}
}
}
}
}
}
} catch (e: Exception) {
logger.warn("Exception in syncing plants", e)
}
return AuthUser(
userName = userId,
tenant = getTenantWithCompany(userId, tenant),
roles = roles,
token = authToken,
expiry = LocalDateTime.from(date.toInstant().atZone(ZoneId.systemDefault())),
plantIds = plantIds
)
}
private val userToTenant = ConcurrentHashMap<String, String>()
private fun getTenantWithCompany(userId: String, tenant: String): String {
return userToTenant.computeIfAbsent(userId) {
try {
HttpClients.createDefault().use { h ->
//sync plant's from rmc to here, just name and id
h.execute(HttpGet("${appConfig.integrationRmc()}/tenant?id=${userId}")).use { r ->
if (r.statusLine.statusCode == 200) {
"$tenant${EntityUtils.toString(r.entity)}"
}
}
}
} catch (e: Exception) {
logger.warn("Exception in syncing plants", e)
}
tenant
}
}
fun keys(ctx: Context) {
ctx.json(Session.jwk())
}
fun endPoint(ctx: Context) {
ctx.json(getAuthEndpoint())
}
fun init(ctx: Context) {
val endpoint = getAuthEndpoint().authorizationEndpoint
val redirectUrl =
"$endpoint?response_type=code&client_id=${appConfig.iamClient()}&redirect_uri=${appConfig.iamClientRedirectUri()}&scope=profile&state=1234zyx"
ctx.redirect(redirectUrl)
}
fun code(ctx: Context) {
val code = ctx.queryParam("code") ?: throw BadRequestResponse("not proper")
val redirectUri = ctx.queryParam("redirectUrl") ?: appConfig.iamClientRedirectUri()
val iamClient = ctx.queryParam("client") ?: appConfig.iamClient()
val ep = getAuthEndpoint().tokenEndpoint
val httpClient = HttpClient.newHttpClient()
val req = HttpRequest.newBuilder().uri(URI.create(ep)).POST(
HttpRequest.BodyPublishers.ofString(
getFormDataAsString(
mapOf(
"code" to code,
"redirect_uri" to redirectUri,
"client_id" to iamClient,
"grant_type" to "authorization_code",
)
)
)
).header("Content-Type", "application/x-www-form-urlencoded").build()
val message = httpClient.send(req, HttpResponse.BodyHandlers.ofString()).body()
val atResponse = objectMapper.readValue<AuthTokenResponse>(message)
val parsed = validateAuthToken(atResponse.accessToken)
database.save(AuthTokenCache().apply {
this.userId = parsed.userName
this.authToken = atResponse.accessToken
this.expiresAt = LocalDateTime.now().plusSeconds(atResponse.expiresIn.toLong())
this.refreshToken = atResponse.refreshToken
this.refreshExpiresAt = LocalDateTime.now().plusSeconds(atResponse.refreshExpiresIn.toLong())
this.refreshHistory = arrayListOf()
})
ctx.result(atResponse.accessToken).contentType(ContentType.TEXT_PLAIN)
}
fun logout(ctx: Context) {
val authToken = ctx.header("Authorization")?.replace("Bearer ", "")?.replace("Bearer: ", "")?.trim()
?: return
val authUser = validateAuthToken(authToken, skipValidate = true)
logger.warn("User ${authUser.userName} is logging out")
database.updateAll(
database.find(AuthTokenCache::class.java)
.where()
.eq("authToken", authToken)
.findList()
.onEach {
it.loggedOut = true
})
ctx.json(mapOf("status" to true))
}
fun refreshToken(ctx: Context) {
//refresh authToken
val authToken = ctx.header("Authorization")?.replace("Bearer ", "")?.replace("Bearer: ", "")?.trim()
?: throw UnauthorizedResponse()
val authUser = validateAuthToken(authToken, skipValidate = true)
val client = ctx.queryParam("client") ?: throw BadRequestResponse("client not sent")
val redirectUri = ctx.queryParam("redirectUri") ?: throw BadRequestResponse("redirectUri not sent")
val foundOldAt = database.find(AuthTokenCache::class.java)
.where()
.eq("userId", authUser.userName)
.eq("expired", false)
.eq("loggedOut", false)
.gt("refreshExpiresAt", LocalDateTime.now())
.findList()
.onEach {
logger.warn("valid authToken for ${authUser.userName} is ${it.authToken}")
}
.firstOrNull {
it.authToken.equals(authToken, ignoreCase = true)
} ?: throw BadRequestResponse("we did not find an entry for this auth token $authToken")
val createdAt = foundOldAt.createdAt
val expiresAt = foundOldAt.expiresAt
val rtExpiresAt = foundOldAt.refreshExpiresAt
val now = LocalDateTime.now()
logger.warn("can we refresh the token for ${authUser.userName}, created = $createdAt expires = $expiresAt, refresh Till = $rtExpiresAt")
val authTokenValid = expiresAt.isAfter(now)
if (authTokenValid) {
ctx.result(authToken).contentType(ContentType.TEXT_PLAIN)
return
}
//we can refresh if at is expired, but we still have time for refresh
val refreshTokenValid = rtExpiresAt.isAfter(now)
if (refreshTokenValid) {
logger.warn("We can refresh the token for ${authUser.userName}, expires = $expiresAt, refresh Till = $rtExpiresAt")
val ep = getAuthEndpoint().tokenEndpoint
val httpClient = HttpClient.newHttpClient()
val req = HttpRequest.newBuilder().uri(URI.create(ep)).POST(
HttpRequest.BodyPublishers.ofString(
getFormDataAsString(
mapOf(
"refresh_token" to foundOldAt.refreshToken,
"redirect_uri" to redirectUri,
"client_id" to client,
"grant_type" to "refresh_token",
)
)
)
).header("Content-Type", "application/x-www-form-urlencoded").build()
val message = httpClient.send(req, HttpResponse.BodyHandlers.ofString()).body()
val atResponse = objectMapper.readValue<AuthTokenResponse>(message)
foundOldAt.authToken = atResponse.accessToken
foundOldAt.expiresAt = LocalDateTime.now().plusSeconds(atResponse.expiresIn.toLong())
foundOldAt.refreshExpiresAt = LocalDateTime.now().plusSeconds(atResponse.refreshExpiresIn.toLong())
foundOldAt.refreshToken = atResponse.refreshToken
foundOldAt.refreshHistory = (foundOldAt.refreshHistory ?: arrayListOf()).apply {
add(RefreshHistory().apply {
oldAt = authUser.token
oldExpiryAt = expiresAt.format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)
newAt = atResponse.accessToken
newExpiryAt = LocalDateTime.now().plusSeconds(atResponse.expiresIn.toLong()).format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)
this.createdAt = LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)
})
}
database.update(foundOldAt)
ctx.result(atResponse.accessToken).contentType(ContentType.TEXT_PLAIN)
} else {
//at is still valid
//we have exceeded the refresh time, so we shall ask the user to login again
logger.warn("We can't refresh the token for ${authUser.userName}, as refresh-time [$rtExpiresAt] is expired")
throw UnauthorizedResponse()
}
}
}
data class AuthUser(
val userName: String,
val tenant: String,
val roles: List<String>,
val token: String,
val expiry: LocalDateTime,
val plantIds: List<String>
)
enum class Action {
CREATE, VIEW, UPDATE, DELETE, APPROVE, ADMIN
}
sealed class Role {
open class Standard(vararg val action: Action) : Role()
data object Entity : Role()
open class Explicit(vararg val roles: String) : Role()
data object DbOps : Role()
}
open class Roles(vararg val roles: Role) : RouteRole