|
| 1 | +import uuid |
| 2 | +from time import time, sleep |
| 3 | +import os |
| 4 | +import jwt |
| 5 | +import requests |
| 6 | +import json |
| 7 | +from .secret import Secret |
| 8 | + |
| 9 | + |
| 10 | +class AuthenticationCache(): |
| 11 | + def __init__(self): |
| 12 | + self.tokens = {} |
| 13 | + |
| 14 | + # Number of consecutive authentication tests before considering it worked. |
| 15 | + # This is required because apigee might be using load balancing. |
| 16 | + self.consecutive_tests = 4 |
| 17 | + |
| 18 | + # Number of seconds before trying a test again |
| 19 | + self.time_between_tests = 10 |
| 20 | + |
| 21 | + # Number of attempts before giving up. It can take up to 5 minutes for the |
| 22 | + # addition of a product to an application to take effect. |
| 23 | + # time_between_tests * max_tests = 300 seconds = 5 minutes. |
| 24 | + self.max_tests = 30 |
| 25 | + |
| 26 | + # How long the token will stay valid |
| 27 | + self.token_validity = 180 |
| 28 | + |
| 29 | + def generate_authentication(self, env, base_url): |
| 30 | + |
| 31 | + # For the test_url, note that we don't need a message_id that actually exists in |
| 32 | + # the backend. The test will only check that the API doesn't return a 401, |
| 33 | + # a 404 response means the authentication is working. |
| 34 | + test_url = f"{base_url}/v1/messages/message_id" |
| 35 | + |
| 36 | + if env in ["internal-dev", "ref"]: |
| 37 | + api_key = os.environ["NON_PROD_API_KEY"] |
| 38 | + private_key = os.environ["NON_PROD_PRIVATE_KEY"] |
| 39 | + url = "https://internal-dev.api.service.nhs.uk/oauth2/token" |
| 40 | + kid = "local" |
| 41 | + elif env == "int": |
| 42 | + api_key = os.environ.get("INTEGRATION_API_KEY") |
| 43 | + private_key = os.environ.get("INTEGRATION_PRIVATE_KEY") |
| 44 | + url = "https://int.api.service.nhs.uk/oauth2/token" |
| 45 | + kid = "local" |
| 46 | + elif env == "prod": |
| 47 | + api_key = os.environ.get("PRODUCTION_API_KEY") |
| 48 | + private_key = os.environ.get("PRODUCTION_PRIVATE_KEY") |
| 49 | + url = "https://api.service.nhs.uk/oauth2/token" |
| 50 | + kid = "prod-1" |
| 51 | + else: |
| 52 | + raise ValueError("Unknown value: ", env) |
| 53 | + |
| 54 | + _, latest_token_expiry = self.tokens.get(env, (None, 0)) |
| 55 | + |
| 56 | + # Generate new token if latest token will expire in 15 seconds |
| 57 | + if env not in self.tokens or latest_token_expiry < int(time()) + 15: |
| 58 | + self.tokens[env] = self.generate_and_test_new_token(api_key, private_key, url, kid, test_url) |
| 59 | + |
| 60 | + bearer_token = self.tokens[env][0] |
| 61 | + return Secret(bearer_token) |
| 62 | + |
| 63 | + def generate_and_test_new_token(self, api_key, private_key, url, kid, test_url): |
| 64 | + new_token = None |
| 65 | + valid_auth = False |
| 66 | + |
| 67 | + for i in range(self.max_tests): |
| 68 | + print(f"Testing new token, attemp #{i+1}") |
| 69 | + if new_token is None: |
| 70 | + new_token = self.generate_new_token(api_key, private_key, url, kid) |
| 71 | + time_since_new_token = int(time()) |
| 72 | + |
| 73 | + if self.test_token(test_url, new_token[0]): |
| 74 | + valid_auth = True |
| 75 | + break |
| 76 | + |
| 77 | + # The test failed, give apigee some time to update its cache. |
| 78 | + sleep(self.time_between_tests) |
| 79 | + |
| 80 | + if int(time()) - time_since_new_token > (self.token_validity / 2): |
| 81 | + # Token about to expire, generate a new one |
| 82 | + new_token = None |
| 83 | + |
| 84 | + if valid_auth: |
| 85 | + print("Token generated successfully") |
| 86 | + return new_token |
| 87 | + |
| 88 | + print("Could not generate token") |
| 89 | + raise RuntimeError("Could not generate token") |
| 90 | + |
| 91 | + def generate_new_token(self, api_key, private_key, url, kid): |
| 92 | + pk_pem = None |
| 93 | + with open(private_key, "r") as f: |
| 94 | + pk_pem = f.read() |
| 95 | + |
| 96 | + token_expiry = int(time()) + self.token_validity |
| 97 | + |
| 98 | + claims = { |
| 99 | + "sub": api_key, |
| 100 | + "iss": api_key, |
| 101 | + "jti": str(uuid.uuid4()), |
| 102 | + "aud": url, |
| 103 | + "exp": token_expiry, |
| 104 | + } |
| 105 | + additional_headers = {"kid": kid} |
| 106 | + |
| 107 | + j = jwt.encode( |
| 108 | + claims, pk_pem, algorithm="RS512", headers=additional_headers |
| 109 | + ) |
| 110 | + |
| 111 | + resp = requests.post(url, headers={ |
| 112 | + "Content-Type": "application/x-www-form-urlencoded" |
| 113 | + }, data={ |
| 114 | + "grant_type": "client_credentials", |
| 115 | + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", |
| 116 | + "client_assertion": j |
| 117 | + } |
| 118 | + ) |
| 119 | + details = json.loads(resp.content) |
| 120 | + |
| 121 | + return (f"Bearer {details.get('access_token')}", token_expiry) |
| 122 | + |
| 123 | + def test_token(self, test_url, token): |
| 124 | + for _ in range(self.consecutive_tests): |
| 125 | + resp = requests.get( |
| 126 | + test_url, |
| 127 | + headers={ |
| 128 | + "Authorization": token, |
| 129 | + "Accept": "*/*", |
| 130 | + "Content-Type": "application/json" |
| 131 | + }, |
| 132 | + ) |
| 133 | + if resp.status_code == 401: |
| 134 | + return False |
| 135 | + |
| 136 | + return True |
0 commit comments