|
|
@@ -0,0 +1,85 @@ |
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates. |
|
|
|
# All rights reserved. |
|
|
|
# |
|
|
|
# This source code is licensed under the license found in the |
|
|
|
# LICENSE file in the root directory of this source tree. |
|
|
|
|
|
|
|
# Author: Alexandre Defossez (adefossez) |
|
|
|
|
|
|
|
""" |
|
|
|
Start multiple process locally for DDP. |
|
|
|
""" |
|
|
|
|
|
|
|
import logging |
|
|
|
import subprocess as sp |
|
|
|
import sys |
|
|
|
|
|
|
|
from hydra import utils |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class ChildrenManager: |
|
|
|
def __init__(self): |
|
|
|
self.children = [] |
|
|
|
self.failed = False |
|
|
|
|
|
|
|
def add(self, child): |
|
|
|
child.rank = len(self.children) |
|
|
|
self.children.append(child) |
|
|
|
|
|
|
|
def __enter__(self): |
|
|
|
return self |
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
|
|
if exc_value is not None: |
|
|
|
logger.error( |
|
|
|
"An exception happened while starting workers %r", exc_value) |
|
|
|
self.failed = True |
|
|
|
try: |
|
|
|
while self.children and not self.failed: |
|
|
|
for child in list(self.children): |
|
|
|
try: |
|
|
|
exitcode = child.wait(0.1) |
|
|
|
except sp.TimeoutExpired: |
|
|
|
continue |
|
|
|
else: |
|
|
|
self.children.remove(child) |
|
|
|
if exitcode: |
|
|
|
logger.error( |
|
|
|
f"Worker {child.rank} died, killing all workers") |
|
|
|
self.failed = True |
|
|
|
except KeyboardInterrupt: |
|
|
|
logger.error( |
|
|
|
"Received keyboard interrupt, trying to kill all workers.") |
|
|
|
self.failed = True |
|
|
|
for child in self.children: |
|
|
|
child.terminate() |
|
|
|
if not self.failed: |
|
|
|
logger.info("All workers completed successfully") |
|
|
|
|
|
|
|
|
|
|
|
def start_ddp_workers(): |
|
|
|
import torch as th |
|
|
|
|
|
|
|
world_size = th.cuda.device_count() |
|
|
|
if not world_size: |
|
|
|
logger.error( |
|
|
|
"DDP is only available on GPU. Make sure GPUs are properly configured with cuda.") |
|
|
|
sys.exit(1) |
|
|
|
logger.info(f"Starting {world_size} worker processes for DDP.") |
|
|
|
with ChildrenManager() as manager: |
|
|
|
for rank in range(world_size): |
|
|
|
kwargs = {} |
|
|
|
argv = list(sys.argv) |
|
|
|
argv += [f"world_size={world_size}", f"rank={rank}"] |
|
|
|
if rank > 0: |
|
|
|
kwargs['stdin'] = sp.DEVNULL |
|
|
|
kwargs['stdout'] = sp.DEVNULL |
|
|
|
kwargs['stderr'] = sp.DEVNULL |
|
|
|
log = utils.HydraConfig().cfg.hydra.job_logging.handlers.file.filename |
|
|
|
log += f".{rank}" |
|
|
|
argv.append("hydra.job_logging.handlers.file.filename=" + log) |
|
|
|
manager.add(sp.Popen([sys.executable] + argv, |
|
|
|
cwd=utils.get_original_cwd(), **kwargs)) |
|
|
|
sys.exit(int(manager.failed)) |