From d651412ba4c5f762477b5993debae07c6340d1ba Mon Sep 17 00:00:00 2001
From: xuetaowave <cxt@cxt-win>
Date: Sun, 18 Feb 2024 20:08:41 +0800
Subject: [PATCH] update

---
 multi_gpu.py | 3 ++-
 train.py     | 3 ++-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/multi_gpu.py b/multi_gpu.py
index 54aabe9..49c8625 100644
--- a/multi_gpu.py
+++ b/multi_gpu.py
@@ -1,4 +1,5 @@
 import os
+from datetime import timedelta
 
 os.environ['MASTER_ADDR'] = '0'
 os.environ['RANK'] = '0'
@@ -9,5 +10,5 @@ os.environ['WANDB_START_METHOD'] = 'thread'
 os.environ['MASTER_PORT'] = '19500'
 
 import torch.distributed as dist
-dist.init_process_group()
+dist.init_process_group(timeout=timedelta(seconds=10))
 pass
\ No newline at end of file
diff --git a/train.py b/train.py
index 67f8dbd..d3fb74d 100644
--- a/train.py
+++ b/train.py
@@ -46,6 +46,7 @@
 
 import os
 import time
+from datetime import timedelta
 import numpy as np
 import argparse
 import h5py
@@ -558,7 +559,7 @@ if __name__ == '__main__':
   local_rank = 0
   if params['world_size'] > 1:
     dist.init_process_group(backend='nccl',
-                            init_method='env://')
+                            init_method='env://', timeout=timedelta(seconds=10))
     local_rank = int(os.environ["LOCAL_RANK"])
     args.gpu = local_rank
     world_rank = dist.get_rank()
-- 
GitLab