# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import multiprocessing
import os
import sys
import traceback
from typing import Optional
[docs]
class ErrorFromWorker:
"""Wrapper class to report an error from a worker process"""
def __init__(self, exc_cls, exc, tb):
"""Create an error object from an exception raised from
the worker process.
The attributes of the process error objects are all strings
as they are easy to send over a pipe.
Args:
exc: exception raised from the worker process
"""
self.pid = os.getpid()
self.error_message = str(exc)
self.stacktrace_message = "".join(traceback.format_exception(exc_cls, exc, tb))
@property
def stacktrace(self):
msg = "[PID={0.pid}] {0.stacktrace_message}"
return msg.format(self)
def __str__(self):
return self.error_message
[docs]
class Task:
"""Wrapped task that trap every Exception and return it as an
ErrorFromWorker object.
We are using a wrapper class instead of a decorator since the class
is pickleable, while a decorator with an inner closure is not.
"""
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
try:
value = self.func(*args, **kwargs)
except Exception:
value = ErrorFromWorker(*sys.exc_info())
return value
[docs]
def imap_unordered(
f, list_of_args, *, processes: int, maxtaskperchild: Optional[int] = None, debug=False
):
"""Wrapper around multiprocessing.Pool.imap_unordered.
Args:
f: function to apply
list_of_args: list of tuples of args for the task
processes: maximum number of processes allowed
debug: if False, raise an exception containing just the error messages
from workers, if True an exception with complete stacktraces
maxtaskperchild: number of tasks to be executed by a child before being
killed and substituted
Raises:
RuntimeError: if any error occurred in the worker processes
"""
if sys.platform in ("darwin", "win32") or len(list_of_args) == 1:
yield from map(f, list_of_args)
return
with multiprocessing.Pool(processes, maxtasksperchild=maxtaskperchild) as p:
for result in p.imap_unordered(Task(f), list_of_args):
if isinstance(result, ErrorFromWorker):
raise RuntimeError(result.stacktrace if debug else str(result))
yield result