Skip to content

Commit 0a7d4f8

Browse files
thodson-usgsclaude
andcommitted
Address /simplify findings on get_ratings
Six fixes from a multi-agent review pass on PR #269: 1. Narrow the per-feature except handler from a bare Exception to (RequestException, ValueError, OSError). The previous catch-all would swallow programming bugs (KeyError on a malformed feature, AttributeError, ...) and silently drop rows. 2. Escape single quotes in CQL string literals via the standard doubling rule. Most monitoring-location IDs cannot contain a quote, but the function takes arbitrary strings — defending against malformed filters and potential injection regardless. New _quote_cql_str helper, applied to monitoring_location_id and file_type. New unit test pins the behaviour. 3. Promote file_type to a Literal["exsa", "base", "corr"] and derive _VALID_FILE_TYPES from it via typing.get_args, so the runtime guard and the type hint can never drift apart. 4. Rename the `datetime` parameter to `time` to match the convention used by every sibling waterdata getter (and to stop shadowing the stdlib `datetime` module). The parameter is still passed through as the STAC `datetime` query string under the hood; that's now documented explicitly. 5. Switch the multi-type local filter from a substring check on the asset URL to feature["properties"]["file_type"]. Substring matching on URLs would false-match if a host or path ever contained one of the literal type names; STAC features carry the typed property already. 6. Skip the on-disk RDB write when file_path is None. tempfile.mkdtemp leaks (no automatic cleanup), and df.attrs["url"] already records the source — so by default we now return only the parsed frame. Users who want a local copy can pass file_path=...; the contract for that path is unchanged. While here, dropped a "discrete analogue to the OGC waterdata getters" sentence from the module docstring (it was internal-architecture WHAT-narrating, not user-relevant guidance). All 11 ratings tests pass (one new test for the CQL quote-escaping). Live verification on USGS-01104475 confirms file_path=None still returns the parsed frame with df.attrs populated, multi-type via property filter returns the expected 4 tables, and the renamed `time=` parameter still drives the STAC datetime filter. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 34b3d14 commit 0a7d4f8

2 files changed

Lines changed: 69 additions & 53 deletions

File tree

dataretrieval/waterdata/ratings.py

Lines changed: 59 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@
66
service overview at https://api.waterdata.usgs.gov/docs/stac/ and the
77
WDFN announcement at https://waterdata.usgs.gov/blog/wdfn-rating-curves/.
88
9-
This is the discrete analogue to the OGC waterdata getters; it lives in
10-
its own module because the transport layer (STAC search + RDB download)
11-
differs from the OGC collections used by the rest of the package.
12-
139
The R analogue is ``read_waterdata_ratings`` in
1410
https://github.com/DOI-USGS/dataRetrieval/.
1511
"""
@@ -18,8 +14,7 @@
1814

1915
import logging
2016
import os
21-
import tempfile
22-
from typing import Any
17+
from typing import Any, Literal, get_args
2318

2419
import pandas as pd
2520
import requests
@@ -36,7 +31,19 @@
3631
logger = logging.getLogger(__name__)
3732

3833
STAC_URL = f"{BASE_URL}/stac/v0"
39-
_VALID_FILE_TYPES: tuple[str, ...] = ("exsa", "base", "corr")
34+
35+
RATING_FILE_TYPE = Literal["exsa", "base", "corr"]
36+
_VALID_FILE_TYPES = get_args(RATING_FILE_TYPE)
37+
38+
39+
def _quote_cql_str(value: str) -> str:
40+
"""Escape a string for inclusion in a single-quoted CQL literal.
41+
42+
CQL escapes a single quote by doubling it. Most monitoring-location IDs
43+
can never contain a quote, but the function accepts arbitrary strings,
44+
so we defend against malformed filters / injection regardless.
45+
"""
46+
return value.replace("'", "''")
4047

4148

4249
def _build_filter(
@@ -47,7 +54,7 @@ def _build_filter(
4754
4855
Mirrors R's logic: only pin ``file_type`` when a single value was given,
4956
so a multi-type request returns every matching site and the file-type
50-
filtering happens client-side from the per-feature URLs.
57+
filtering happens client-side from the per-feature properties.
5158
"""
5259
parts: list[str] = []
5360
if monitoring_location_id is not None:
@@ -56,16 +63,16 @@ def _build_filter(
5663
if isinstance(monitoring_location_id, str)
5764
else list(monitoring_location_id)
5865
)
59-
joined = "', '".join(ids)
66+
joined = "', '".join(_quote_cql_str(i) for i in ids)
6067
parts.append(f"monitoring_location_id IN ('{joined}')")
6168
if file_type is not None:
62-
parts.append(f"file_type = '{file_type}'")
69+
parts.append(f"file_type = '{_quote_cql_str(file_type)}'")
6370
return " AND ".join(parts) if parts else None
6471

