Source code for parsons.aws.lambda_distribute

import csv
from io import TextIOWrapper, BytesIO, StringIO
import logging
import sys
import traceback
import time

from parsons.aws.aws_async import get_func_task_path, import_and_get_task, run as maybe_async_run
from parsons.aws.s3 import S3
from parsons.etl.table import Table
from parsons.utilities.check_env import check

logger = logging.getLogger(__name__)


class DistributeTaskException(Exception):
    pass


class TestStorage:

    def __init__(self):
        self.data = {}

    def put_object(self, bucket, key, object_bytes):
        self.data[key] = object_bytes

    def get_range(self, bucket, key, rangestart, rangeend):
        return self.data[key][rangestart:rangeend]


class S3Storage:
    """
    These methods are pretty specialized, so we keep them
    inside this file rather than s3.py
    """

    def __init__(self):
        self.s3 = S3()

    def put_object(self, bucket, key, object_bytes, **kwargs):
        return self.s3.client.put_object(Bucket=bucket, Key=key, Body=object_bytes, **kwargs)

    def get_range(self, bucket, key, rangestart, rangeend):
        """
        Gets an explicit byte-range of an S3 file
        """
        # bytes is INCLUSIVE for the rangeend parameter, unlike python
        # so e.g. while python returns 2 bytes for data[2:4]
        # Range: bytes=2-4 will return 3!! So we subtract 1
        response = self.s3.client.get_object(
            Bucket=bucket, Key=key,
            Range='bytes={}-{}'.format(rangestart, rangeend - 1))
        return response['Body'].read()


FAKE_STORAGE = TestStorage()
S3_TEMP_KEY_PREFIX = "Parsons_DistributeTask"


def distribute_task_csv(csv_bytes_utf8, func_to_run, bucket,
                        header=None,
                        func_kwargs=None,
                        func_class=None,
                        func_class_kwargs=None,
                        catch=False,
                        group_count=100,
                        storage='s3'):
    """
    The same as distribute_task, but instead of a table, the
    first argument is bytes of a csv encoded into utf8.
    This function is used by distribute_task() which you should use instead.
    """
    global FAKE_STORAGE
    func_name = get_func_task_path(func_to_run, func_class)
    row_chunks = csv_bytes_utf8.split(b'\n')
    cursor = 0
    row_ranges = []
    # gather start/end bytes for each row
    for rowc in row_chunks:
        rng = [cursor]
        cursor = cursor + len(rowc) + 1  # +1 is the \n character
        rng.append(cursor)
        row_ranges.append(rng)

    # group the rows and get start/end bytes for each group
    group_ranges = []
    # table csv writer appends a terminal \r\n, so we do len-1
    for grpstep in range(0, len(row_ranges) - 1, group_count):
        end = min(len(row_ranges) - 1, grpstep + group_count - 1)
        group_ranges.append((row_ranges[grpstep][0], row_ranges[end][1]))

    # upload data
    filename = hash(time.time())
    storagekey = f"{S3_TEMP_KEY_PREFIX}/{filename}.csv"
    groupcount = len(group_ranges)
    logger.debug(f'distribute_task_csv storagekey {storagekey} w/ {groupcount} groups')

    response = None
    if storage == 's3':
        response = S3Storage().put_object(bucket, storagekey, csv_bytes_utf8)
    else:
        response = FAKE_STORAGE.put_object(bucket, storagekey, csv_bytes_utf8)

    # start processes
    results = [
        maybe_async_run(
            process_task_portion,
            [bucket, storagekey, grp[0], grp[1], func_name, header,
             storage, func_kwargs, catch, func_class_kwargs],
            # if we are using local storage, then it must be run locally, as well
            # (good for testing/debugging)
            remote_aws_lambda_function_name='FORCE_LOCAL' if storage == 'local' else None
        )
        for grp in group_ranges]
    return {'DEBUG_ONLY': 'results may vary depending on context/platform',
            'results': results,
            'put_response': response}


