From 8c03e30549da81b4df3959fcdfaba363aec409b9 Mon Sep 17 00:00:00 2001
From: PinkPanther-ny <alvincny529@gmail.com>
Date: Mon, 21 Mar 2022 17:00:55 +0800
Subject: [PATCH] Refactor framework structure.

---
 src/models/load_model.py     | 39 ++++++++------------
 src/preprocess/preprocess.py |  8 ++--
 src/settings/configs.py      | 71 +++++++++++++++++-------------------
 src/utils/utils.py           | 42 ++++++++++-----------
 4 files changed, 74 insertions(+), 86 deletions(-)

diff --git a/src/models/load_model.py b/src/models/load_model.py
index 22b15d0..680df74 100644
--- a/src/models/load_model.py
+++ b/src/models/load_model.py
@@ -1,41 +1,34 @@
-# OMP_NUM_THREADS=2 python -m torch.distributed.run --nproc_per_node 4 90plus.py
-
-import datetime
-import os
 import torch
-import torch.nn as nn
-import torch.optim as optim
-import gc
-from subprocess import call
-
-import torch.distributed as dist
 from torch.nn.parallel import DistributedDataParallel as DDP
+
 from ..settings import configs, model
-from ..utils import Timer, find_best3, eval_total
+from ..utils import find_best_n_model
 
 def ini_model():
-    local_rank = configs.LOCAL_RANK
     global model
     # Load model to gpu
-    device = torch.device("cuda", local_rank)
-    configs.DEVICE = device
     # Check if load specific model or load best model in model folder
     if configs.LOAD_MODEL:
         if configs.LOAD_BEST:
-            configs.MODEL_NAME = find_best3(local_rank)
+            configs.MODEL_NAME = find_best_n_model(configs._LOCAL_RANK)
         try:
-            print(configs.MODEL_DIR + configs.MODEL_NAME)
-            model.load_state_dict(torch.load(configs.MODEL_DIR + configs.MODEL_NAME))
+            model.load_state_dict(torch.load(configs._MODEL_DIR + configs.MODEL_NAME, map_location=configs._DEVICE))
+            configs._LOAD_SUCCESS = True
 
-        except FileNotFoundError or IsADirectoryError:
-            print(f"{configs.MODEL_NAME} Model not found!")
-    
+        except FileNotFoundError:
+            if configs._LOCAL_RANK == 0:
+                print(f"[\"{configs.MODEL_NAME}\"] Model not found! Fall back to untrained model.\n")
+            configs._LOAD_SUCCESS = False
+        except IsADirectoryError:
+            if configs._LOCAL_RANK == 0:
+                print(f"IsADirectoryError! Fall back to untrained model.\n")
+            configs._LOAD_SUCCESS = False
+            
     # Move loaded model with parameters to gpus
     # Then warp with DDP, reducer will be constructed too.
-    model.to(device)
+    model.to(configs._DEVICE)
     if configs.DDP_ON:
-        model = DDP(model, device_ids=[local_rank], output_device=local_rank)
+        model = DDP(model, device_ids=[configs._LOCAL_RANK], output_device=configs._LOCAL_RANK)
     
-
     return model
     
\ No newline at end of file
diff --git a/src/preprocess/preprocess.py b/src/preprocess/preprocess.py
index 3a1d29a..7128ef5 100644
--- a/src/preprocess/preprocess.py
+++ b/src/preprocess/preprocess.py
@@ -38,14 +38,14 @@ class Preprocessor:
         if self.loader is not None:
             return self.loader
 
-        data_dir = configs.DATA_DIR
+        data_dir = configs._DATA_DIR
         batch_size = configs.BATCH_SIZE
         n_workers = configs.NUM_WORKERS
 
         train_set = CIFAR10(root=data_dir, train=True,
-                            download=True, transform=self.transform_train)
+                            download=False, transform=self.transform_train)
         test_set = CIFAR10(root=data_dir, train=False,
-                           download=True, transform=self.transform_test)
+                           download=False, transform=self.transform_test)
         if configs.DDP_ON:
             train_sampler = DistributedSampler(train_set)
             train_loader = DataLoader(train_set, batch_size=batch_size,
@@ -78,6 +78,6 @@ class Preprocessor:
             fig.add_subplot(wid, wid, i + 1)
             plt.imshow((np.transpose(loader.dataset[index][0].numpy(), (1, 2, 0))))
             plt.axis('off')
-            plt.title(configs.CLASSES[loader.dataset[index][1]])
+            plt.title(configs._CLASSES[loader.dataset[index][1]])
 
         fig.show()
diff --git a/src/settings/configs.py b/src/settings/configs.py
index 048d36c..6b0e55d 100644
--- a/src/settings/configs.py
+++ b/src/settings/configs.py
@@ -1,52 +1,51 @@
 import os
 import json
 
-os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, [0, 1, 2, 3, 4, 5, 6, 7]))
+import torch
 
 
 class Config:
     def __init__(self, *dict_config) -> None:
         # ==============================================
         # GLOBAL SETTINGS
