Source code for parsons.aws.lambda_distribute

import csv
from io import TextIOWrapper, BytesIO, StringIO
import logging
import sys
import traceback
import time

from parsons.aws.aws_async import get_func_task_path, import_and_get_task, run as maybe_async_run
from parsons.aws.s3 import S3
from parsons.etl.table import Table
from parsons.utilities.check_env import check

logger = logging.getLogger(__name__)


class DistributeTaskException(Exception):
    pass


class TestStorage:

    def __init__(self):
        self.data = {}

    def put_object(self, bucket, key, object_bytes):
        self.data[key] = object_bytes

    def get_range(self, bucket, key, rangestart, rangeend):
        return self.data[key][rangestart:rangeend]


class S3Storage:
    """
    These methods are pretty specialized, so we keep them
    inside this file rather than s3.py
    """

    def __init__(self):
        self.s3 = S3()

    def put_object(self, bucket, key, object_bytes, **kwargs):
        return self.s3.client.put_object(Bucket=bucket, Key=key, Body=object_bytes, **kwargs)

    def get_range(self, bucket, key, rangestart, rangeend):
        """
        Gets an explicit byte-range of an S3 file
        """
        # bytes is INCLUSIVE for the rangeend parameter, unlike python
        # so e.g. while python returns 2 bytes for data[2:4]
        # Range: bytes=2-4 will return 3!! So we subtract 1
        response = self.s3.client.get_object(
            Bucket=bucket, Key=key,
            Range='bytes={}-{}'.format(rangestart, rangeend - 1))
        return response['Body'].read()


FAKE_STORAGE = TestStorage()
S3_TEMP_KEY_PREFIX = "Parsons_DistributeTask"


def distribute_task_csv(csv_bytes_utf8, func_to_run, bucket,
                        header=None,
                        func_kwargs=None,
                        func_class=None,
                        func_class_kwargs=None,
                        catch=False,
                        group_count=100,
                        storage='s3'):
    """
    The same as distribute_task, but instead of a table, the
    first argument is bytes of a csv encoded into utf8.
    This function is used by distribute_task() which you should use instead.
    """
    global FAKE_STORAGE
    func_name = get_func_task_path(func_to_run, func_class)
    row_chunks = csv_bytes_utf8.split(b'\n')
    cursor = 0
    row_ranges = []
    # gather start/end bytes for each row
    for rowc in row_chunks:
        rng = [cursor]
        cursor = cursor + len(rowc) + 1  # +1 is the \n character
        rng.append(cursor)
        row_ranges.append(rng)

    # group the rows and get start/end bytes for each group
    group_ranges = []
    # table csv writer appends a terminal \r\n, so we do len-1
    for grpstep in range(0, len(row_ranges) - 1, group_count):
        end = min(len(row_ranges) - 1, grpstep + group_count - 1)
        group_ranges.append((row_ranges[grpstep][0], row_ranges[end][1]))

    # upload data
    filename = hash(time.time())
    storagekey = f"{S3_TEMP_KEY_PREFIX}/{filename}.csv"
    groupcount = len(group_ranges)
    logger.debug(f'distribute_task_csv storagekey {storagekey} w/ {groupcount} groups')

    response = None
    if storage == 's3':
        response = S3Storage().put_object(bucket, storagekey, csv_bytes_utf8)
    else:
        response = FAKE_STORAGE.put_object(bucket, storagekey, csv_bytes_utf8)

    # start processes
    results = [
        maybe_async_run(
            process_task_portion,
            [bucket, storagekey, grp[0], grp[1], func_name, header,
             storage, func_kwargs, catch, func_class_kwargs],
            # if we are using local storage, then it must be run locally, as well
            # (good for testing/debugging)
            remote_aws_lambda_function_name='FORCE_LOCAL' if storage == 'local' else None
        )
        for grp in group_ranges]
    return {'DEBUG_ONLY': 'results may vary depending on context/platform',
            'results': results,
            'put_response': response}


[docs]def distribute_task(table, func_to_run, bucket=None, func_kwargs=None, func_class=None, func_class_kwargs=None, catch=False, group_count=100, storage='s3'): """ Distribute processing rows in a table across multiple AWS Lambda invocations. `Args:` table: Parsons Table Table of data you wish to distribute processing across Lambda invocations of `func_to_run` argument. func_to_run: function The function you want to run whose first argument will be a subset of table bucket: str The bucket name to use for s3 upload to process the whole table Not required if you set environment variable ``S3_TEMP_BUCKET`` func_kwargs: dict If the function has other arguments to pass along with `table` then provide them as a dict here. They must all be JSON-able. func_class: class If the function is a classmethod or function on a class, then pass the pure class here. E.g. If you passed `ActionKit.bulk_upload_table`, then you would pass `ActionKit` here. func_class_kwargs: dict If it is a class function, and the class must be instantiated, then pass the kwargs to instantiate the class here. E.g. If you passed `ActionKit.bulk_upload_table` as the function, then you would pass {'domain': ..., 'username': ... etc} here. This must all be JSON-able data. catch: bool Lambda will retry running an event three times if there's an exception -- if you want to prevent this, set `catch=True` and then it will catch any errors and stop retries. The error will be in CloudWatch logs with string "Distribute Error" This might be important if row-actions are not idempotent and your own function might fail causing repeats. group_count: int Set this to how many rows to process with each Lambda invocation (Default: 100) storage: str Debugging option: Defaults to "s3". To test distribution locally without s3, set to "local". `Returns:` Debug information -- do not rely on the output, as it will change depending on how this method is invoked. """ if storage not in ('s3', 'local'): raise DistributeTaskException('storage argument must be s3 or local') bucket = check('S3_TEMP_BUCKET', bucket) csvdata = StringIO() outcsv = csv.writer(csvdata) outcsv.writerows(table.table.data()) return distribute_task_csv(csvdata.getvalue().encode('utf-8-sig'), func_to_run, bucket, header=table.columns, func_kwargs=func_kwargs, func_class=func_class, func_class_kwargs=func_class_kwargs, catch=catch, group_count=group_count, storage=storage)
def process_task_portion(bucket, storagekey, rangestart, rangeend, func_name, header, storage='s3', func_kwargs=None, catch=False, func_class_kwargs=None): global FAKE_STORAGE logger.debug(f'process_task_portion func_name {func_name}, ' f'storagekey {storagekey}, byterange {rangestart}-{rangeend}') func = import_and_get_task(func_name, func_class_kwargs) if storage == 's3': filedata = S3Storage().get_range(bucket, storagekey, rangestart, rangeend) else: filedata = FAKE_STORAGE.get_range(bucket, storagekey, rangestart, rangeend) lines = list(csv.reader(TextIOWrapper(BytesIO(filedata), encoding='utf-8-sig'))) table = Table([header] + lines) if catch: try: func(table, **func_kwargs) except Exception: # In Lambda you can search for '"Distribute Error"' in the logs type_, value_, traceback_ = sys.exc_info() err_traceback_str = '\n'.join(traceback.format_exception(type_, value_, traceback_)) return {'Exception': 'Distribute Error', 'error': err_traceback_str, 'rangestart': rangestart, 'rangeend': rangeend, 'func_name': func_name, 'bucket': bucket, 'storagekey': storagekey} else: func(table, **func_kwargs)