Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion pyiceberg/catalog/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import json
from collections import deque
from enum import Enum
from typing import (
Expand Down Expand Up @@ -435,7 +436,16 @@ def _create_session(self) -> Session:
elif ssl_client_cert := ssl_client.get(CERT):
session.cert = ssl_client_cert

if auth_config := self.properties.get(AUTH):
if raw_auth := self.properties.get(AUTH):
# When auth is configured via an environment variable (e.g. PYICEBERG_CATALOG__<NAME>__AUTH),
# the value arrives as a JSON string rather than a dict. Decode it before processing.
if isinstance(raw_auth, str):
try:
auth_config: dict[str, Any] = json.loads(raw_auth)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse auth configuration as JSON: {raw_auth!r}") from e
else:
auth_config = raw_auth
auth_type = auth_config.get("type")
if auth_type is None:
raise ValueError("auth.type must be defined")
Expand All @@ -448,6 +458,27 @@ def _create_session(self) -> Session:
if auth_type != CUSTOM and auth_impl:
raise ValueError("auth.impl can only be specified when using custom auth.type")

self._auth_manager = AuthManagerFactory.create(auth_impl or auth_type, auth_type_config)
session.auth = AuthManagerAdapter(self._auth_manager)
elif auth_type := self.properties.get(f"{AUTH}.type"):
# Support flattened env-var style configuration:
# PYICEBERG_CATALOG__<NAME>__AUTH__TYPE=oauth2
# PYICEBERG_CATALOG__<NAME>__AUTH__OAUTH2__CLIENT_ID=id
# The env-var parser maps these to flat properties like "auth.type" and "auth.oauth2.client-id".
# Key names are converted from kebab-case to snake_case to match AuthManager constructor parameters.
auth_impl = self.properties.get(f"{AUTH}.impl")

if auth_type == CUSTOM and not auth_impl:
raise ValueError("auth.impl must be specified when using custom auth.type")

if auth_type != CUSTOM and auth_impl:
raise ValueError("auth.impl can only be specified when using custom auth.type")

type_prefix = f"{AUTH}.{auth_type}."
auth_type_config = {
k[len(type_prefix) :].replace("-", "_"): v for k, v in self.properties.items() if k.startswith(type_prefix)
}

self._auth_manager = AuthManagerFactory.create(auth_impl or auth_type, auth_type_config)
session.auth = AuthManagerAdapter(self._auth_manager)
else:
Expand Down
132 changes: 132 additions & 0 deletions tests/catalog/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3167,3 +3167,135 @@ def test_load_table_without_storage_credentials(
)
assert actual.metadata.model_dump() == expected.metadata.model_dump()
assert actual == expected


# Tests for issue #3422: REST catalog auth cannot be configured via environment
# variables unless auth JSON strings are decoded.


def test_rest_catalog_with_basic_auth_as_json_string(rest_mock: Mocker) -> None:
"""When auth arrives as a JSON string (e.g. from an environment variable), it should be decoded correctly."""
import json

rest_mock.get(
f"{TEST_URI}v1/config",
json={"defaults": {}, "overrides": {}},
status_code=200,
)
auth_dict = {
"type": "basic",
"basic": {
"username": "one",
"password": "two",
},
}
catalog_properties = {
"uri": TEST_URI,
"auth": json.dumps(auth_dict),
}
catalog = RestCatalog("rest", **catalog_properties)
assert catalog.uri == TEST_URI

encoded_user_pass = base64.b64encode(b"one:two").decode()
expected_auth_header = f"Basic {encoded_user_pass}"
assert rest_mock.last_request.headers["Authorization"] == expected_auth_header


def test_rest_catalog_with_oauth2_auth_as_json_string(requests_mock: Mocker) -> None:
"""OAuth2 auth configured as a JSON string (e.g. from an environment variable) should work correctly."""
import json

requests_mock.post(
f"{TEST_URI}oauth2/token",
json={
"access_token": "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3",
"token_type": "Bearer",
"expires_in": 3600,
},
status_code=200,
)
requests_mock.get(
f"{TEST_URI}v1/config",
json={"defaults": {}, "overrides": {}},
status_code=200,
)
auth_dict = {
"type": "oauth2",
"oauth2": {
"client_id": "some_client_id",
"client_secret": "some_client_secret",
"token_url": f"{TEST_URI}oauth2/token",
},
}
catalog_properties = {
"uri": TEST_URI,
"auth": json.dumps(auth_dict),
}
catalog = RestCatalog("rest", **catalog_properties)
assert catalog.uri == TEST_URI


def test_rest_catalog_with_invalid_json_auth_string() -> None:
"""An auth value that is a string but not valid JSON should raise a descriptive ValueError."""
with pytest.raises(ValueError, match="Failed to parse auth configuration as JSON"):
RestCatalog("rest", uri=TEST_URI, auth="not-valid-json")


def test_rest_catalog_with_basic_auth_flat_properties(rest_mock: Mocker) -> None:
"""Auth configured via flattened env-var properties (e.g. PYICEBERG_CATALOG__<NAME>__AUTH__TYPE=basic)
should initialise the correct AuthManager.

The env-var parser converts PYICEBERG_CATALOG__<NAME>__AUTH__TYPE=basic into the flat property
'auth.type' = 'basic' and PYICEBERG_CATALOG__<NAME>__AUTH__BASIC__USERNAME=one into
'auth.basic.username' = 'one'.
"""
rest_mock.get(
f"{TEST_URI}v1/config",
json={"defaults": {}, "overrides": {}},
status_code=200,
)
catalog_properties = {
"uri": TEST_URI,
# Flat properties as produced by the env-var config parser
"auth.type": "basic",
"auth.basic.username": "one",
"auth.basic.password": "two",
}
catalog = RestCatalog("rest", **catalog_properties)
assert catalog.uri == TEST_URI

encoded_user_pass = base64.b64encode(b"one:two").decode()
expected_auth_header = f"Basic {encoded_user_pass}"
assert rest_mock.last_request.headers["Authorization"] == expected_auth_header


def test_rest_catalog_with_oauth2_auth_flat_properties(requests_mock: Mocker) -> None:
"""OAuth2 auth configured via flattened env-var properties should work correctly.

PYICEBERG_CATALOG__<NAME>__AUTH__OAUTH2__CLIENT_ID maps to 'auth.oauth2.client-id'.
The dash is normalised to an underscore ('client_id') when forwarding to OAuth2AuthManager.
"""
requests_mock.post(
f"{TEST_URI}oauth2/token",
json={
"access_token": "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3",
"token_type": "Bearer",
"expires_in": 3600,
},
status_code=200,
)
requests_mock.get(
f"{TEST_URI}v1/config",
json={"defaults": {}, "overrides": {}},
status_code=200,
)
catalog_properties = {
"uri": TEST_URI,
# Flat properties as produced by the env-var config parser (note: kebab-case keys)
"auth.type": "oauth2",
"auth.oauth2.client-id": "some_client_id",
"auth.oauth2.client-secret": "some_client_secret",
"auth.oauth2.token-url": f"{TEST_URI}oauth2/token",
}
catalog = RestCatalog("rest", **catalog_properties)
assert catalog.uri == TEST_URI