import csv
from io import TextIOWrapper, BytesIO, StringIO
import logging
import sys
import traceback
import time
from parsons.aws.aws_async import (
get_func_task_path,
import_and_get_task,
run as maybe_async_run,
)
from parsons.aws.s3 import S3
from parsons.etl.table import Table
from parsons.utilities.check_env import check
logger = logging.getLogger(__name__)
class DistributeTaskException(Exception):
pass
class TestStorage:
def __init__(self):
self.data = {}
def put_object(self, bucket, key, object_bytes):
self.data[key] = object_bytes
def get_range(self, bucket, key, rangestart, rangeend):
return self.data[key][rangestart:rangeend]
class S3Storage:
"""
These methods are pretty specialized, so we keep them
inside this file rather than s3.py
"""
def __init__(self, use_env_token=True):
self.s3 = S3(use_env_token=use_env_token)
def put_object(self, bucket, key, object_bytes, **kwargs):
return self.s3.client.put_object(Bucket=bucket, Key=key, Body=object_bytes, **kwargs)
def get_range(self, bucket, key, rangestart, rangeend):
"""
Gets an explicit byte-range of an S3 file
"""
# bytes is INCLUSIVE for the rangeend parameter, unlike python
# so e.g. while python returns 2 bytes for data[2:4]
# Range: bytes=2-4 will return 3!! So we subtract 1
response = self.s3.client.get_object(
Bucket=bucket, Key=key, Range="bytes={}-{}".format(rangestart, rangeend - 1)
)
return response["Body"].read()
FAKE_STORAGE = TestStorage()
S3_TEMP_KEY_PREFIX = "Parsons_DistributeTask"
def distribute_task_csv(
csv_bytes_utf8,
func_to_run,
bucket,
header=None,
func_kwargs=None,
func_class=None,
func_class_kwargs=None,
catch=False,
group_count=100,
storage="s3",
use_s3_env_token=True,
):
"""
The same as distribute_task, but instead of a table, the
first argument is bytes of a csv encoded into utf8.
This function is used by distribute_task() which you should use instead.
"""
global FAKE_STORAGE
func_name = get_func_task_path(func_to_run, func_class)
row_chunks = csv_bytes_utf8.split(b"\n")
cursor = 0
row_ranges = []
# gather start/end bytes for each row
for rowc in row_chunks:
rng = [cursor]
cursor = cursor + len(rowc) + 1 # +1 is the \n character
rng.append(cursor)
row_ranges.append(rng)
# group the rows and get start/end bytes for each group
group_ranges = []
# table csv writer appends a terminal \r\n, so we do len-1
for grpstep in range(0, len(row_ranges) - 1, group_count):
end = min(len(row_ranges) - 1, grpstep + group_count - 1)
group_ranges.append((row_ranges[grpstep][0], row_ranges[end][1]))
# upload data
filename = hash(time.time())
storagekey = f"{S3_TEMP_KEY_PREFIX}/{filename}.csv"
groupcount = len(group_ranges)
logger.debug(f"distribute_task_csv storagekey {storagekey} w/ {groupcount} groups")
response = None
if storage == "s3":
response = S3Storage(use_env_token=use_s3_env_token).put_object(
bucket, storagekey, csv_bytes_utf8
)
else:
response = FAKE_STORAGE.put_object(bucket, storagekey, csv_bytes_utf8)
# start processes
results = [
maybe_async_run(
process_task_portion,
[
bucket,
storagekey,
grp[0],
grp[1],
func_name,
header,
storage,
func_kwargs,
catch,
func_class_kwargs,
use_s3_env_token,
],
# if we are using local storage, then it must be run locally, as well
# (good for testing/debugging)
remote_aws_lambda_function_name="FORCE_LOCAL" if storage == "local" else None,
)
for grp in group_ranges
]
return {
"DEBUG_ONLY": "results may vary depending on context/platform",
"results": results,
"put_response": response,
}
[docs]
def distribute_task(
table,
func_to_run,
bucket=None,
func_kwargs=None,
func_class=None,
func_class_kwargs=None,
catch=False,
group_count=100,
storage="s3",
use_s3_env_token=True,
):
"""
Distribute processing rows in a table across multiple AWS Lambda invocations.
`Args:`
table: Parsons Table
Table of data you wish to distribute processing across Lambda invocations
of `func_to_run` argument.
func_to_run: function
The function you want to run whose
first argument will be a subset of table
bucket: str
The bucket name to use for s3 upload to process the whole table
Not required if you set environment variable ``S3_TEMP_BUCKET``
func_kwargs: dict
If the function has other arguments to pass along with `table`
then provide them as a dict here. They must all be JSON-able.
func_class: class
If the function is a classmethod or function on a class,
then pass the pure class here.
E.g. If you passed `ActionKit.bulk_upload_table`,
then you would pass `ActionKit` here.
func_class_kwargs: dict
If it is a class function, and the class must be instantiated,
then pass the kwargs to instantiate the class here.
E.g. If you passed `ActionKit.bulk_upload_table` as the function,
then you would pass {'domain': ..., 'username': ... etc} here.
This must all be JSON-able data.
catch: bool
Lambda will retry running an event three times if there's an
exception -- if you want to prevent this, set `catch=True`
and then it will catch any errors and stop retries.
The error will be in CloudWatch logs with string "Distribute Error"
This might be important if row-actions are not idempotent and your
own function might fail causing repeats.
group_count: int
Set this to how many rows to process with each Lambda invocation (Default: 100)
storage: str
Debugging option: Defaults to "s3". To test distribution locally without s3,
set to "local".
use_s3_env_token: str
If storage is set to "s3", sets the use_env_token parameter on the S3 storage.
`Returns:`
Debug information -- do not rely on the output, as it will change
depending on how this method is invoked.
"""
if storage not in ("s3", "local"):
raise DistributeTaskException("storage argument must be s3 or local")
bucket = check("S3_TEMP_BUCKET", bucket)
csvdata = StringIO()
outcsv = csv.writer(csvdata)
outcsv.writerows(table.table.data())
return distribute_task_csv(
csvdata.getvalue().encode("utf-8-sig"),
func_to_run,
bucket,
header=table.columns,
func_kwargs=func_kwargs,
func_class=func_class,
func_class_kwargs=func_class_kwargs,
catch=catch,
group_count=group_count,
storage=storage,
use_s3_env_token=use_s3_env_token,
)
def process_task_portion(
bucket,
storagekey,
rangestart,
rangeend,
func_name,
header,
storage="s3",
func_kwargs=None,
catch=False,
func_class_kwargs=None,
use_s3_env_token=True,
):
global FAKE_STORAGE
logger.debug(
f"process_task_portion func_name {func_name}, "
f"storagekey {storagekey}, byterange {rangestart}-{rangeend}"
)
func = import_and_get_task(func_name, func_class_kwargs)
if storage == "s3":
filedata = S3Storage(use_env_token=use_s3_env_token).get_range(
bucket, storagekey, rangestart, rangeend
)
else:
filedata = FAKE_STORAGE.get_range(bucket, storagekey, rangestart, rangeend)
lines = list(csv.reader(TextIOWrapper(BytesIO(filedata), encoding="utf-8-sig")))
table = Table([header] + lines)
if catch:
try:
func(table, **func_kwargs)
except Exception:
# In Lambda you can search for '"Distribute Error"' in the logs
type_, value_, traceback_ = sys.exc_info()
err_traceback_str = "\n".join(traceback.format_exception(type_, value_, traceback_))
return {
"Exception": "Distribute Error",
"error": err_traceback_str,
"rangestart": rangestart,
"rangeend": rangeend,
"func_name": func_name,
"bucket": bucket,
"storagekey": storagekey,
}
else:
func(table, **func_kwargs)