2024-04-29 18:39:56 +05:30

280 lines
11 KiB
Kotlin

package com.restapi.config
import com.fasterxml.jackson.module.kotlin.readValue
import com.restapi.config.AppConfig.Companion.appConfig
import com.restapi.domain.Plant
import com.restapi.domain.Session
import com.restapi.domain.Session.objectMapper
import io.javalin.http.BadRequestResponse
import io.javalin.http.ContentType
import io.javalin.http.Context
import io.javalin.http.UnauthorizedResponse
import io.javalin.security.RouteRole
import org.apache.http.client.methods.HttpGet
import org.apache.http.impl.client.HttpClients
import org.apache.http.util.EntityUtils
import org.jose4j.jwk.HttpsJwks
import org.jose4j.jwt.consumer.JwtConsumerBuilder
import org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver
import org.slf4j.LoggerFactory
import java.net.URI
import java.net.URLEncoder
import java.net.http.HttpClient
import java.net.http.HttpRequest
import java.net.http.HttpResponse
import java.nio.charset.StandardCharsets
import java.time.LocalDateTime
import java.time.ZoneId
import java.util.*
import java.util.concurrent.ConcurrentHashMap
const val AUTH_TOKEN = "AUTH_TOKEN_V2"
private fun getFormDataAsString(formData: Map<String, String>): String {
return formData.entries.joinToString("&") {
val key = URLEncoder.encode(it.key, StandardCharsets.UTF_8)
val value = URLEncoder.encode(it.value, StandardCharsets.UTF_8)
"$key=$value"
}
}
object Auth {
private val logger = LoggerFactory.getLogger("Auth")
private val authCache = ConcurrentHashMap<String, AuthEndpoint>()
fun getAuthEndpoint(): AuthEndpoint {
return authCache.computeIfAbsent("AUTH") {
val wellKnown = "${appConfig.iamUrl()}/realms/${appConfig.iamRealm()}/.well-known/openid-configuration"
val client = HttpClient.newHttpClient()
val req = HttpRequest.newBuilder().uri(URI.create(wellKnown)).GET().build()
objectMapper.readValue<AuthEndpoint>(
client.send(req, HttpResponse.BodyHandlers.ofString()).body()
)
}
}
private val jwtConsumer =
JwtConsumerBuilder().setRequireExpirationTime().setAllowedClockSkewInSeconds(30).setRequireSubject()
.setExpectedAudience("account")
.setExpectedIssuer(getAuthEndpoint().issuer)
.setVerificationKeyResolver(HttpsJwksVerificationKeyResolver(HttpsJwks(getAuthEndpoint().jwksUri))).build()
private val jwtConsumerSkipValidate =
JwtConsumerBuilder().setSkipAllValidators()
.setVerificationKeyResolver(HttpsJwksVerificationKeyResolver(HttpsJwks(getAuthEndpoint().jwksUri))).build()
fun validateAuthToken(authToken: String, skipValidate: Boolean = false): AuthUser {
// Validate the JWT and process it to the Claims
val jwtClaims = if (skipValidate) jwtConsumerSkipValidate.process(authToken) else jwtConsumer.process(authToken)
val userId = jwtClaims.jwtClaims.claimsMap["preferred_username"] as String
val tenant = jwtClaims.jwtClaims.claimsMap["tenant"] as String
val plantIds = jwtClaims.jwtClaims.claimsMap["plantIds"] as List<String>
val roles = ((jwtClaims.jwtClaims.claimsMap["realm_access"] as Map<String, Any>)["roles"]) as List<String>
val date = Date(jwtClaims.jwtClaims.expirationTime.valueInMillis)
try {
HttpClients.createDefault().use { h ->
//sync plant's from rmc to here, just name and id
for (plantId in plantIds) {
val existing = Session.database.find(Plant::class.java).where().eq("plantId", plantId).findOne()
h.execute(HttpGet("${appConfig.integrationRmc()}/plant?id=${plantId}")).use { r ->
if (r.statusLine.statusCode == 200) {
val response = EntityUtils.toString(r.entity)
if (existing == null) {
Session.database.save(Plant().apply {
this.plantId = plantId
this.plantName = response
})
} else {
existing.apply {
this.plantName = response
this.save()
}
}
}
}
}
}
} catch (e: Exception) {
logger.warn("Exception in syncing plants", e)
}
return AuthUser(
userName = userId,
tenant = getTenantWithCompany(userId, tenant),
roles = roles,
token = authToken,
expiry = LocalDateTime.from(date.toInstant().atZone(ZoneId.systemDefault())),
plantIds = plantIds
)
}
private val userToTenant = ConcurrentHashMap<String, String>()
private fun getTenantWithCompany(userId: String, tenant: String): String {
return userToTenant.computeIfAbsent(userId) {
try {
HttpClients.createDefault().use { h ->
//sync plant's from rmc to here, just name and id
h.execute(HttpGet("${appConfig.integrationRmc()}/tenant?id=${userId}")).use { r ->
if (r.statusLine.statusCode == 200) {
"$tenant${EntityUtils.toString(r.entity)}"
}
}
}
} catch (e: Exception) {
logger.warn("Exception in syncing plants", e)
}
tenant
}
}
fun keys(ctx: Context) {
ctx.json(Session.jwk())
}
fun endPoint(ctx: Context) {
ctx.json(getAuthEndpoint())
}
fun init(ctx: Context) {
val endpoint = getAuthEndpoint().authorizationEndpoint
val redirectUrl =
"$endpoint?response_type=code&client_id=${appConfig.iamClient()}&redirect_uri=${appConfig.iamClientRedirectUri()}&scope=profile&state=1234zyx"
ctx.redirect(redirectUrl)
}
fun code(ctx: Context) {
val code = ctx.queryParam("code") ?: throw BadRequestResponse("not proper")
val redirectUri = ctx.queryParam("redirectUrl") ?: appConfig.iamClientRedirectUri()
val iamClient = ctx.queryParam("client") ?: appConfig.iamClient()
val ep = getAuthEndpoint().tokenEndpoint
val httpClient = HttpClient.newHttpClient()
val req = HttpRequest.newBuilder().uri(URI.create(ep)).POST(
HttpRequest.BodyPublishers.ofString(
getFormDataAsString(
mapOf(
"code" to code,
"redirect_uri" to redirectUri,
"client_id" to iamClient,
"grant_type" to "authorization_code",
)
)
)
).header("Content-Type", "application/x-www-form-urlencoded").build()
val message = httpClient.send(req, HttpResponse.BodyHandlers.ofString()).body()
val atResponse = objectMapper.readValue<AuthTokenResponse>(message)
val parsed = validateAuthToken(atResponse.accessToken)
//keep track of this for renewal when asked by client
Session.redis.lpush(
"$AUTH_TOKEN${parsed.userName}", objectMapper.writeValueAsString(
atResponse.copy(
createdAt = LocalDateTime.now()
)
)
)
ctx.result(atResponse.accessToken).contentType(ContentType.TEXT_PLAIN)
}
fun refreshToken(ctx: Context) {
//refresh authToken
val authToken = ctx.header("Authorization")?.replace("Bearer ", "")?.replace("Bearer: ", "")?.trim()
?: throw UnauthorizedResponse()
val authUser = validateAuthToken(authToken, skipValidate = true)
val client = ctx.queryParam("client") ?: throw BadRequestResponse("client not sent")
val redirectUri = ctx.queryParam("redirectUri") ?: throw BadRequestResponse("redirectUri not sent")
val key = "$AUTH_TOKEN${authUser.userName}"
val found = Session.redis.llen(key)
logger.warn("for user ${authUser.userName}, found from redis, $key => $found entries")
val foundOldAt = (0..found).mapNotNull { Session.redis.lindex(key, it) }
.map { objectMapper.readValue<AuthTokenResponse>(it) }.firstOrNull { it.accessToken == authToken }
?: throw BadRequestResponse("authToken not found in cache")
val createdAt = foundOldAt.createdAt ?: throw BadRequestResponse("created at is missing")
val expiresAt = createdAt.plusSeconds(foundOldAt.expiresIn + 0L)
val rtExpiresAt = createdAt.plusSeconds(foundOldAt.refreshExpiresIn + 0L)
val now = LocalDateTime.now()
logger.warn("can we refresh the token for ${authUser.userName}, created = $createdAt expires = $expiresAt, refresh Till = $rtExpiresAt")
//we can refresh if at is expired, but we still have time for refresh
if (expiresAt.isBefore(now) && now.isBefore(rtExpiresAt)) {
logger.warn("We can refresh the token for ${authUser.userName}, expires = $expiresAt, refresh Till = $rtExpiresAt")
val ep = getAuthEndpoint().tokenEndpoint
val httpClient = HttpClient.newHttpClient()
val req = HttpRequest.newBuilder().uri(URI.create(ep)).POST(
HttpRequest.BodyPublishers.ofString(
getFormDataAsString(
mapOf(
"refresh_token" to foundOldAt.refreshToken,
"redirect_uri" to redirectUri,
"client_id" to client,
"grant_type" to "refresh_token",
)
)
)
).header("Content-Type", "application/x-www-form-urlencoded").build()
val message = httpClient.send(req, HttpResponse.BodyHandlers.ofString()).body()
val atResponse = objectMapper.readValue<AuthTokenResponse>(message)
val parsed = validateAuthToken(atResponse.accessToken)
Session.redis.lpush(
"AUTH_TOKEN_${parsed.userName}", objectMapper.writeValueAsString(
atResponse.copy(createdAt = LocalDateTime.now())
)
)
ctx.result(atResponse.accessToken).contentType(ContentType.TEXT_PLAIN)
} else {
//at is still valid
if (expiresAt.isAfter(now)) {
logger.warn("Still valid, the token for ${authUser.userName}, will expire at $expiresAt")
ctx.result(foundOldAt.accessToken).contentType(ContentType.TEXT_PLAIN)
} else {
//we have exceeded the refresh time, so we shall ask the user to login again
logger.warn("We can't refresh the token for ${authUser.userName}, as refresh-time [$rtExpiresAt] is expired")
throw UnauthorizedResponse()
}
}
}
}
data class AuthUser(
val userName: String,
val tenant: String,
val roles: List<String>,
val token: String,
val expiry: LocalDateTime,
val plantIds: List<String>
)
enum class Action {
CREATE, VIEW, UPDATE, DELETE, APPROVE, ADMIN
}
sealed class Role {
open class Standard(vararg val action: Action) : Role()
data object Entity : Role()
open class Explicit(vararg val roles: String) : Role()
data object DbOps : Role()
}
open class Roles(vararg val roles: Role) : RouteRole