1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
| import kotlinx.datetime.Clock
import io.ktor.client.statement.bodyAsText
import io.ktor.http.HttpMethod
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.put
import kotlinx.serialization.serializer
import com.auth0.jwt.JWT
import com.auth0.jwt.algorithms.Algorithm
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import java.security.KeyFactory
import java.security.interfaces.RSAPrivateKey
import java.security.spec.PKCS8EncodedKeySpec
import java.util.Base64
object JWTService {
suspend fun getJWTToken(
config: JWTTokenConfig,
options: RequestOptions? = null
): JWTToken {
// Trim private key and validate format
val trimmedPrivateKey = config.privateKey.trim()
val keyFormat = when {
trimmedPrivateKey.contains("BEGIN RSA PRIVATE KEY") -> "RSA"
trimmedPrivateKey.contains("BEGIN PRIVATE KEY") -> "PKCS8"
else -> null
}
if (keyFormat == null) {
throw Exception(
"Invalid private key format. Expected PEM format (RSA or PKCS8)"
)
}
// 准备JWT payload
val now = Clock.System.now().epochSeconds
val jwtPayload = buildJsonObject {
put("iss", config.appId)
put("aud", config.aud)
put("iat", now)
put("exp", now + 3600) // 1小时
put("jti", now.toString(16))
if (config.sessionName != null) {
put("session_name", config.sessionName)
}
}
// 将JsonObject转换为Map
val jwtPayloadMap = jwtPayload.toMap()
// 使用JWT provider签名获取token
val token = sign(
payload = jwtPayloadMap,
privateKey = trimmedPrivateKey,
algorithm = config.algorithm ?: "RS256",
keyid = config.keyId
)
// 交换JWT token获取OAuth token
val tokenConfig = buildMap<String, Any> {
put("token", token)
config.baseURL?.let { put("baseURL", it) }
put("durationSeconds", config.durationSeconds ?: 900)
config.scope?.let { put("scope", it) }
}
return doGetJWTToken(tokenConfig, options)
}
private suspend fun doGetJWTToken(
config: Map<String, Any>,
options: RequestOptions? = null
): JWTToken {
val api = APIClient(token = config["token"] as String, baseURL = config["baseURL"] as? String)
val payload = buildJsonObject {
put("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
put("duration_seconds", (config["durationSeconds"] as? Int ?: 900).toInt())
if (config["scope"] != null) {
put("scope", config["scope"].toString())
}
}
val jsonPayload = Json.encodeToString(JsonObject.serializer(), payload)
val response = api.request(HttpMethod.Post, "/api/permission/oauth2/token", config["token"] as String, payload, options)
return Json.decodeFromString(serializer<JWTToken>(), response.bodyAsText())
}
// 对 json 里的内容做下处理,方便未来扩展
private fun JsonObject.toMap(): Map<String, Any> {
return entries.associate { (key, element) ->
key to when (element) {
is kotlinx.serialization.json.JsonPrimitive -> {
when {
element.isString -> element.content
element.content.toLongOrNull() != null -> element.content.toLong()
element.content.toDoubleOrNull() != null -> element.content.toDouble()
element.content == "true" -> true
element.content == "false" -> false
else -> element.content
}
}
else -> element.toString()
}
}
}
private fun sign(
payload: Map<String, Any>,
privateKey: String,
algorithm: String,
keyid: String
): String {
val cleanKey = privateKey
.replace("-----BEGIN PRIVATE KEY-----", "")
.replace("-----END PRIVATE KEY-----", "")
.replace("-----BEGIN RSA PRIVATE KEY-----", "")
.replace("-----END RSA PRIVATE KEY-----", "")
.replace("\n", "")
.trim()
println("JWT sign - 私钥清理完成")
val keyBytes = Base64.getDecoder().decode(cleanKey) // requires API level 26
println("JWT sign - 私钥解码完成,${keyBytes}")
val keySpec = PKCS8EncodedKeySpec(keyBytes)
val keyFactory = KeyFactory.getInstance("RSA")
val privateKeyObj = keyFactory.generatePrivate(keySpec) as RSAPrivateKey
val alg = Algorithm.RSA256(null, privateKeyObj)
println("JWT sign - 私钥引用创建完成,${alg}")
return JWT.create()
.withKeyId(keyid)
.apply {
payload.forEach { (key, value) ->
when (value) {
is String -> withClaim(key, value)
is Int -> withClaim(key, value)
is Long -> withClaim(key, value)
is Double -> withClaim(key, value)
is Boolean -> withClaim(key, value)
}
}
}
.sign(alg)
}
}
@Serializable
data class JWTToken(
@SerialName("access_token")
val accessToken: String,
@SerialName("token_type")
val tokenType: String,
@SerialName("expires_in")
val expiresIn: Long
)
@Serializable
data class JWTTokenConfig(
val appId: String,
val privateKey: String,
val aud: String = "api.coze.com", // 这里可以换成自己的 AUD
val algorithm: String? = "RS256",
val keyId: String,
val sessionName: String? = null,
val baseURL: String? = null,
val durationSeconds: Int? = 900,
val scope: String? = null
)
|