Skip to content

Commit 8d4f016

Browse files
committed
add support for authentication through Azure MSI
1 parent 89297b1 commit 8d4f016

File tree

5 files changed

+582
-0
lines changed

5 files changed

+582
-0
lines changed

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Release v0.104.0
44

55
### New Features and Improvements
6+
* Add support for authentication through Azure Managed Service Identity (MSI) via the new `azure-msi` credential provider.
67
* Added automatic detection of AI coding agents (Antigravity, Claude Code, Cline, Codex, Copilot CLI, Cursor, Gemini CLI, OpenCode) in the user-agent string. The SDK now appends `agent/<name>` to HTTP request headers when running inside a known AI agent environment.
78

89
### Bug Fixes
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package com.databricks.sdk.core;
2+
3+
import com.databricks.sdk.core.oauth.AzureMsiTokenSource;
4+
import com.databricks.sdk.core.oauth.CachedTokenSource;
5+
import com.databricks.sdk.core.oauth.OAuthHeaderFactory;
6+
import com.databricks.sdk.core.oauth.Token;
7+
import com.databricks.sdk.core.utils.AzureUtils;
8+
import com.databricks.sdk.support.InternalApi;
9+
import com.fasterxml.jackson.databind.ObjectMapper;
10+
import java.util.HashMap;
11+
import java.util.Map;
12+
import org.slf4j.Logger;
13+
import org.slf4j.LoggerFactory;
14+
15+
/**
16+
* Adds refreshed Azure Active Directory (AAD) tokens obtained via Azure Managed Service Identity
17+
* (MSI) to every request. This provider authenticates using the Azure Instance Metadata Service
18+
* (IMDS) endpoint, which is available on Azure VMs and other compute resources with managed
19+
* identities enabled.
20+
*/
21+
@InternalApi
22+
public class AzureMsiCredentialsProvider implements CredentialsProvider {
23+
private static final Logger LOG = LoggerFactory.getLogger(AzureMsiCredentialsProvider.class);
24+
private final ObjectMapper mapper = new ObjectMapper();
25+
26+
@Override
27+
public String authType() {
28+
return "azure-msi";
29+
}
30+
31+
@Override
32+
public OAuthHeaderFactory configure(DatabricksConfig config) {
33+
if (!config.isAzure()) {
34+
return null;
35+
}
36+
37+
if (!isAzureUseMsi(config)) {
38+
return null;
39+
}
40+
41+
if (config.getAzureWorkspaceResourceId() == null && config.getHost() == null) {
42+
return null;
43+
}
44+
45+
LOG.debug("Generating AAD token via Azure MSI");
46+
47+
AzureUtils.ensureHostPresent(config, mapper, this::tokenSourceFor);
48+
49+
CachedTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
50+
CachedTokenSource cloud =
51+
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
52+
53+
return OAuthHeaderFactory.fromSuppliers(
54+
inner::getToken,
55+
() -> {
56+
Token token = inner.getToken();
57+
Map<String, String> headers = new HashMap<>();
58+
headers.put("Authorization", "Bearer " + token.getAccessToken());
59+
AzureUtils.addSpManagementToken(cloud, headers);
60+
AzureUtils.addWorkspaceResourceId(config, headers);
61+
return headers;
62+
});
63+
}
64+
65+
/**
66+
* Null-safe check for the azureUseMsi config flag. The underlying field is a boxed Boolean, but
67+
* the getter auto-unboxes to primitive boolean, which would NPE when the field is null. This
68+
* helper treats null as false.
69+
*/
70+
private static boolean isAzureUseMsi(DatabricksConfig config) {
71+
try {
72+
return config.getAzureUseMsi();
73+
} catch (NullPointerException e) {
74+
return false;
75+
}
76+
}
77+
78+
/**
79+
* Creates a CachedTokenSource for the specified Azure resource using MSI authentication.
80+
*
81+
* @param config The DatabricksConfig instance containing the required authentication parameters.
82+
* @param resource The Azure resource for which OAuth tokens need to be fetched.
83+
* @return A CachedTokenSource instance capable of fetching OAuth tokens for the specified Azure
84+
* resource.
85+
*/
86+
CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
87+
AzureMsiTokenSource tokenSource =
88+
new AzureMsiTokenSource(config.getHttpClient(), resource, config.getAzureClientId());
89+
return new CachedTokenSource.Builder(tokenSource)
90+
.setAsyncDisabled(config.getDisableAsyncTokenRefresh())
91+
.build();
92+
}
93+
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ private synchronized void addDefaultCredentialsProviders(DatabricksConfig config
179179
addOIDCCredentialsProviders(config);
180180

181181
providers.add(new AzureGithubOidcCredentialsProvider());
182+
providers.add(new AzureMsiCredentialsProvider());
182183
providers.add(new AzureServicePrincipalCredentialsProvider());
183184
providers.add(new AzureCliCredentialsProvider());
184185
providers.add(new ExternalBrowserCredentialsProvider());
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import com.databricks.sdk.core.DatabricksException;
4+
import com.databricks.sdk.core.http.HttpClient;
5+
import com.databricks.sdk.core.http.Request;
6+
import com.databricks.sdk.core.http.Response;
7+
import com.databricks.sdk.support.InternalApi;
8+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
9+
import com.fasterxml.jackson.annotation.JsonProperty;
10+
import com.fasterxml.jackson.databind.ObjectMapper;
11+
import java.io.IOException;
12+
import java.time.Instant;
13+
14+
/**
15+
* A {@link TokenSource} that fetches OAuth tokens from the Azure Instance Metadata Service (IMDS)
16+
* endpoint for Managed Service Identity (MSI) authentication.
17+
*
18+
* <p>This token source makes HTTP GET requests to the well-known IMDS endpoint at {@code
19+
* http://169.254.169.254/metadata/identity/oauth2/token} to obtain access tokens for the specified
20+
* Azure resource.
21+
*/
22+
@InternalApi
23+
public class AzureMsiTokenSource implements TokenSource {
24+
25+
private static final String IMDS_ENDPOINT =
26+
"http://169.254.169.254/metadata/identity/oauth2/token";
27+
28+
private final HttpClient httpClient;
29+
private final String resource;
30+
private final String clientId;
31+
private final ObjectMapper mapper = new ObjectMapper();
32+
33+
/** Response structure from the Azure IMDS token endpoint. */
34+
@JsonIgnoreProperties(ignoreUnknown = true)
35+
static class MsiTokenResponse {
36+
@JsonProperty("token_type")
37+
private String tokenType;
38+
39+
@JsonProperty("access_token")
40+
private String accessToken;
41+
42+
@JsonProperty("expires_on")
43+
private String expiresOn;
44+
45+
Token toToken() {
46+
if (accessToken == null || accessToken.isEmpty()) {
47+
throw new DatabricksException("MSI token response missing or empty 'access_token' field");
48+
}
49+
if (tokenType == null || tokenType.isEmpty()) {
50+
throw new DatabricksException("MSI token response missing or empty 'token_type' field");
51+
}
52+
if (expiresOn == null || expiresOn.isEmpty()) {
53+
throw new DatabricksException("MSI token response missing 'expires_on' field");
54+
}
55+
long epoch;
56+
try {
57+
epoch = Long.parseLong(expiresOn);
58+
} catch (NumberFormatException e) {
59+
throw new DatabricksException(
60+
"Invalid 'expires_on' value in MSI token response: " + expiresOn, e);
61+
}
62+
return new Token(accessToken, tokenType, Instant.ofEpochSecond(epoch));
63+
}
64+
}
65+
66+
/**
67+
* Creates a new AzureMsiTokenSource.
68+
*
69+
* @param httpClient The HTTP client to use for requests to the IMDS endpoint.
70+
* @param resource The Azure resource for which to obtain an access token.
71+
* @param clientId The client ID of the managed identity to use. May be null for system-assigned
72+
* identities.
73+
*/
74+
public AzureMsiTokenSource(HttpClient httpClient, String resource, String clientId) {
75+
this.httpClient = httpClient;
76+
this.resource = resource;
77+
this.clientId = clientId;
78+
}
79+
80+
@Override
81+
public Token getToken() {
82+
Request req = new Request("GET", IMDS_ENDPOINT);
83+
req.withQueryParam("api-version", "2018-02-01");
84+
req.withQueryParam("resource", resource);
85+
if (clientId != null && !clientId.isEmpty()) {
86+
req.withQueryParam("client_id", clientId);
87+
}
88+
req.withHeader("Metadata", "true");
89+
90+
Response resp;
91+
try {
92+
resp = httpClient.execute(req);
93+
} catch (IOException e) {
94+
throw new DatabricksException(
95+
"Failed to request MSI token from IMDS endpoint: " + e.getMessage(), e);
96+
}
97+
98+
if (resp.getStatusCode() != 200) {
99+
throw new DatabricksException(
100+
"Failed to request MSI token: status code "
101+
+ resp.getStatusCode()
102+
+ ", response body: "
103+
+ resp.getDebugBody());
104+
}
105+
106+
try {
107+
MsiTokenResponse msiToken = mapper.readValue(resp.getBody(), MsiTokenResponse.class);
108+
return msiToken.toToken();
109+
} catch (IOException e) {
110+
throw new DatabricksException("Failed to parse MSI token response: " + e.getMessage(), e);
111+
}
112+
}
113+
}

0 commit comments

Comments
 (0)