6572

6673
def _search(
6774
filter_str: str | None,
68-
datetime_str: str | None,
75+
time_str: str | None,
6976
bbox: list[float] | None,
7077
limit: int,
7178
ssl_check: bool,
@@ -74,8 +81,8 @@ def _search(
7481
params: dict[str, Any] = {"limit": limit}
7582
if filter_str is not None:
7683
params["filter"] = filter_str
77-
if datetime_str is not None:
78-
params["datetime"] = datetime_str
84+
if time_str is not None:
85+
params["datetime"] = time_str
7986
if bbox is not None:
8087
params["bbox"] = ",".join(str(b) for b in bbox)
8188

@@ -103,19 +110,19 @@ def _extract_rdb_comment(rdb: str) -> list[str]:
103110

104111
def _download_and_parse(
105112
feature: dict[str, Any],
106-
file_path: str,
113+
file_path: str | None,
107114
ssl_check: bool,
108115
) -> pd.DataFrame:
109-
"""Fetch the feature's data asset, write it to ``file_path``, parse RDB."""
116+
"""Fetch the feature's data asset, parse RDB, optionally persist to disk."""
110117
url = feature["assets"]["data"]["href"]
111118
fid = feature["id"]
112-
target = os.path.join(file_path, fid)
113119

114120
response = requests.get(url, headers=_default_headers(), verify=ssl_check)
115121
response.raise_for_status()
116122

117-
with open(target, "w") as f:
118-
f.write(response.text)
123+
if file_path is not None:
124+
with open(os.path.join(file_path, fid), "w") as f:
125+
f.write(response.text)
119126

120127
df = _read_rdb(response.text)
121128
df.attrs["comment"] = _extract_rdb_comment(response.text)
@@ -125,9 +132,9 @@ def _download_and_parse(
125132

126133
def get_ratings(
127134
monitoring_location_id: str | list[str] | None = None,
128-
file_type: str | list[str] = "exsa",
135+
file_type: RATING_FILE_TYPE | list[RATING_FILE_TYPE] = "exsa",
129136
file_path: str | None = None,
130-
datetime: str | list[str] | None = None,
137+
time: str | list[str] | None = None,
131138
bbox: list[float] | None = None,
132139
limit: int = 10000,
133140
download_and_parse: bool = True,
@@ -157,18 +164,19 @@ def get_ratings(
157164
One or more identifiers in ``AGENCY-ID`` form (e.g.
158165
``"USGS-01104475"``). If omitted, the spatial / temporal filters
159166
determine the result set.
160-
file_type : string or list of strings, default ``"exsa"``
161-
Which rating file(s) to request. One or more of ``"exsa"``,
162-
``"base"``, ``"corr"``.
167+
file_type : ``"exsa"``, ``"base"``, ``"corr"``, or a list, default ``"exsa"``
168+
Which rating file(s) to request.
163169
file_path : string, optional
164-
Directory the downloaded RDB files are written to. Defaults to a
165-
per-call temporary directory created via :func:`tempfile.mkdtemp`.
166-
datetime : string or list of strings, optional
167-
STAC ``datetime`` filter — a single date / datetime, or an
168-
interval (``"start/end"``, optionally half-bounded with ``..``).
169-
ISO 8601 *durations* (``"P1M"``, ``"PT36H"``, …) are **not**
170-
supported by the rating-curve service; passing one raises
171-
``ValueError``.
170+
Directory the downloaded RDB files are written to. If ``None``
171+
(the default), the parsed ``DataFrame`` is returned without
172+
persisting the bytes to disk; ``df.attrs["url"]`` still records
173+
where each rating came from.
174+
time : string or list of strings, optional
175+
STAC ``datetime`` filter (passed through verbatim under that name)
176+
— a single date / datetime, or an interval (``"start/end"``,
177+
optionally half-bounded with ``..``). ISO 8601 *durations*
178+
(``"P1M"``, ``"PT36H"``, …) are **not** supported by the
179+
rating-curve service; passing one raises ``ValueError``.
172180
bbox : list of numbers, optional
173181
Only features whose geometry intersects the bounding box are
174182
selected. Format: ``[xmin, ymin, xmax, ymax]`` in CRS 4326
@@ -199,7 +207,7 @@ def get_ratings(
199207
------
200208
ValueError
201209
For an unrecognized ``file_type`` value or an ISO 8601 duration in
202-
``datetime``.
210+
``time``.
203211
204212
Examples
205213
--------
@@ -224,10 +232,10 @@ def get_ratings(
224232
... download_and_parse=False,
225233
... )
226234
227-
>>> # Restrict to features modified since seven days ago (no durations)
235+
>>> # Restrict to features in a date range (durations not supported)
228236
>>> features = dataretrieval.waterdata.get_ratings(
229237
... bbox=[-95.0, 40.0, -92.0, 42.0],
230-
... datetime=["2026-04-29", ".."],
238+
... time=["2026-04-29", ".."],
231239
... download_and_parse=False,
232240
... )
233241
@@ -239,46 +247,46 @@ def get_ratings(
239247
f"Invalid file_type {invalid!r}. Valid options: {list(_VALID_FILE_TYPES)}."
240248
)
241249

242-
if datetime is not None:
250+
if time is not None:
243251
# The rating-curve STAC service rejects ISO 8601 durations; surface a
244252
# clear error rather than letting the server return a confusing 4xx.
245-
dt_values = datetime if isinstance(datetime, list) else [datetime]
246-
if any(v is not None and "P" in str(v).upper() for v in dt_values):
253+
time_values = time if isinstance(time, list) else [time]
254+
if any(v is not None and "P" in str(v).upper() for v in time_values):
247255
raise ValueError(
248256
"ISO 8601 durations (e.g. 'P7D') are not supported in "
249-
"`datetime` for the rating-curve service. Provide a date or "
257+
"`time` for the rating-curve service. Provide a date or "
250258
"interval instead."
251259
)
252-
datetime_str = _format_api_dates(datetime, date=False)
260+
time_str = _format_api_dates(time, date=False)
253261
else:
254-
datetime_str = None
262+
time_str = None
255263

256264
# Mirror R: only pin file_type in the server-side filter when one type
257-
# is requested. With multiple types, fetch all and filter URLs locally.
265+
# is requested. With multiple types, fetch all and filter locally.
258266
server_file_type = file_types[0] if len(file_types) == 1 else None
259267
filter_str = _build_filter(monitoring_location_id, server_file_type)
260268

261-
features = _search(filter_str, datetime_str, bbox, limit, ssl_check)
269+
features = _search(filter_str, time_str, bbox, limit, ssl_check)
262270

263271
if not download_and_parse:
264272
return features
265273

266-
if file_path is None:
267-
file_path = tempfile.mkdtemp(prefix="dataretrieval-ratings-")
268-
os.makedirs(file_path, exist_ok=True)
274+
if file_path is not None:
275+
os.makedirs(file_path, exist_ok=True)
269276

270277
out: dict[str, pd.DataFrame] = {}
278+
requested = set(file_types)
271279
for feature in features:
272-
url = feature.get("assets", {}).get("data", {}).get("href", "")
273-
# Skip features whose file type wasn't requested (only relevant when
274-
# `file_type` is a list — single-type requests are already filtered
275-
# server-side).
276-
if not any(ft in url for ft in file_types):
280+
# Multi-type requests skip the server-side file_type filter, so
281+
# filter here on the per-feature property (more reliable than
282+
# substring-matching the URL).
283+
feat_type = feature.get("properties", {}).get("file_type")
284+
if feat_type not in requested:
277285
continue
278286
fid = feature["id"]
279287
try:
280288
out[fid] = _download_and_parse(feature, file_path, ssl_check)
281-
except Exception as e:
289+
except (requests.RequestException, ValueError, OSError) as e:
282290
logger.warning("Failed to download / parse %s: %s", fid, e)
283291

284292
return out

tests/waterdata_ratings_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,23 @@ def test_get_ratings_rejects_invalid_file_type():
3535
get_ratings(monitoring_location_id="USGS-01104475", file_type="bogus")
3636

3737

38-
def test_get_ratings_rejects_iso_8601_duration_in_datetime():
38+
def test_get_ratings_rejects_iso_8601_duration_in_time():
3939
"""STAC ratings doesn't accept ISO 8601 durations; surface a clear error."""
4040
with pytest.raises(ValueError, match="durations.*not supported"):
4141
get_ratings(
4242
monitoring_location_id="USGS-01104475",
43-
datetime="P7D",
43+
time="P7D",
4444
)
4545

4646

47+
def test_build_filter_escapes_quotes():
48+
"""Defends against malformed CQL or injection if an ID contains a quote."""
49+
from dataretrieval.waterdata.ratings import _build_filter
50+
51+
f = _build_filter("USGS-x'-y", None)
52+
assert f == "monitoring_location_id IN ('USGS-x''-y')"
53+
54+
4755
_SAMPLE_RDB = """\
4856
# header line one
4957
# header line two

0 commit comments

Comments
 (0)