+        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, [0, 1, 2, 3, 4, 5, 6, 7]))
+        
         self.DDP_ON: bool = True
+        self.MIX_PRECISION: bool = True
 
         self.BATCH_SIZE: int = 512
-        self.LEARNING_RATE: float = 1e-3
-        self.LEARNING_RATE_UPDATE_EPOCH: int = 30
-        self.LEARNING_RATE_UPDATE_RATE: float = 0.12
-        self.LEARNING_RATE_END: float = 1e-5
+        self.LEARNING_RATE: float = 1e-4
         self.TOTAL_EPOCHS: int = 5000
 
-        self.OPT_USE_ADAM: bool = True
-
         self.LOAD_MODEL: bool = True
-        self.MODEL_NAME: str = "10X92.pth"
-        self.LOAD_BEST: bool = False
-        self.EPOCH_TO_LOAD_BEST: int = 15
-
-        self.MODEL_SAVE_THRESHOLD: float = 0
-
-        self.NUM_WORKERS: int = 4
-        self.N_LOGS_PER_EPOCH: int = 0
+        self.MODEL_NAME: str = "92_35.pth"
+        self.LOAD_BEST: bool = True
+        self.N_LOGS_PER_EPOCH: int = 3
 
         # ==============================================
         # SPECIAL SETTINGS
-        self.EPOCHS_PER_EVAL: int = 2
-
-        self.ADAM_SGD_SWITCH: bool = True
-        self.EPOCHS_PER_SWITCH: int = 30
-
+        self.EPOCHS_PER_EVAL: int = 1
+        self.NUM_WORKERS: int = 4
+        self.MODEL_DIR_NAME: str = "/models_v100/"
+        
         # ==============================================