[docs]def distribute_task(table, func_to_run, bucket=None, func_kwargs=None, func_class=None, func_class_kwargs=None, catch=False, group_count=100, storage='s3'): """ Distribute processing rows in a table across multiple AWS Lambda invocations. If you are running the processing of a table inside AWS Lambda, then you are limited by how many rows can be processed within the Lambda's time limit (at time-of-writing, maximum 15min). Based on experience and some napkin math, with the same data that would allow 1000 rows to be processed inside a single AWS Lambda instance, this method allows 10 MILLION rows to be processed. Rather than converting the table to SQS or other options, the fastest way is to upload the table to S3, and then invoke multiple Lambda sub-invocations, each of which can be sent a byte-range of the data in the S3 CSV file for which to process. Using this method requires some setup. You have three tasks: 1. Define the function to process rows, the first argument, must take your table's data (though only a subset of rows will be passed) (e.g. `def task_for_distribution(table, **kwargs):`) 2. Where you would have run `task_for_distribution(my_table, **kwargs)` instead call `distribute_task(my_table, task_for_distribution, func_kwargs=kwargs) (either setting env var S3_TEMP_BUCKET or passing a bucket= parameter) 3. Setup your Lambda handler to include :py:meth:`parsons.aws.event_command` (or run and deploy your lambda with `Zappa <https://github.com/Miserlou/Zappa>`_) To test locally, include the argument `storage="local"` which will test the distribute_task function, but run the task sequentially and in local memory. A minimalistic example Lambda handler might look something like this: .. code-block:: python :emphasize-lines: 5,6 from parsons.aws import event_command, distribute_task def process_table(table, foo, bar=None): for row in table: do_sloooooow_thing(row, foo, bar) def handler(event, context): ## ADD THESE TWO LINES TO TOP OF HANDLER: if event_command(event, context): return table = FakeDatasource.load_to_table(username='123', password='abc') # table is so big that running # process_table(table, foo=789, bar='baz') would timeout # so instead we: distribute_task(table, process_table, bucket='my-temp-s3-bucket', func_kwargs={'foo': 789, 'bar': 'baz'}) `Args:` table: Parsons Table Table of data you wish to distribute processing across Lambda invocations of `func_to_run` argument. func_to_run: function The function you want to run whose first argument will be a subset of table bucket: str The bucket name to use for s3 upload to process the whole table Not required if you set environment variable ``S3_TEMP_BUCKET`` func_kwargs: dict If the function has other arguments to pass along with `table` then provide them as a dict here. They must all be JSON-able. func_class: class If the function is a classmethod or function on a class, then pass the pure class here. E.g. If you passed `ActionKit.bulk_upload_table`, then you would pass `ActionKit` here. func_class_kwargs: dict If it is a class function, and the class must be instantiated, then pass the kwargs to instantiate the class here. E.g. If you passed `ActionKit.bulk_upload_table` as the function, then you would pass {'domain': ..., 'username': ... etc} here. This must all be JSON-able data. catch: bool Lambda will retry running an event three times if there's an exception -- if you want to prevent this, set `catch=True` and then it will catch any errors and stop retries. The error will be in CloudWatch logs with string "Distribute Error" This might be important if row-actions are not idempotent and your own function might fail causing repeats. group_count: int Set this to how many rows to process with each Lambda invocation (Default: 100) storage: str Debugging option: Defaults to "s3". To test distribution locally without s3, set to "local". `Returns:` Debug information -- do not rely on the output, as it will change depending on how this method is invoked. """ if storage not in ('s3', 'local'): raise DistributeTaskException(f'storage argument must be s3 or local') bucket = check('S3_TEMP_BUCKET', bucket) csvdata = StringIO() outcsv = csv.writer(csvdata) outcsv.writerows(table.table.data()) return distribute_task_csv(csvdata.getvalue().encode('utf-8-sig'), func_to_run, bucket, header=table.columns, func_kwargs=func_kwargs, func_class=func_class, func_class_kwargs=func_class_kwargs, catch=catch, group_count=group_count, storage=storage)
def process_task_portion(bucket, storagekey, rangestart, rangeend, func_name, header, storage='s3', func_kwargs=None, catch=False, func_class_kwargs=None): global FAKE_STORAGE logger.debug(f'process_task_portion func_name {func_name}, ' f'storagekey {storagekey}, byterange {rangestart}-{rangeend}') func = import_and_get_task(func_name, func_class_kwargs) if storage == 's3': filedata = S3Storage().get_range(bucket, storagekey, rangestart, rangeend) else: filedata = FAKE_STORAGE.get_range(bucket, storagekey, rangestart, rangeend) lines = list(csv.reader(TextIOWrapper(BytesIO(filedata), encoding='utf-8-sig'))) table = Table([header] + lines) if catch: try: func(table, **func_kwargs) except Exception: # In Lambda you can search for '"Distribute Error"' in the logs type_, value_, traceback_ = sys.exc_info() err_traceback_str = '\n'.join(traceback.format_exception(type_, value_, traceback_)) return {'Exception': 'Distribute Error', 'error': err_traceback_str, 'rangestart': rangestart, 'rangeend': rangeend, 'func_name': func_name, 'bucket': bucket, 'storagekey': storagekey} else: func(table, **func_kwargs)