Skip to content

Commit 425e7f9

Browse files
karpetrosyanstainless-app[bot]
authored andcommitted
feat(client): add support for short-lived tokens (#1185)
1 parent e3f45b2 commit 425e7f9

21 files changed

Lines changed: 2204 additions & 8 deletions

README.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,83 @@ OpenAIClient clientWithOptions = client.withOptions(optionsBuilder -> {
156156

157157
The `withOptions()` method does not affect the original client or service.
158158

159+
### Workload identity authentication
160+
161+
Workload identity authentication allows applications running in cloud environments (Kubernetes, Azure, GCP) to authenticate using short-lived tokens issued by the cloud provider, instead of long-lived API keys.
162+
163+
#### Basic setup
164+
165+
```java
166+
import com.openai.auth.*;
167+
import com.openai.client.OpenAIClient;
168+
import com.openai.client.okhttp.OpenAIOkHttpClient;
169+
170+
SubjectTokenProvider provider = K8sServiceAccountTokenProvider.builder().build();
171+
172+
WorkloadIdentity workloadIdentity = WorkloadIdentity.builder()
173+
.clientId("your-client-id")
174+
.identityProviderId("your-identity-provider-id")
175+
.serviceAccountId("your-service-account-id")
176+
.provider(provider)
177+
.build();
178+
179+
OpenAIClient client = OpenAIOkHttpClient.builder()
180+
.workloadIdentity(workloadIdentity)
181+
.build();
182+
```
183+
184+
#### Kubernetes service account token provider
185+
186+
```java
187+
// Use default token path (/var/run/secrets/kubernetes.io/serviceaccount/token)
188+
SubjectTokenProvider provider = K8sServiceAccountTokenProvider.builder().build();
189+
```
190+
191+
```java
192+
// Or specify a custom token path
193+
SubjectTokenProvider provider = K8sServiceAccountTokenProvider.builder()
194+
.tokenPath("/custom/path/to/token")
195+
.build();
196+
```
197+
198+
#### Azure Managed Identity provider
199+
200+
```java
201+
import com.openai.auth.*;
202+
203+
// Use defaults (resource: https://management.azure.com/, api-version: 2018-02-01)
204+
SubjectTokenProvider provider = AzureManagedIdentityTokenProvider.builder()
205+
.build();
206+
```
207+
208+
```java
209+
import com.openai.auth.*;
210+
211+
// Or customize
212+
SubjectTokenProvider provider = AzureManagedIdentityTokenProvider.builder()
213+
.resource("https://management.azure.com/")
214+
.apiVersion("2018-02-01")
215+
.build();
216+
```
217+
218+
#### GCP ID token provider
219+
220+
```java
221+
import com.openai.auth.*;
222+
223+
SubjectTokenProvider provider = GcpIdTokenProvider.builder()
224+
.build();
225+
```
226+
227+
```java
228+
import com.openai.auth.*;
229+
230+
// Or customize the audience
231+
SubjectTokenProvider provider = GcpIdTokenProvider.builder()
232+
.audience("https://api.openai.com/v1")
233+
.build();
234+
```
235+
159236
## Requests and responses
160237

161238
To send a request to the OpenAI API, build an instance of some `Params` class and pass it to the corresponding client method. When the response is received, it will be deserialized into an instance of a Java class.

openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OpenAIOkHttpClient.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package com.openai.client.okhttp
44

55
import com.fasterxml.jackson.databind.json.JsonMapper
6+
import com.openai.auth.WorkloadIdentity
67
import com.openai.azure.AzureOpenAIServiceVersion
78
import com.openai.azure.AzureUrlPathMode
89
import com.openai.client.OpenAIClient
@@ -277,6 +278,14 @@ class OpenAIOkHttpClient private constructor() {
277278

278279
fun credential(credential: Credential) = apply { clientOptions.credential(credential) }
279280

281+
fun workloadIdentity(workloadIdentity: WorkloadIdentity?) = apply {
282+
clientOptions.workloadIdentity(workloadIdentity)
283+
}
284+
285+
/** Alias for calling [Builder.workloadIdentity] with `workloadIdentity.orElse(null)`. */
286+
fun workloadIdentity(workloadIdentity: Optional<WorkloadIdentity>) =
287+
workloadIdentity(workloadIdentity.getOrNull())
288+
280289
fun azureServiceVersion(azureServiceVersion: AzureOpenAIServiceVersion) = apply {
281290
clientOptions.azureServiceVersion(azureServiceVersion)
282291
}

openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OpenAIOkHttpClientAsync.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package com.openai.client.okhttp
44

55
import com.fasterxml.jackson.databind.json.JsonMapper
6+
import com.openai.auth.WorkloadIdentity
67
import com.openai.azure.AzureOpenAIServiceVersion
78
import com.openai.azure.AzureUrlPathMode
89
import com.openai.client.OpenAIClientAsync
@@ -277,6 +278,14 @@ class OpenAIOkHttpClientAsync private constructor() {
277278

278279
fun credential(credential: Credential) = apply { clientOptions.credential(credential) }
279280

281+
fun workloadIdentity(workloadIdentity: WorkloadIdentity?) = apply {
282+
clientOptions.workloadIdentity(workloadIdentity)
283+
}
284+
285+
/** Alias for calling [Builder.workloadIdentity] with `workloadIdentity.orElse(null)`. */
286+
fun workloadIdentity(workloadIdentity: Optional<WorkloadIdentity>) =
287+
workloadIdentity(workloadIdentity.getOrNull())
288+
280289
fun azureServiceVersion(azureServiceVersion: AzureOpenAIServiceVersion) = apply {
281290
clientOptions.azureServiceVersion(azureServiceVersion)
282291
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
package com.openai.auth
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties
4+
import com.fasterxml.jackson.annotation.JsonProperty
5+
import com.fasterxml.jackson.databind.json.JsonMapper
6+
import com.fasterxml.jackson.module.kotlin.jacksonTypeRef
7+
import com.openai.core.http.HttpClient
8+
import com.openai.core.http.HttpMethod
9+
import com.openai.core.http.HttpRequest
10+
import com.openai.errors.SubjectTokenProviderException
11+
import java.util.concurrent.CompletableFuture
12+
import java.util.concurrent.CompletionException
13+
14+
private const val DEFAULT_AUDIENCE = "https://management.azure.com/"
15+
private const val DEFAULT_AZURE_API_VERSION = "2018-02-01"
16+
private const val AZURE_IMDS_BASE_URL = "http://169.254.169.254/metadata/identity/oauth2/token"
17+
18+
/**
19+
* A [SubjectTokenProvider] that fetches an identity token from the Azure Instance Metadata Service
20+
* (IMDS).
21+
*
22+
* It calls the local IMDS endpoint and returns the `access_token` from the response.
23+
*/
24+
class AzureManagedIdentityTokenProvider
25+
private constructor(private val resource: String, private val apiVersion: String) :
26+
SubjectTokenProvider {
27+
28+
override fun tokenType(): SubjectTokenType = SubjectTokenType.JWT
29+
30+
override fun getToken(httpClient: HttpClient, jsonMapper: JsonMapper): String {
31+
val request =
32+
HttpRequest.builder()
33+
.method(HttpMethod.GET)
34+
.baseUrl(AZURE_IMDS_BASE_URL)
35+
.putHeader("Metadata", "true")
36+
.putQueryParam("api-version", apiVersion)
37+
.putQueryParam("resource", resource)
38+
.build()
39+
40+
return try {
41+
val response = httpClient.execute(request)
42+
response.use {
43+
if (response.statusCode() != 200) {
44+
throw SubjectTokenProviderException(
45+
provider = "azure-imds",
46+
message = "IMDS returned status ${response.statusCode()}",
47+
)
48+
}
49+
50+
val result =
51+
jsonMapper.readValue(response.body(), jacksonTypeRef<AzureIMDSResponse>())
52+
53+
if (result.accessToken.isEmpty()) {
54+
throw SubjectTokenProviderException(
55+
provider = "azure-imds",
56+
message = "IMDS response missing 'access_token' field",
57+
)
58+
}
59+
60+
result.accessToken
61+
}
62+
} catch (e: SubjectTokenProviderException) {
63+
throw e
64+
} catch (e: Exception) {
65+
throw SubjectTokenProviderException(
66+
provider = "azure-imds",
67+
message = "failed to fetch token from IMDS",
68+
cause = e,
69+
)
70+
}
71+
}
72+
73+
override fun getTokenAsync(
74+
httpClient: HttpClient,
75+
jsonMapper: JsonMapper,
76+
): CompletableFuture<String> {
77+
val request =
78+
HttpRequest.builder()
79+
.method(HttpMethod.GET)
80+
.baseUrl(AZURE_IMDS_BASE_URL)
81+
.putHeader("Metadata", "true")
82+
.putQueryParam("api-version", apiVersion)
83+
.putQueryParam("resource", resource)
84+
.build()
85+
86+
return httpClient
87+
.executeAsync(request)
88+
.thenApply { response ->
89+
response.use {
90+
if (response.statusCode() != 200) {
91+
throw SubjectTokenProviderException(
92+
provider = "azure-imds",
93+
message = "IMDS returned status ${response.statusCode()}",
94+
)
95+
}
96+
97+
val result =
98+
jsonMapper.readValue(response.body(), jacksonTypeRef<AzureIMDSResponse>())
99+
100+
if (result.accessToken.isEmpty()) {
101+
throw SubjectTokenProviderException(
102+
provider = "azure-imds",
103+
message = "IMDS response missing 'access_token' field",
104+
)
105+
}
106+
107+
result.accessToken
108+
}
109+
}
110+
.exceptionally { e ->
111+
val cause = if (e is CompletionException) e.cause ?: e else e
112+
if (cause is SubjectTokenProviderException) throw cause
113+
throw SubjectTokenProviderException(
114+
provider = "azure-imds",
115+
message = "failed to fetch token from IMDS",
116+
cause = cause,
117+
)
118+
}
119+
}
120+
121+
@JsonIgnoreProperties(ignoreUnknown = true)
122+
private data class AzureIMDSResponse(@JsonProperty("access_token") val accessToken: String)
123+
124+
companion object {
125+
@JvmStatic fun builder() = Builder()
126+
}
127+
128+
class Builder internal constructor() {
129+
130+
private var resource: String = DEFAULT_AUDIENCE
131+
private var apiVersion: String = DEFAULT_AZURE_API_VERSION
132+
133+
/**
134+
* The Azure resource URI to request a token for (default: `https://management.azure.com/`).
135+
*/
136+
fun resource(resource: String) = apply { this.resource = resource }
137+
138+
/** The IMDS API version to use (default: `2018-02-01`). */
139+
fun apiVersion(apiVersion: String) = apply { this.apiVersion = apiVersion }
140+
141+
fun build(): AzureManagedIdentityTokenProvider =
142+
AzureManagedIdentityTokenProvider(resource, apiVersion)
143+
}
144+
}

0 commit comments

Comments
 (0)