-        # NOT SUPPOSED TO BE CHANGED OFTEN
-
-        self.WORKING_DIR: str = os.path.dirname(os.path.realpath(__file__))
-        self.MODEL_DIR: str = self.WORKING_DIR + "/models_v100/"
-        self.DATA_DIR: str = self.WORKING_DIR + '/data/'
-        self.CLASSES: tuple = ('plane', 'car', 'bird', 'cat', 'deer',
+        # Private
+        self._WORKING_DIR: str = os.path.dirname(os.path.realpath(__file__))
+        self._MODEL_DIR: str = self._WORKING_DIR + self.MODEL_DIR_NAME
+        self._DATA_DIR: str = self._WORKING_DIR + '/data/'
+        self._CLASSES: tuple = ('plane', 'car', 'bird', 'cat', 'deer',
                                'dog', 'frog', 'horse', 'ship', 'truck')
 
-        self.DEVICE = None
-        self.LOCAL_RANK = None
+        self._DEVICE = None
+        self._LOCAL_RANK = None
+        self._LOAD_SUCCESS: bool = False
+        
+        if self.DDP_ON:
+            self._LOCAL_RANK = int(os.environ["LOCAL_RANK"])
+        else:
+            self._LOCAL_RANK = 0
+        
+        self._DEVICE = torch.device("cuda", self._LOCAL_RANK)
         
         if len(dict_config) != 0:
             d = eval(dict_config[0])
@@ -54,21 +53,21 @@ class Config:
                 setattr(self, k, d[k])
 
     def reset_working_dir(self, main_dir):
-        self.WORKING_DIR: str = os.path.dirname(os.path.realpath(main_dir))
-        self.MODEL_DIR: str = self.WORKING_DIR + "/models_v100/"
-        self.DATA_DIR: str = self.WORKING_DIR + '/data/'
+        self._WORKING_DIR: str = os.path.dirname(os.path.realpath(main_dir))
+        self._MODEL_DIR: str = self._WORKING_DIR + self.MODEL_DIR_NAME
+        self._DATA_DIR: str = self._WORKING_DIR + '/data/'
                 
-        if not os.path.exists(self.MODEL_DIR):
-            os.makedirs(self.MODEL_DIR)
+        if not os.path.exists(self._MODEL_DIR):
+            os.makedirs(self._MODEL_DIR)
 
     def save(self, fn='/config.json'):
-        with open(self.WORKING_DIR + fn, 'w') as fp:
+        with open(self._WORKING_DIR + fn, 'w') as fp:
             json.dump(str(self.__dict__), fp, indent=4)
 
     def load(self, fn='/config.json'):
         try:
 
-            with open(self.WORKING_DIR + fn, 'r') as fp:
+            with open(self._WORKING_DIR + fn, 'r') as fp:
                 dict_config = json.load(fp)
                 d = eval(dict_config)
                 for k in dict(d):
@@ -79,5 +78,3 @@ class Config:
 
 
 configs = Config()
-# configs.load()
-# configs.save()
diff --git a/src/utils/utils.py b/src/utils/utils.py
index 90d314b..7264179 100644
--- a/src/utils/utils.py
+++ b/src/utils/utils.py
@@ -24,9 +24,8 @@ class Timer:
 
 def eval_total(model, testloader, timer, epoch=-1):
     # Only neccessary to evaluate model on one gpu
-    if configs.LOCAL_RANK != 0:
+    if configs._LOCAL_RANK != 0:
         return
-    device = configs.DEVICE
     model.eval()
     correct = 0
     total = 0
@@ -36,51 +35,50 @@ def eval_total(model, testloader, timer, epoch=-1):
         for data in testloader:
             images, labels = data
             # calculate outputs by running images through the network
-            outputs = model(images.to(device))
+            outputs = model(images.to(configs._DEVICE))
             # the class with the highest energy is what we choose as prediction
             _, predicted = torch.max(outputs.cpu().data, 1)
             total += labels.size(0)
             correct += (predicted == labels).sum().item()
             
-    save_model = 100 * correct / total >= configs.MODEL_SAVE_THRESHOLD
-    print(f"{'''''' if epoch==-1 else '''Epoch ''' + str(epoch) + ''': '''}Accuracy of the network on the {total} test images: {100 * correct / float(total)} % ({'saved' if save_model else 'discarded'})")
+    print(f"{'''''' if epoch==-1 else '''Epoch ''' + str(epoch) + ''': '''}Accuracy of the network on the {total} test images: {100 * correct / float(total)} %")
     t = timer.timeit()
-    print(f"Delta time: {t[0]}, Already: {t[1]}\n")
+    print(f"Evaluate delta time: {t[0]}, Already: {t[1]}\n")
     model.train()
-    if save_model:
-        if configs.DDP_ON:
-            torch.save(model.module.state_dict(), configs.MODEL_DIR + f"{100 * correct / total}".replace('.', '_') + '.pth')
-        else:
-            torch.save(model.state_dict(), configs.MODEL_DIR + f"{100 * correct / total}".replace('.', '_') + '.pth')
+    
+    if configs.DDP_ON:
+        torch.save(model.module.state_dict(), configs._MODEL_DIR + f"{100 * correct / total}".replace('.', '_') + '.pth')
+    else:
+        torch.save(model.state_dict(), configs._MODEL_DIR + f"{100 * correct / total}".replace('.', '_') + '.pth')
 
 
-def find_best3(local_rank, rand=False):
-    files = next(walk(configs.MODEL_DIR), (None, None, []))[2]
+def find_best_n_model(local_rank, n=5, rand=False):
+    files = next(walk(configs._MODEL_DIR), (None, None, []))[2]
     if len(files) == 0:
         return ''
     acc = sorted([float(i.split('.')[0].replace('_', '.')) for i in files], reverse=True)
-    best_acc = acc[:3]
+    best_acc = acc[:n]
     
-    for i in acc[3:]:
+    for i in acc[n:]:
         try:
-            os.remove(configs.MODEL_DIR + "/" + str(i).replace('.', '_') + ".pth")
+            os.remove(configs._MODEL_DIR + "/" + str(i).replace('.', '_') + ".pth")
         except:
             continue
             
         
-    model_name = str(best_acc[randrange(3) if (rand and len(acc[:3]) == 3) else 0]).replace('.', '_') + ".pth"
+    model_name = str(best_acc[randrange(n) if (rand and len(acc[:n]) == n) else 0]).replace('.', '_') + ".pth"
     if local_rank == 0:
-        print(f"Loading one of top 3 best model: {model_name}\n")
+        print(f"Loading one of the top {n} best model: {model_name}\n")
     return "/" + model_name
 
 
-def remove_bad_models():
-    files = next(walk(configs.MODEL_DIR), (None, None, []))[2]
+def remove_bad_models(n=5):
+    files = next(walk(configs._MODEL_DIR), (None, None, []))[2]
     if len(files) == 0:
         return
     acc = sorted([float(i.split('.')[0].replace('_', '.')) for i in files], reverse=True)
-    for i in acc[3:]:
+    for i in acc[n:]:
         try:
-            os.remove(configs.MODEL_DIR + "/" + str(i).replace('.', '_') + ".pth")
+            os.remove(configs._MODEL_DIR + "/" + str(i).replace('.', '_') + ".pth")
         except:
             continue
-- 
GitLab