# -*- coding: utf-8 -*-
from __future__ import print_function
import time
import logging
import itertools
import traceback
import multiprocessing
import pickle as pickle
from collections import namedtuple
import six
from six.moves import queue
import cloudpickle
from .util import istask
from . import tracked
from .helpers import apply_sh
logger = logging.getLogger(__name__)
[docs]class TaskResult(namedtuple(
"TaskResult", ["task_no", "error", "dep_keys", "dep_compares"])):
"""The result of a task's execution.
:param task_no: The number of the task that produced this result.
:type: int
:param error: The error generated by executing the task. None if
the task was successful.
:type: str or None
:param dep_keys: The list of ``.name`` attributes from the objects
of type :class:`anadama2.tracked.Base` associated with
this task. These keys are used with the storage backend to save
successful task results.
:type dep_keys: list of str
:param dep_compares: The list of the results of the ``compare()``
method from the objects of type
:class:`anadama2.tracked.Base` associated with this
task. These values are used with the storage backend to save
successful task results.
:type dep_keys: list of list of str
"""
pass
[docs]class TaskFailed(Exception):
def __init__(self, msg, task_no):
self.task_no = task_no
super(TaskFailed, self).__init__(msg)
[docs]class BaseRunner(object):
def __init__(self, run_context):
self.ctx = run_context
self.quit_early = False
[docs] def run_tasks(self, task_idx_deque):
raise NotImplementedError()
[docs]class DryRunner(BaseRunner):
_typemap = {
tracked.TrackedString: "String",
tracked.TrackedVariable: "Variable",
tracked.TrackedFile: "File",
tracked.HugeTrackedFile: "Big File",
tracked.TrackedDirectory: "Directory",
tracked.TrackedFilePattern: "File Pattern",
tracked.TrackedExecutable: "Executable",
tracked.TrackedFunction: "Python Function",
}
sublist_template=" - {} ({})"
[docs] def run_tasks(self, task_idx_deque):
# order the tasks by number
tasks_by_number={}
for index in task_idx_deque:
number=self.ctx.tasks[index].task_no
try:
number=int(number)
except (ValueError, TypeError):
pass
tasks_by_number[number]=self.ctx.tasks[index]
# print the tasks from first to last
for number, task in sorted(tasks_by_number.items()):
six.print_("{} - {}".format(number, task.name))
six.print_(" Dependencies ({})".format(len(task.depends)))
for dep in task.depends:
six.print_(self._depformat(dep))
six.print_(" Targets ({})".format(len(task.targets)))
for dep in task.targets:
six.print_(self._depformat(dep))
six.print_(" Actions ({})".format(len(task.actions)))
for action in task.actions:
six.print_(self._actionformat(action))
six.print_("------------------")
def _depformat(self, d):
if istask(d):
return " - Task {} - {}".format(d.task_no, d.name)
else:
t = type(d)
desc = self._typemap.get(t, str(t))
return self.sublist_template.format(d.name, desc)
def _actionformat(self, action):
if six.callable(action):
return self.sublist_template.format(action.__name__,"function")
else:
return self.sublist_template.format(action,"command")
[docs]class SerialLocalRunner(BaseRunner):
[docs] def run_tasks(self, task_idx_deque):
total = len(task_idx_deque)
logger.debug("Running %i tasks locally and serially", total)
while task_idx_deque:
idx = task_idx_deque.pop()
parents = set(self.ctx.dag.predecessors(idx))
failed_parents = parents.intersection(self.ctx.failed_tasks)
if failed_parents:
self.ctx._handle_task_result(
parent_failed_result(idx, next(iter(failed_parents))))
continue
self.ctx._handle_task_started(idx)
self.ctx._reporter.task_running(idx)
self.ctx._reporter.task_command(idx)
result = _run_task_locally(self.ctx.tasks[idx])
self.ctx._handle_task_result(result)
if self.quit_early and bool(result.error):
logger.debug("Quitting early")
break
[docs]def worker_run_loop(work_q, result_q, run_task, reporter=None, lock=None):
logger.debug("Starting worker")
while True:
try:
logger.debug("Getting work")
pkl, extra = work_q.get()
logger.debug("Got work")
except IOError as e:
logger.debug("Received IOError (%s) errno %s from work_q",
e.message, e.errno)
break
except EOFError:
logger.debug("Received EOFError from work_q")
break
if type(pkl) is dict and pkl.get("stop", False):
logger.debug("Received sentinel, stopping")
break
try:
logger.debug("Deserializing task")
task = pickle.loads(pkl)
logger.debug("Task deserialized")
except Exception as e:
result_q.put_nowait(exception_result(e))
logger.debug("Failed to deserialize task")
continue
logger.debug("Running task %s with %s", task.task_no, run_task)
if reporter:
if lock is not None:
lock.acquire()
reporter.task_running(task.task_no)
reporter.task_command(task.task_no)
if lock is not None:
lock.release()
result = run_task(task, extra)
logger.debug("Finished running task; "
"putting results on result_q")
result_q.put_nowait(result)
logger.debug("Result put on result_q. Back to get more work.")
def _run_task_locally(task, extra=None):
# convert any command strings to sh actions before running
for i, action_func in enumerate(apply_sh(task.actions)):
logger.debug("Executing task %i action %i", task.task_no, i)
try:
action_func(task)
except Exception:
msg = ("Error executing action {}. "
"Original Exception: \n{}")
return exception_result(
TaskFailed(msg.format(i, traceback.format_exc()), task.task_no)
)
logger.debug("Completed executing task %i action %i", task.task_no, i)
return _get_task_result(task)
def _get_task_result(task):
targ_keys, targ_compares = list(), list()
for target in task.targets:
targ_keys.append(target.name)
try:
targ_compares.append(list(target.compare()))
except (Exception, EnvironmentError):
msg = "Failed to produce target `{}'. Original exception: {}"
return exception_result(
TaskFailed(msg.format(target, traceback.format_exc()),
task.task_no)
)
return TaskResult(task.task_no, None, targ_keys, targ_compares)
[docs]class ParallelLocalWorker(multiprocessing.Process):
def __init__(self, work_q, result_q, lock, reporter):
super(ParallelLocalWorker, self).__init__()
self.logger = logger
self.work_q = work_q
self.result_q = result_q
self.lock = lock
self.reporter = reporter
[docs] @staticmethod
def appropriate_q_class(*args, **kwargs):
return multiprocessing.Queue(*args, **kwargs)
[docs] @staticmethod
def appropriate_lock():
return multiprocessing.Lock()
[docs] def run(self):
return worker_run_loop(self.work_q, self.result_q, _run_task_locally, self.reporter, self.lock)
[docs]class ParallelLocalRunner(BaseRunner):
def __init__(self, run_context, jobs):
super(ParallelLocalRunner, self).__init__(run_context)
self.work_q = multiprocessing.Queue()
self.result_q = multiprocessing.Queue()
self.lock = multiprocessing.Lock()
self.reporter = run_context._reporter
self.workers = [ ParallelLocalWorker(self.work_q, self.result_q, self.lock, self.reporter)
for _ in range(jobs) ]
self.started = False
[docs] def run_tasks(self, task_idx_deque):
self.task_idx_deque = task_idx_deque
logger.debug("Running %i tasks in parallel with %i workers locally",
len(task_idx_deque), len(self.workers))
self.n_to_do = len(task_idx_deque)
while True:
self._fill_work_q()
logger.debug("Tasks left to do: %s", self.n_to_do)
if self.n_to_do <= 0:
logger.debug("No new tasks added to work queues and"
" the number of completed tasks equals the"
" number of tasks to do. Job's done")
break
try:
result = self.result_q.get()
except (SystemExit, KeyboardInterrupt):
logger.info("Terminating due to SystemExit or Ctrl-C")
self.terminate()
raise
except Exception as e:
logger.error("Terminating due to unhandled exception")
logger.exception(e)
self.terminate()
raise
else:
self.n_to_do -= 1
self.ctx._handle_task_result(result)
if self.quit_early and result.error:
logger.debug("Quitting early.")
self.terminate()
break
self.cleanup()
def _fill_work_q(self):
logger.debug("Filling work_q")
for _ in range(len(self.task_idx_deque)):
idx = self.task_idx_deque.pop()
parents = set(self.ctx.dag.predecessors(idx))
failed_parents = parents.intersection(self.ctx.failed_tasks)
if failed_parents:
self.ctx._handle_task_result(
parent_failed_result(idx, failed_parents.pop()))
self.n_to_do -= 1
continue
elif parents.difference(self.ctx.completed_tasks):
# has undone parents, come back again later
self.task_idx_deque.appendleft(idx)
continue
try:
pkl = cloudpickle.dumps(self.ctx.tasks[idx])
except Exception as e:
msg = ("Unable to serialize task `{}'. "
"Original error was `{}'.")
raise ValueError(msg.format(self.ctx.tasks[idx], e))
logger.debug("Adding task %i to work_q", idx)
self.ctx._handle_task_started(idx)
self.work_q.put((pkl, None))
logger.debug("Added task %i to work_q", idx)
if not self.started:
logger.debug("Starting up workers")
for w in self.workers:
w.start()
self.started = True
[docs] def terminate(self):
logger.debug("Terminating all workers")
self.work_q._rlock.acquire()
logger.debug("got work_q readlock")
while self.work_q._reader.poll():
logger.debug("draining work_q")
try:
self.work_q._reader.recv()
except EOFError:
break
time.sleep(0)
for worker in self.workers:
logger.debug("terminating worker %s", worker)
worker.terminate()
for worker in self.workers:
logger.debug("joining worker %s", worker)
worker.join()
logger.debug("releasing readlock")
self.work_q._rlock.release()
logger.debug("termination complete")
[docs] def cleanup(self):
logger.debug("cleaning up parallellocalrunner")
for w in self.workers:
logger.debug("giving stop sentinel to worker %s", w)
self.work_q.put(({"stop": True}, None))
for w in self.workers:
logger.debug("joining worker %s", w)
w.join()
logger.debug("successfully cleaned up parallellocalrunner")
[docs]class GridRunner(BaseRunner):
def __init__(self, workflow):
super(GridRunner, self).__init__(workflow)
self._worker_config = dict()
self._worker_qs = dict()
self.routes = dict() # task_no -> (worker_type_name, extra_args)
self.workers = list()
self.started = False
self.default_worker = None
[docs] def add_worker(self, worker_class, name,
rate=1, default=False):
self._worker_config[name] = (worker_class, rate)
if default:
self.default_worker = name
[docs] def run_tasks(self, task_idx_deque):
self.task_idx_deque = task_idx_deque
if not self.workers:
self._init_workers()
logger.debug("Running %i tasks in parallel with %i workers"
" using the grid",
len(task_idx_deque), len(self.workers))
self.n_to_do = len(task_idx_deque)
while True:
self._fill_work_qs()
logger.debug("Tasks left to do: %s", self.n_to_do)
if self.n_to_do <= 0:
break
try:
result = self._get_result()
except (SystemExit, KeyboardInterrupt):
logger.info("Terminating due to SystemExit or Ctrl-C")
self.terminate()
raise
except Exception as e:
logger.error("Terminating due to unhandled exception")
logger.exception(e)
self.terminate()
raise
else:
self.n_to_do -= 1
self.ctx._handle_task_result(result)
if self.quit_early and result.error:
logger.debug("Quitting early.")
self.terminate()
break
self.cleanup()
[docs] def terminate(self):
for name, (work_q, _) in self._worker_qs.items():
if hasattr(work_q, "_rlock"):
self._terminate_mpq(work_q, name)
elif hasattr(work_q, "mutex"):
self._terminate_qq(work_q, name)
else:
raise Exception
for worker in self.workers:
worker.join()
[docs] def cleanup(self):
for name, (_, n_procs) in self._worker_config.items():
for _ in range(n_procs):
self._worker_qs[name][0].put(({"stop": True}, None))
for w in self.workers:
w.join()
[docs] def route(self, task_no):
if task_no in self.routes:
return self.routes[task_no]
elif self.default_worker is not None:
return self.default_worker, None
else:
msg = ("GridRunner tried to run task {} but has no "
"runner to run it and no default runner is defined")
raise ValueError(msg.format(task_no))
def _fill_work_qs(self):
logger.debug("Filling work_qs")
for _ in range(len(self.task_idx_deque)):
idx = self._get_next_task()
if idx is None:
continue
try:
pkl = cloudpickle.dumps(self.ctx.tasks[idx])
except Exception as e:
msg = ("Unable to serialize task `{}'. "
"Original error was `{}'.")
raise ValueError(msg.format(self.ctx.tasks[idx], e))
name, extra = self.route(idx)
logger.debug("Adding task %i to `%s' work_q", idx, name)
self._worker_qs[name][0].put((pkl, extra))
self.ctx._handle_task_started(idx)
logger.debug("Added task %i to `%s' work_q", idx, name)
if not self.started:
logger.debug("Starting up workers")
for w in self.workers:
logger.debug("Starting worker %s", w)
w.start()
self.started = True
def _init_workers(self):
threads, procs = list(), list()
for name, (worker_cls, n_procs) in self._worker_config.items():
work_q = worker_cls.appropriate_q_class()
result_q = worker_cls.appropriate_q_class()
lock = worker_cls.appropriate_lock()
self._worker_qs[name] = (work_q, result_q)
isproc = issubclass(worker_cls, multiprocessing.Process)
l = procs if isproc else threads
for _ in range(n_procs):
l.append(worker_cls(work_q, result_q, lock, self.ctx._reporter))
self.workers = procs+threads # http://stackoverflow.com/a/13115499
self._qcycle = itertools.cycle(val[1]
for val in self._worker_qs.values())
def _get_next_task(self):
idx = self.task_idx_deque.pop()
parents = set(self.ctx.dag.predecessors(idx))
failed_parents = parents.intersection(self.ctx.failed_tasks)
if failed_parents:
self.ctx._handle_task_result(
parent_failed_result(idx, next(iter(failed_parents)))
)
self.n_to_do -= 1
return None
elif parents.difference(self.ctx.completed_tasks):
# has undone parents, come back again later
self.task_idx_deque.appendleft(idx)
return None
if idx is None:
raise Exception
return idx
def _get_result(self):
while True:
for _ in range(len(self.workers)):
try:
ret = next(self._qcycle).get(False)
except queue.Empty:
continue
return ret
time.sleep(0.05)
def _terminate_mpq(self, q, name):
q._rlock.acquire()
while q._reader.poll():
try:
q._reader.recv()
except EOFError:
break
time.sleep(0)
worker_type = self._worker_config[name][0]
for worker in self.workers:
if isinstance(worker, worker_type):
worker.terminate()
q._rlock.release()
def _terminate_qq(self, q, name):
q.mutex.acquire()
while q.queue:
q.queue.pop()
worker_type = self._worker_config[name][0]
for worker in self.workers:
if isinstance(worker, worker_type):
worker.join()
q.mutex.release()
[docs]def default(run_context, jobs):
if jobs < 2:
return SerialLocalRunner(run_context)
else:
return ParallelLocalRunner(run_context, jobs)
_current_grid_runner = None
[docs]def current_grid_runner(context):
global _current_grid_runner
if _current_grid_runner is None:
_current_grid_runner = GridRunner(context)
return _current_grid_runner
[docs]def exception_result(exc):
return TaskResult(getattr(exc, "task_no", None), str(exc), None, None)
[docs]def parent_failed_result(idx, parent_idx):
return TaskResult(
idx, "Task failed because parent task `{}' failed".format(parent_idx),
None, None)