Skip to content

Commit 6e63e27

Browse files
committed
fix: check severity label on reruns, improve error logging, add tests
Signed-off-by: lelia <2418071+lelia@users.noreply.github.com>
1 parent db6d25a commit 6e63e27

2 files changed

Lines changed: 229 additions & 35 deletions

File tree

socket_basics/core/notification/github_pr_notifier.py

Lines changed: 140 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List, Optional
22
import logging
3+
from urllib.parse import quote
34

45
from socket_basics.core.notification.base import BaseNotifier
56
from socket_basics.core.config import get_github_token, get_github_repository, get_github_pr_number
@@ -118,8 +119,39 @@ def notify(self, facts: Dict[str, Any]) -> None:
118119
# Add labels to PR if enabled
119120
if self.config.get('pr_labels_enabled', True) and pr_number:
120121
labels = self._determine_pr_labels(valid_notifications)
121-
if labels:
122-
self._add_pr_labels(pr_number, labels)
122+
self._reconcile_pr_labels(pr_number, labels)
123+
def _managed_pr_label_config(self) -> Dict[str, str]:
124+
"""Return the managed severity label names configured for PRs."""
125+
return {
126+
'critical': self.config.get('pr_label_critical', 'security: critical'),
127+
'high': self.config.get('pr_label_high', 'security: high'),
128+
'medium': self.config.get('pr_label_medium', 'security: medium'),
129+
'low': self.config.get('pr_label_low', 'security: low'),
130+
}
131+
132+
def _get_label_color_info(self, label: str) -> Optional[tuple[str, str]]:
133+
"""Infer color/description for managed or custom severity labels."""
134+
label_colors = {
135+
self.config.get('pr_label_critical', 'security: critical'): ('D73A4A', 'Critical security vulnerabilities'),
136+
self.config.get('pr_label_high', 'security: high'): ('D93F0B', 'High severity security issues'),
137+
self.config.get('pr_label_medium', 'security: medium'): ('FBCA04', 'Medium severity security issues'),
138+
self.config.get('pr_label_low', 'security: low'): ('E4E4E4', 'Low severity security issues'),
139+
}
140+
color_info = label_colors.get(label)
141+
if color_info:
142+
return color_info
143+
144+
label_lower = label.lower()
145+
if 'critical' in label_lower:
146+
return ('D73A4A', 'Critical security vulnerabilities')
147+
if 'high' in label_lower:
148+
return ('D93F0B', 'High severity security issues')
149+
if 'medium' in label_lower:
150+
return ('FBCA04', 'Medium severity security issues')
151+
if 'low' in label_lower:
152+
return ('E4E4E4', 'Low severity security issues')
153+
return None
154+
123155

