|
| 1 | +import sys |
| 2 | +import ssl |
| 3 | +import asyncio |
| 4 | +import traceback |
| 5 | +import threading |
| 6 | +import backoff |
| 7 | +import certifi |
| 8 | +import random |
| 9 | + |
| 10 | +from tqdm import tqdm |
| 11 | +from urllib.parse import urlencode |
| 12 | +from contextlib import suppress |
| 13 | +from opencage.geocoder import OpenCageGeocode, OpenCageGeocodeError |
| 14 | + |
| 15 | +class OpenCageBatchGeocoder(): |
| 16 | + def __init__(self, options): |
| 17 | + self.options = options |
| 18 | + self.sslcontext = ssl.create_default_context(cafile=certifi.where()) |
| 19 | + self.write_counter = 1 |
| 20 | + |
| 21 | + def __call__(self, *args, **kwargs): |
| 22 | + asyncio.run(self.geocode(*args, **kwargs)) |
| 23 | + |
| 24 | + async def geocode(self, input, output): |
| 25 | + if not self.options.dry_run: |
| 26 | + test = await self.test_request() |
| 27 | + if test['error']: |
| 28 | + self.log(test['error']) |
| 29 | + return |
| 30 | + |
| 31 | + if self.options.headers: |
| 32 | + header_columns = next(input, None) |
| 33 | + if header_columns is None: |
| 34 | + return |
| 35 | + |
| 36 | + queue = asyncio.Queue(maxsize=self.options.limit) |
| 37 | + |
| 38 | + await self.read_input(input, queue) |
| 39 | + |
| 40 | + if self.options.dry_run: |
| 41 | + return |
| 42 | + |
| 43 | + if self.options.headers: |
| 44 | + output.writerow(header_columns + self.options.add_columns) |
| 45 | + |
| 46 | + progress_bar = not (self.options.no_progress or self.options.quiet) and \ |
| 47 | + tqdm(total=queue.qsize(), position=0, desc="Addresses geocoded", dynamic_ncols=True) |
| 48 | + |
| 49 | + tasks = [] |
| 50 | + for _ in range(self.options.workers): |
| 51 | + task = asyncio.create_task(self.worker(output, queue, progress_bar)) |
| 52 | + tasks.append(task) |
| 53 | + |
| 54 | + # This starts the workers and waits until all are finished |
| 55 | + await queue.join() |
| 56 | + |
| 57 | + # All tasks done |
| 58 | + for task in tasks: |
| 59 | + task.cancel() |
| 60 | + |
| 61 | + if progress_bar: |
| 62 | + progress_bar.close() |
| 63 | + |
| 64 | + async def test_request(self): |
| 65 | + try: |
| 66 | + async with OpenCageGeocode(self.options.api_key, domain=self.options.api_domain, sslcontext=self.sslcontext) as geocoder: |
| 67 | + result = await geocoder.geocode_async('Kendall Sq, Cambridge, MA', raw_response=True) |
| 68 | + |
| 69 | + free = False |
| 70 | + with suppress(KeyError): |
| 71 | + free = result['rate']['limit'] == 2500 |
| 72 | + |
| 73 | + return { 'error': None, 'free': free } |
| 74 | + except Exception as exc: |
| 75 | + return { 'error': exc } |
| 76 | + |
| 77 | + async def read_input(self, input, queue): |
| 78 | + for index, row in enumerate(input): |
| 79 | + line_number = index + 1 |
| 80 | + |
| 81 | + if len(row) == 0: |
| 82 | + raise Exception(f"Empty line in input file at line number {line_number}, aborting") |
| 83 | + |
| 84 | + item = await self.read_one_line(row, line_number) |
| 85 | + await queue.put(item) |
| 86 | + |
| 87 | + if queue.full(): |
| 88 | + break |
| 89 | + |
| 90 | + async def read_one_line(self, row, row_id): |
| 91 | + if self.options.command == 'reverse': |
| 92 | + input_columns = [1, 2] |
| 93 | + elif self.options.input_columns: |
| 94 | + input_columns = self.options.input_columns |
| 95 | + else: |
| 96 | + input_columns = None |
| 97 | + |
| 98 | + if input_columns: |
| 99 | + address = [] |
| 100 | + try: |
| 101 | + for column in input_columns: |
| 102 | + # input_columns option uses 1-based indexing |
| 103 | + address.append(row[column - 1]) |
| 104 | + except IndexError: |
| 105 | + self.log(f"Missing input column {column} in {row}") |
| 106 | + else: |
| 107 | + address = row |
| 108 | + |
| 109 | + if self.options.command == 'reverse' and len(address) != 2: |
| 110 | + self.log(f"Expected two comma-separated values for reverse geocoding, got {address}") |
| 111 | + |
| 112 | + return { 'row_id': row_id, 'address': ','.join(address), 'original_columns': row } |
| 113 | + |
| 114 | + async def worker(self, output, queue, progress): |
| 115 | + while True: |
| 116 | + item = await queue.get() |
| 117 | + |
| 118 | + try: |
| 119 | + await self.geocode_one_address(output, item['row_id'], item['address'], item['original_columns']) |
| 120 | + |
| 121 | + if progress: |
| 122 | + progress.update(1) |
| 123 | + except Exception as exc: |
| 124 | + traceback.print_exception(exc, file=sys.stderr) |
| 125 | + finally: |
| 126 | + queue.task_done() |
| 127 | + |
| 128 | + async def geocode_one_address(self, output, row_id, address, original_columns): |
| 129 | + def on_backoff(details): |
| 130 | + if not self.options.quiet: |
| 131 | + sys.stderr.write("Backing off {wait:0.1f} seconds afters {tries} tries " |
| 132 | + "calling function {target} with args {args} and kwargs " |
| 133 | + "{kwargs}\n".format(**details)) |
| 134 | + |
| 135 | + @backoff.on_exception(backoff.expo, |
| 136 | + asyncio.TimeoutError, |
| 137 | + max_time=self.options.timeout, |
| 138 | + max_tries=self.options.retries, |
| 139 | + on_backoff=on_backoff) |
| 140 | + async def _geocode_one_address(): |
| 141 | + async with OpenCageGeocode(self.options.api_key, domain=self.options.api_domain, sslcontext=self.sslcontext) as geocoder: |
| 142 | + geocoding_results = None |
| 143 | + params = { 'no_annotations': 1, **self.options.extra_params } |
| 144 | + |
| 145 | + try: |
| 146 | + if self.options.command == 'reverse': |
| 147 | + lon, lat = address.split(',') |
| 148 | + geocoding_results = await geocoder.reverse_geocode_async(lon, lat, **params) |
| 149 | + else: |
| 150 | + geocoding_results = await geocoder.geocode_async(address, **params) |
| 151 | + except OpenCageGeocodeError as exc: |
| 152 | + self.log(str(exc)) |
| 153 | + except Exception as exc: |
| 154 | + traceback.print_exception(exc, file=sys.stderr) |
| 155 | + |
| 156 | + try: |
| 157 | + if geocoding_results is not None and len(geocoding_results): |
| 158 | + geocoding_result = geocoding_results[0] |
| 159 | + else: |
| 160 | + geocoding_result = None |
| 161 | + |
| 162 | + if self.options.verbose: |
| 163 | + self.log({ |
| 164 | + 'row_id': row_id, |
| 165 | + 'thread_id': threading.get_native_id(), |
| 166 | + 'request': geocoder.url + '?' + urlencode(geocoder._parse_request(address, params)), |
| 167 | + 'response': geocoding_result |
| 168 | + }) |
| 169 | + |
| 170 | + await self.write_one_geocoding_result(output, row_id, address, geocoding_result, original_columns) |
| 171 | + except Exception as exc: |
| 172 | + traceback.print_exception(exc, file=sys.stderr) |
| 173 | + |
| 174 | + await _geocode_one_address() |
| 175 | + |
| 176 | + async def write_one_geocoding_result(self, output, row_id, address, geocoding_result, original_columns = []): |
| 177 | + row = original_columns |
| 178 | + |
| 179 | + for column in self.options.add_columns: |
| 180 | + if geocoding_result is None: |
| 181 | + row.append('') |
| 182 | + elif column in geocoding_result: |
| 183 | + row.append(geocoding_result[column]) |
| 184 | + elif column in geocoding_result['components']: |
| 185 | + row.append(geocoding_result['components'][column]) |
| 186 | + elif column in geocoding_result['geometry']: |
| 187 | + row.append(geocoding_result['geometry'][column]) |
| 188 | + else: |
| 189 | + row.append('') |
| 190 | + |
| 191 | + # Enforce that row are written ordered. That means we might wait for other threads |
| 192 | + # to finish a task and make the overall process slower. Alternative would be to |
| 193 | + # use a second queue, or keep some results in memory. |
| 194 | + while row_id > self.write_counter: |
| 195 | + if self.options.verbose: |
| 196 | + self.log(f"Want to write row {row_id}, but write_counter is at {self.write_counter}") |
| 197 | + await asyncio.sleep(random.uniform(0.01, 0.1)) |
| 198 | + |
| 199 | + if self.options.verbose: |
| 200 | + self.log(f"Writing row {row_id}") |
| 201 | + output.writerow(row) |
| 202 | + self.write_counter = self.write_counter + 1 |
| 203 | + |
| 204 | + def log(self, message): |
| 205 | + if not self.options.quiet: |
| 206 | + sys.stderr.write(f"{message}\n") |
| 207 | + |
0 commit comments