diff --git a/compose/cli/command.py b/compose/cli/command.py index e1ae690c0..7cea91a2d 100644 --- a/compose/cli/command.py +++ b/compose/cli/command.py @@ -10,6 +10,7 @@ import six from . import errors from . import verbose_proxy from .. import config +from .. import parallel from ..config.environment import Environment from ..const import API_VERSIONS from ..project import Project @@ -23,6 +24,8 @@ log = logging.getLogger(__name__) def project_from_options(project_dir, options): environment = Environment.from_env_file(project_dir) + set_parallel_limit(environment) + host = options.get('--host') if host is not None: host = host.lstrip('=') @@ -38,6 +41,22 @@ def project_from_options(project_dir, options): ) +def set_parallel_limit(environment): + parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT') + if parallel_limit: + try: + parallel_limit = int(parallel_limit) + except ValueError: + raise errors.UserError( + 'COMPOSE_PARALLEL_LIMIT must be an integer (found: "{}")'.format( + environment.get('COMPOSE_PARALLEL_LIMIT') + ) + ) + if parallel_limit <= 1: + raise errors.UserError('COMPOSE_PARALLEL_LIMIT can not be less than 2') + parallel.GlobalLimit.set_global_limit(parallel_limit) + + def get_config_from_options(base_dir, options): environment = Environment.from_env_file(base_dir) config_path = get_config_path_from_options( diff --git a/compose/const.py b/compose/const.py index 2ac08b89a..6e5902cad 100644 --- a/compose/const.py +++ b/compose/const.py @@ -18,6 +18,7 @@ LABEL_VERSION = 'com.docker.compose.version' LABEL_VOLUME = 'com.docker.compose.volume' LABEL_CONFIG_HASH = 'com.docker.compose.config-hash' NANOCPUS_SCALE = 1000000000 +PARALLEL_LIMIT = 64 SECRETS_PATH = '/run/secrets' diff --git a/compose/parallel.py b/compose/parallel.py index f271561ff..3c0098c05 100644 --- a/compose/parallel.py +++ b/compose/parallel.py @@ -15,6 +15,7 @@ from six.moves.queue import Queue from compose.cli.colors import green from compose.cli.colors import red from compose.cli.signals import ShutdownException +from compose.const import PARALLEL_LIMIT from compose.errors import HealthCheckFailed from compose.errors import NoHealthCheckConfigured from compose.errors import OperationFailedError @@ -26,6 +27,20 @@ log = logging.getLogger(__name__) STOP = object() +class GlobalLimit(object): + """Simple class to hold a global semaphore limiter for a project. This class + should be treated as a singleton that is instantiated when the project is. + """ + + global_limiter = Semaphore(PARALLEL_LIMIT) + + @classmethod + def set_global_limit(cls, value): + if value is None: + value = PARALLEL_LIMIT + cls.global_limiter = Semaphore(value) + + def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, parent_objects=None): """Runs func on objects in parallel while ensuring that func is ran on object only after it is ran on all its dependencies. @@ -173,7 +188,7 @@ def producer(obj, func, results, limiter): The entry point for a producer thread which runs func on a single object. Places a tuple on the results queue once func has either returned or raised. """ - with limiter: + with limiter, GlobalLimit.global_limiter: try: result = func(obj) results.put((obj, result, None)) diff --git a/tests/unit/parallel_test.py b/tests/unit/parallel_test.py index 3a60f01a6..4ebc24d8c 100644 --- a/tests/unit/parallel_test.py +++ b/tests/unit/parallel_test.py @@ -1,11 +1,13 @@ from __future__ import absolute_import from __future__ import unicode_literals +import unittest from threading import Lock import six from docker.errors import APIError +from compose.parallel import GlobalLimit from compose.parallel import parallel_execute from compose.parallel import parallel_execute_iter from compose.parallel import ParallelStreamWriter @@ -31,91 +33,113 @@ def get_deps(obj): return [(dep, None) for dep in deps[obj]] -def test_parallel_execute(): - results, errors = parallel_execute( - objects=[1, 2, 3, 4, 5], - func=lambda x: x * 2, - get_name=six.text_type, - msg="Doubling", - ) +class ParallelTest(unittest.TestCase): - assert sorted(results) == [2, 4, 6, 8, 10] - assert errors == {} + def test_parallel_execute(self): + results, errors = parallel_execute( + objects=[1, 2, 3, 4, 5], + func=lambda x: x * 2, + get_name=six.text_type, + msg="Doubling", + ) + assert sorted(results) == [2, 4, 6, 8, 10] + assert errors == {} -def test_parallel_execute_with_limit(): - limit = 1 - tasks = 20 - lock = Lock() + def test_parallel_execute_with_limit(self): + limit = 1 + tasks = 20 + lock = Lock() - def f(obj): - locked = lock.acquire(False) - # we should always get the lock because we're the only thread running - assert locked - lock.release() - return None + def f(obj): + locked = lock.acquire(False) + # we should always get the lock because we're the only thread running + assert locked + lock.release() + return None - results, errors = parallel_execute( - objects=list(range(tasks)), - func=f, - get_name=six.text_type, - msg="Testing", - limit=limit, - ) + results, errors = parallel_execute( + objects=list(range(tasks)), + func=f, + get_name=six.text_type, + msg="Testing", + limit=limit, + ) - assert results == tasks * [None] - assert errors == {} + assert results == tasks * [None] + assert errors == {} + def test_parallel_execute_with_global_limit(self): + GlobalLimit.set_global_limit(1) + self.addCleanup(GlobalLimit.set_global_limit, None) + tasks = 20 + lock = Lock() -def test_parallel_execute_with_deps(): - log = [] + def f(obj): + locked = lock.acquire(False) + # we should always get the lock because we're the only thread running + assert locked + lock.release() + return None - def process(x): - log.append(x) + results, errors = parallel_execute( + objects=list(range(tasks)), + func=f, + get_name=six.text_type, + msg="Testing", + ) - parallel_execute( - objects=objects, - func=process, - get_name=lambda obj: obj, - msg="Processing", - get_deps=get_deps, - ) + assert results == tasks * [None] + assert errors == {} - assert sorted(log) == sorted(objects) + def test_parallel_execute_with_deps(self): + log = [] - assert log.index(data_volume) < log.index(db) - assert log.index(db) < log.index(web) - assert log.index(cache) < log.index(web) + def process(x): + log.append(x) + parallel_execute( + objects=objects, + func=process, + get_name=lambda obj: obj, + msg="Processing", + get_deps=get_deps, + ) -def test_parallel_execute_with_upstream_errors(): - log = [] + assert sorted(log) == sorted(objects) - def process(x): - if x is data_volume: - raise APIError(None, None, "Something went wrong") - log.append(x) + assert log.index(data_volume) < log.index(db) + assert log.index(db) < log.index(web) + assert log.index(cache) < log.index(web) - parallel_execute( - objects=objects, - func=process, - get_name=lambda obj: obj, - msg="Processing", - get_deps=get_deps, - ) + def test_parallel_execute_with_upstream_errors(self): + log = [] - assert log == [cache] + def process(x): + if x is data_volume: + raise APIError(None, None, "Something went wrong") + log.append(x) - events = [ - (obj, result, type(exception)) - for obj, result, exception - in parallel_execute_iter(objects, process, get_deps, None) - ] + parallel_execute( + objects=objects, + func=process, + get_name=lambda obj: obj, + msg="Processing", + get_deps=get_deps, + ) - assert (cache, None, type(None)) in events - assert (data_volume, None, APIError) in events - assert (db, None, UpstreamError) in events - assert (web, None, UpstreamError) in events + assert log == [cache] + + events = [ + (obj, result, type(exception)) + for obj, result, exception + in parallel_execute_iter(objects, process, get_deps, None) + ] + + assert (cache, None, type(None)) in events + assert (data_volume, None, APIError) in events + assert (db, None, UpstreamError) in events + assert (web, None, UpstreamError) in events def test_parallel_execute_alignment(capsys):