124156
def _send_pr_comment(self, facts: Dict[str, Any], title: str, content: str) -> None:
125157
"""Send a single PR comment with title and content."""
@@ -423,19 +455,93 @@ def _ensure_label_exists_with_color(self, label_name: str, color: str, descripti
423455
logger.info('GithubPRNotifier: created label "%s" with color #%s', label_name, color)
424456
return True
425457
else:
426-
logger.warning('GithubPRNotifier: failed to create label "%s": %s',
427-
label_name, create_resp.status_code)
458+
logger.warning(
459+
'GithubPRNotifier: failed to create label "%s": %s %s',
460+
label_name,
461+
create_resp.status_code,
462+
create_resp.text[:200],
463+
)
428464
return False
429465
else:
430-
logger.warning('GithubPRNotifier: unexpected response checking label: %s', resp.status_code)
466+
logger.warning(
467+
'GithubPRNotifier: unexpected response checking label "%s": %s %s',
468+
label_name,
469+
resp.status_code,
470+
resp.text[:200],
471+
)
431472
return False
432473

433474
except Exception as e:
434475
logger.debug('GithubPRNotifier: exception ensuring label exists: %s', e)
435476
return False
436477

478+
def _ensure_pr_labels_exist(self, labels: List[str]) -> None:
479+
"""Ensure desired labels exist in the repository with appropriate colors."""
480+
for label in labels:
481+
color_info = self._get_label_color_info(label)
482+
if color_info:
483+
color, description = color_info
484+
self._ensure_label_exists_with_color(label, color, description)
485+
486+
def _get_current_pr_label_names(self, pr_number: int) -> List[str]:
487+
"""Fetch current label names for the PR."""
488+
if not self.repository:
489+
return []
490+
491+
try:
492+
import requests
493+
headers = {
494+
'Authorization': f'token {self.token}',
495+
'Accept': 'application/vnd.github.v3+json'
496+
}
497+
url = f"{self.api_base}/repos/{self.repository}/issues/{pr_number}/labels"
498+
resp = requests.get(url, headers=headers, timeout=10)
499+
if resp.status_code == 200:
500+
payload = resp.json()
501+
return [label.get('name') for label in payload if isinstance(label, dict) and label.get('name')]
502+
logger.warning(
503+
'GithubPRNotifier: failed to fetch current labels for PR %s: %s %s',
504+
pr_number,
505+
resp.status_code,
506+
resp.text[:200],
507+
)
508+
except Exception as e:
509+
logger.error('GithubPRNotifier: exception fetching current labels: %s', e)
510+
return []
511+
512+
def _remove_pr_label(self, pr_number: int, label: str) -> bool:
513+
"""Remove a label from a PR."""
514+
if not self.repository or not label:
515+
return False
516+
517+
try:
518+
import requests
519+
headers = {
520+
'Authorization': f'token {self.token}',
521+
'Accept': 'application/vnd.github.v3+json'
522+
}
523+
encoded_label = quote(label, safe='')
524+
url = f"{self.api_base}/repos/{self.repository}/issues/{pr_number}/labels/{encoded_label}"
525+
resp = requests.delete(url, headers=headers, timeout=10)
526+
if resp.status_code == 200:
527+
logger.info('GithubPRNotifier: removed label from PR %s: %s', pr_number, label)
528+
return True
529+
if resp.status_code == 404:
530+
logger.debug('GithubPRNotifier: label %s already absent from PR %s', label, pr_number)
531+
return True
532+
logger.warning(
533+
'GithubPRNotifier: failed to remove label "%s" from PR %s: %s %s',
534+
label,
535+
pr_number,
536+
resp.status_code,
537+
resp.text[:200],
538+
)
539+
except Exception as e:
540+
logger.error('GithubPRNotifier: exception removing label %s: %s', label, e)
541+
return False
542+
437543
def _add_pr_labels(self, pr_number: int, labels: List[str]) -> bool:
438-
"""Add labels to a PR, ensuring they exist with appropriate colors.
544+
"""Add missing labels to a PR.
439545
440546
Args:
441547
pr_number: PR number
@@ -447,34 +553,6 @@ def _add_pr_labels(self, pr_number: int, labels: List[str]) -> bool:
447553
if not self.repository or not labels:
448554
return False
449555

450-
# Color mapping for severity labels (matching emoji colors)
451-
label_colors = {
452-
'security: critical': ('D73A4A', 'Critical security vulnerabilities'),
453-
'security: high': ('D93F0B', 'High severity security issues'),
454-
'security: medium': ('FBCA04', 'Medium severity security issues'),
455-
'security: low': ('E4E4E4', 'Low severity security issues'),
456-
}
457-
458-
# Ensure labels exist with correct colors
459-
for label in labels:
460-
# Get color and description if this is a known severity label
461-
color_info = label_colors.get(label)
462-
if color_info:
463-
color, description = color_info
464-
self._ensure_label_exists_with_color(label, color, description)
465-
# For custom label names, use a default color
466-
elif ':' in label:
467-
# Try to infer severity from label name
468-
label_lower = label.lower()
469-
if 'critical' in label_lower:
470-
self._ensure_label_exists_with_color(label, 'D73A4A', 'Critical security vulnerabilities')
471-
elif 'high' in label_lower:
472-
self._ensure_label_exists_with_color(label, 'D93F0B', 'High severity security issues')
473-
elif 'medium' in label_lower:
474-
self._ensure_label_exists_with_color(label, 'FBCA04', 'Medium severity security issues')
475-
elif 'low' in label_lower:
476-
self._ensure_label_exists_with_color(label, 'E4E4E4', 'Low severity security issues')
477-
478556
try:
479557
import requests
480558
headers = {
@@ -490,12 +568,33 @@ def _add_pr_labels(self, pr_number: int, labels: List[str]) -> bool:
490568
logger.info('GithubPRNotifier: added labels to PR %s: %s', pr_number, ', '.join(labels))
491569
return True
492570
else:
493-
logger.warning('GithubPRNotifier: failed to add labels: %s', resp.status_code)
571+
logger.warning('GithubPRNotifier: failed to add labels: %s %s', resp.status_code, resp.text[:200])
494572
return False
495573
except Exception as e:
496574
logger.error('GithubPRNotifier: exception adding labels: %s', e)
497575
return False
498576

577+
def _reconcile_pr_labels(self, pr_number: int, desired_labels: List[str]) -> bool:
578+
"""Reconcile managed severity labels on the PR to match the latest run."""
579+
managed_labels = set(filter(None, self._managed_pr_label_config().values()))
580+
current_labels = set(self._get_current_pr_label_names(pr_number))
581+
desired_label_set = set(filter(None, desired_labels))
582+
583+
stale_labels = sorted(label for label in current_labels if label in managed_labels and label not in desired_label_set)
584+
labels_to_add = sorted(label for label in desired_label_set if label not in current_labels)
585+
586+
success = True
587+
for label in stale_labels:
588+
success = self._remove_pr_label(pr_number, label) and success
589+
590+
if labels_to_add:
591+
self._ensure_pr_labels_exist(labels_to_add)
592+
success = self._add_pr_labels(pr_number, labels_to_add) and success
593+
594+
if not stale_labels and not labels_to_add:
595+
logger.info('GithubPRNotifier: PR %s severity labels already up to date', pr_number)
596+
return success
597+
499598
def _determine_pr_labels(self, notifications: List[Dict[str, Any]]) -> List[str]:
500599
"""Determine which labels to add based on notifications.
501600
@@ -517,13 +616,16 @@ def _determine_pr_labels(self, notifications: List[Dict[str, Any]]) -> List[str]
517616
critical_match = re.search(r'Critical:\s*(\d+)', content)
518617
high_match = re.search(r'High:\s*(\d+)', content)
519618
medium_match = re.search(r'Medium:\s*(\d+)', content)
619+
low_match = re.search(r'Low:\s*(\d+)', content)
520620

521621
if critical_match and int(critical_match.group(1)) > 0:
522622
severities_found.add('critical')
523623
if high_match and int(high_match.group(1)) > 0:
524624
severities_found.add('high')
525625
if medium_match and int(medium_match.group(1)) > 0:
526626
severities_found.add('medium')
627+
if low_match and int(low_match.group(1)) > 0:
628+
severities_found.add('low')
527629

528630
# Map severities to label names (using configurable labels)
529631
labels = []
@@ -536,5 +638,8 @@ def _determine_pr_labels(self, notifications: List[Dict[str, Any]]) -> List[str]
536638
elif 'medium' in severities_found:
537639
label_name = self.config.get('pr_label_medium', 'security: medium')
538640
labels.append(label_name)
641+
elif 'low' in severities_found:
642+
label_name = self.config.get('pr_label_low', 'security: low')
643+
labels.append(label_name)
539644

540645
return labels

tests/test_github_pr_notifier.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from socket_basics.core.notification.github_pr_notifier import GithubPRNotifier
2+
3+
4+
def _notification(summary: str) -> dict:
5+
return {'title': 'Socket SAST JavaScript', 'content': summary}
6+
7+
8+
def test_determine_pr_labels_prefers_highest_current_severity():
9+
notifier = GithubPRNotifier(
10+
{
11+
'repository': 'SocketDev/socket-basics',
12+
'pr_label_critical': 'security: critical',
13+
'pr_label_high': 'security: high',
14+
'pr_label_medium': 'security: medium',
15+
'pr_label_low': 'security: low',
16+
}
17+
)
18+
19+
labels = notifier._determine_pr_labels(
20+
[_notification('Critical: 0 | High: 1 | Medium: 2 | Low: 3')]
21+
)
22+
23+
assert labels == ['security: high']
24+
25+
26+
def test_determine_pr_labels_supports_low_severity():
27+
notifier = GithubPRNotifier(
28+
{
29+
'repository': 'SocketDev/socket-basics',
30+
'pr_label_low': 'security: low',
31+
}
32+
)
33+
34+
labels = notifier._determine_pr_labels(
35+
[_notification('Critical: 0 | High: 0 | Medium: 0 | Low: 2')]
36+
)
37+
38+
assert labels == ['security: low']
39+
40+
41+
def test_reconcile_pr_labels_replaces_stale_managed_severity(monkeypatch):
42+
notifier = GithubPRNotifier(
43+
{
44+
'repository': 'SocketDev/socket-basics',
45+
'pr_label_critical': 'security: critical',
46+
'pr_label_high': 'security: high',
47+
'pr_label_medium': 'security: medium',
48+
'pr_label_low': 'security: low',
49+
}
50+
)
51+
52+
removed: list[str] = []
53+
added: list[str] = []
54+
ensured: list[str] = []
55+
56+
monkeypatch.setattr(notifier, '_get_current_pr_label_names', lambda pr_number: ['security: critical', 'team: backend'])
57+
monkeypatch.setattr(notifier, '_remove_pr_label', lambda pr_number, label: removed.append(label) or True)
58+
monkeypatch.setattr(notifier, '_ensure_pr_labels_exist', lambda labels: ensured.extend(labels))
59+
monkeypatch.setattr(notifier, '_add_pr_labels', lambda pr_number, labels: added.extend(labels) or True)
60+
61+
success = notifier._reconcile_pr_labels(123, ['security: medium'])
62+
63+
assert success is True
64+
assert removed == ['security: critical']
65+
assert ensured == ['security: medium']
66+
assert added == ['security: medium']
67+
68+
69+
def test_reconcile_pr_labels_clears_managed_labels_when_none_desired(monkeypatch):
70+
notifier = GithubPRNotifier(
71+
{
72+
'repository': 'SocketDev/socket-basics',
73+
'pr_label_critical': 'security: critical',
74+
'pr_label_high': 'security: high',
75+
'pr_label_medium': 'security: medium',
76+
'pr_label_low': 'security: low',
77+
}
78+
)
79+
80+
removed: list[str] = []
81+
monkeypatch.setattr(notifier, '_get_current_pr_label_names', lambda pr_number: ['security: high', 'docs'])
82+
monkeypatch.setattr(notifier, '_remove_pr_label', lambda pr_number, label: removed.append(label) or True)
83+
monkeypatch.setattr(notifier, '_ensure_pr_labels_exist', lambda labels: (_ for _ in ()).throw(AssertionError('should not ensure labels')))
84+
monkeypatch.setattr(notifier, '_add_pr_labels', lambda pr_number, labels: (_ for _ in ()).throw(AssertionError('should not add labels')))
85+
86+
success = notifier._reconcile_pr_labels(123, [])
87+
88+
assert success is True
89+
assert removed == ['security: high']

0 commit comments

Comments
 (0)