Skip to content
Snippets Groups Projects
Commit 09e8aa1c authored by nuoyanc's avatar nuoyanc
Browse files

Half way there, and reload gitignore.

parent 76d7a906
Branches
No related tags found
No related merge requests found
Showing
with 862 additions and 0 deletions
data/
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# AWS User-specific
.idea/**/aws.xml
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# SonarLint plugin
.idea/sonarlint/
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
\ No newline at end of file
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
<?xml version="1.0" encoding="UTF-8"?>
<module type="JAVA_MODULE" version="4">
<component name="FacetManager">
<facet type="Python" name="Python">
<configuration sdkName="" />
</facet>
</component>
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyNamespacePackagesService">
<option name="namespacePackageFolders">
<list>
<option value="$MODULE_DIR$/src" />
</list>
</option>
</component>
</module>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPep8Inspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<option name="ignoredErrors">
<list>
<option value="E722" />
</list>
</option>
</inspection_tool>
</profile>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" languageLevel="JDK_17" project-jdk-name="Python 3.9 (base)" project-jdk-type="Python SDK">
<output url="file://$PROJECT_DIR$/out" />
</component>
<component name="UnattendedHostPersistenceState">
<option name="openedFilesInfos">
<list>
<OpenedFileInfo>
<option name="caretOffset" value="2427" />
<option name="fileUrl" value="file://$PROJECT_DIR$/src/preprocess/preprocess.py" />
</OpenedFileInfo>
<OpenedFileInfo>
<option name="caretOffset" value="5280" />
<option name="fileUrl" value="file://$PROJECT_DIR$/main.py" />
</OpenedFileInfo>
<OpenedFileInfo>
<option name="caretOffset" value="2434" />
<option name="fileUrl" value="file://$PROJECT_DIR$/src/settings/configs.py" />
</OpenedFileInfo>
</list>
</option>
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/bmlf.iml" filepath="$PROJECT_DIR$/.idea/bmlf.iml" />
</modules>
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>
\ No newline at end of file
{
"python.analysis.extraPaths": [
"./src"
]
}
\ No newline at end of file
main.py 0 → 100644
from src.models import ini_model
from src.preprocess import Preprocessor
from src.settings import configs
from src.utils import Timer, find_best3, eval_total
# OMP_NUM_THREADS=2 python -m torch.distributed.run --nproc_per_node 4 90plus.py
import datetime
import os
import torch
import torch.nn as nn
import torch.optim as optim
import gc
from subprocess import call
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def train():
if configs.DDP_ON:
# DDP backend initialization
configs.LOCAL_RANK = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(configs.LOCAL_RANK)
dist.init_process_group(backend='nccl')
else:
configs.LOCAL_RANK = 0
model = ini_model()
trainloader, testloader = Preprocessor().get_loader()
# Start timer from here
timer = Timer()
timer.timeit()
if configs.LOAD_MODEL and configs.LOCAL_RANK == 0:
print(f"\nVerifying loaded model ({configs.MODEL_NAME})'s accuracy as its name suggested...")
eval_total(model, testloader, timer)
if configs.LOCAL_RANK == 0:
print(f"Start training! Total {configs.TOTAL_EPOCHS} epochs.\n")
return
# Define loss function and optimizer for the following training process
criterion = nn.CrossEntropyLoss()
opt1 = optim.Adam(model.parameters(), lr=configs.LEARNING_RATE)
opt2 = optim.SGD(model.parameters(), lr=configs.LEARNING_RATE, momentum=0.90)
opts = [opt2, opt1]
opt_use_adam = configs.OPT_USE_ADAM
# Mixed precision for speed up
# https://zhuanlan.zhihu.com/p/165152789
scalar = torch.cuda.amp.GradScaler()
# ========================== Train =============================
for epoch in range(configs.TOTAL_EPOCHS):
if epoch%configs.LEARNING_RATE_UPDATE_EPOCH == configs.LEARNING_RATE_UPDATE_EPOCH - 1:
configs.LEARNING_RATE *= configs.LEARNING_RATE_UPDATE_RATE
if configs.LEARNING_RATE <= configs.LEARNING_RATE_END:
configs.LEARNING_RATE = configs.LEARNING_RATE_END
print(f"Learning rate updated to {configs.LEARNING_RATE}\n")
opt1 = optim.Adam(model.parameters(), lr=configs.LEARNING_RATE)
opt2 = optim.SGD(model.parameters(), lr=configs.LEARNING_RATE, momentum=0.90)
# To avoid duplicated data sent to multi-gpu
trainloader.sampler.set_epoch(epoch)
# Just for removing worst models
if epoch % configs.EPOCH_TO_LOAD_BEST == 0:
remove_bad_models()
# By my stategy, chose optimizer dynamically
optimizer = opts[int(opt_use_adam)]
# Counter for printing information during training
count_log = 0 if configs.N_LOGS_PER_EPOCH == 0 else int(len(trainloader) / configs.N_LOGS_PER_EPOCH)
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# Speed up with half precision
with torch.cuda.amp.autocast():
# forward + backward + optimize
outputs = model(inputs.to(device))
loss = criterion(outputs, labels.to(device))
# Scale the gradient
scalar.scale(loss).backward()
scalar.step(optimizer)
scalar.update()
# print statistics
running_loss += loss.item() * inputs.shape[0]
if count_log != 0 and local_rank == 0 and i % count_log == count_log - 1:
print(f'[{epoch + 1}(Epochs), {i + 1:5d}(batches)] loss: {running_loss / count_log:.3f}')
running_loss = 0.0
# Switch to another optimizer after some epochs
if configs.ADAM_SGD_SWITCH:
if epoch % configs.EPOCHS_PER_SWITCH == configs.EPOCHS_PER_SWITCH - 1:
opt_use_adam = not opt_use_adam
print(f"Epoch {epoch + 1}: Opt switched to {'Adam' if opt_use_adam else 'SGD'}")
# Evaluate model on main GPU after some epochs
if local_rank == 0 and epoch % configs.EPOCHS_PER_EVAL == configs.EPOCHS_PER_EVAL - 1:
eval_total(model, testloader, timer, device, epoch)
print(f'Training Finished! ({str(datetime.timedelta(seconds=int(timer.timeit())))})')
if __name__ == '__main__':
try:
# gc.collect()
torch.cuda.empty_cache()
configs.reset_working_dir(__file__)
train()
except KeyboardInterrupt:
print("Exit!")
from .load_model import ini_model
\ No newline at end of file
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
class Net(nn.Module):
def __init__(self):
super().__init__()
def make_sequence(in_channels, increase_channels=False):
middle_channels = in_channels*2 if increase_channels else in_channels
first_stride = 2 if increase_channels else 1
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=middle_channels, dilation=1, kernel_size=3, padding=1, stride=first_stride),
nn.BatchNorm2d(num_features=middle_channels, eps=0.000001, momentum=0.9),
nn.ReLU(),
nn.Conv2d(in_channels=middle_channels, out_channels=middle_channels, dilation=1, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(num_features=middle_channels, eps=0.000001, momentum=0.9),
)
def make_sequence_left(in_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=in_channels*2, dilation=1, kernel_size=1, padding=0, stride=2),
nn.BatchNorm2d(num_features=in_channels*2, eps=0.000001, momentum=0.9),
)
self.activation_func = F.leaky_relu
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1, stride=1)
self.bn1 = nn.BatchNorm2d(num_features=64, eps=0.000001, momentum=0.9)
# self.pool = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
self.seq1_r = make_sequence(in_channels=64)
self.seq2_l = make_sequence_left(64)
self.seq2_r = make_sequence(in_channels=64, increase_channels=True)
self.seq3_r = make_sequence(in_channels=128)
self.seq4_l = make_sequence_left(128)
self.seq4_r = make_sequence(in_channels=128, increase_channels=True)
self.seq5_r = make_sequence(in_channels=256)
self.seq6_l = make_sequence_left(256)
self.seq6_r = make_sequence(in_channels=256, increase_channels=True)
self.seq7_r = make_sequence(in_channels=512)
self.seq_end = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(512, 10)
)
# self.resnet18 = models.resnet50(num_classes=10)
# self.resnet18._modules['conv1'] = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1, stride=1)
# self.resnet18._modules['maxpool'] = nn.MaxPool2d(kernel_size=1)
def forward(self, x):
# return self.resnet18(x)
# Before the first block
x = self.activation_func(self.bn1(self.conv1(x)))
# x = self.pool(x)
# block 1_1
x = x + self.seq1_r(x)
x = self.activation_func(x)
# block 1_2
x = x + self.seq1_r(x)
x = self.activation_func(x)
# block 2_1
x = self.seq2_l(x) + self.seq2_r(x)
x = self.activation_func(x)
# block 3_1
x = x + self.seq3_r(x)
x = self.activation_func(x)
# block 4
x = self.seq4_l(x) + self.seq4_r(x)
x = self.activation_func(x)
# block 5
x = x + self.seq5_r(x)
x = self.activation_func(x)
# block 6
x = self.seq6_l(x) + self.seq6_r(x)
x = self.activation_func(x)
# block 7
x = x + self.seq7_r(x)
x = self.activation_func(x)
# end
x = self.seq_end(x)
return x
\ No newline at end of file
# OMP_NUM_THREADS=2 python -m torch.distributed.run --nproc_per_node 4 90plus.py
import datetime
import os
import torch
import torch.nn as nn
import torch.optim as optim
import gc
from subprocess import call
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from ..settings import configs, model
from ..utils import Timer, find_best3, eval_total
def ini_model():
local_rank = configs.LOCAL_RANK
global model
# Load model to gpu
device = torch.device("cuda", local_rank)
configs.DEVICE = device
# Check if load specific model or load best model in model folder
if configs.LOAD_MODEL:
if configs.LOAD_BEST:
configs.MODEL_NAME = find_best3(local_rank)
try:
print(configs.MODEL_DIR + configs.MODEL_NAME)
model.load_state_dict(torch.load(configs.MODEL_DIR + configs.MODEL_NAME))
except FileNotFoundError or IsADirectoryError:
print(f"{configs.MODEL_NAME} Model not found!")
# Move loaded model with parameters to gpus
# Then warp with DDP, reducer will be constructed too.
model.to(device)
if configs.DDP_ON:
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
return model
\ No newline at end of file
from .preprocess import Preprocessor
\ No newline at end of file
import random
from typing import Tuple
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import CIFAR10
from ..settings.configs import configs
import math
class Preprocessor:
# Official data augmentation for CIFAR10 dataset
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4, padding_mode="constant"),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def __init__(self,
trans_train=transform_train,
trans_test=transform_test) -> None:
self.trans_train = trans_train
self.trans_test = trans_test
self.loader = None
def get_loader(self)->Tuple[DataLoader, DataLoader]:
if self.loader is not None:
return self.loader
data_dir = configs.DATA_DIR
batch_size = configs.BATCH_SIZE
n_workers = configs.NUM_WORKERS
train_set = CIFAR10(root=data_dir, train=True,
download=True, transform=self.transform_train)
test_set = CIFAR10(root=data_dir, train=False,
download=True, transform=self.transform_test)
if configs.DDP_ON:
train_sampler = DistributedSampler(train_set)
train_loader = DataLoader(train_set, batch_size=batch_size,
sampler=train_sampler)
else:
train_loader = DataLoader(train_set, batch_size=batch_size,
shuffle=True, num_workers=n_workers)
# Test with whole test set, no need for distributed sampler
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=n_workers)
# Return two iterables which contain data in blocks, block size equals to batch size
return train_loader, test_loader
def visualize_data(self, n=9, train=True, rand=True)->None:
loader = self.get_loader()[int(not train)]
wid = int(math.floor(math.sqrt(n)))
if wid * wid < n:
wid += 1
fig = plt.figure(figsize=(2 * wid, 2 * wid))
print(wid)
for i in range(n):
if rand:
index = random.randint(0, len(loader.dataset) - 1)
else:
index = i
# Add subplot to corresponding position
fig.add_subplot(wid, wid, i + 1)
plt.imshow((np.transpose(loader.dataset[index][0].numpy(), (1, 2, 0))))
plt.axis('off')
plt.title(configs.CLASSES[loader.dataset[index][1]])
fig.show()
from .configs import configs
from .model_initialize import model
import os
import json
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, [0, 1, 2, 3, 4, 5, 6, 7]))
class Config:
def __init__(self, *dict_config) -> None:
# ==============================================
# GLOBAL SETTINGS
self.DDP_ON: bool = True
self.BATCH_SIZE: int = 512
self.LEARNING_RATE: float = 1e-3
self.LEARNING_RATE_UPDATE_EPOCH: int = 30
self.LEARNING_RATE_UPDATE_RATE: float = 0.12
self.LEARNING_RATE_END: float = 1e-5
self.TOTAL_EPOCHS: int = 5000
self.OPT_USE_ADAM: bool = True
self.LOAD_MODEL: bool = True
self.MODEL_NAME: str = "10X92.pth"
self.LOAD_BEST: bool = False
self.EPOCH_TO_LOAD_BEST: int = 15
self.MODEL_SAVE_THRESHOLD: float = 0
self.NUM_WORKERS: int = 4
self.N_LOGS_PER_EPOCH: int = 0
# ==============================================
# SPECIAL SETTINGS
self.EPOCHS_PER_EVAL: int = 2
self.ADAM_SGD_SWITCH: bool = True
self.EPOCHS_PER_SWITCH: int = 30
# ==============================================
# NOT SUPPOSED TO BE CHANGED OFTEN
self.WORKING_DIR: str = os.path.dirname(os.path.realpath(__file__))
self.MODEL_DIR: str = self.WORKING_DIR + "/models_v100/"
self.DATA_DIR: str = self.WORKING_DIR + '/data/'
self.CLASSES: tuple = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
self.DEVICE = None
self.LOCAL_RANK = None
if len(dict_config) != 0:
d = eval(dict_config[0])
for k in dict(d):
setattr(self, k, d[k])
def reset_working_dir(self, main_dir):
self.WORKING_DIR: str = os.path.dirname(os.path.realpath(main_dir))
self.MODEL_DIR: str = self.WORKING_DIR + "/models_v100/"
self.DATA_DIR: str = self.WORKING_DIR + '/data/'
if not os.path.exists(self.MODEL_DIR):
os.makedirs(self.MODEL_DIR)
def save(self, fn='/config.json'):
with open(self.WORKING_DIR + fn, 'w') as fp:
json.dump(str(self.__dict__), fp, indent=4)
def load(self, fn='/config.json'):
try:
with open(self.WORKING_DIR + fn, 'r') as fp:
dict_config = json.load(fp)
d = eval(dict_config)
for k in dict(d):
setattr(self, k, d[k])
print("Config file loaded successfully!")
except:
print("Config file does not exits, use default value instead!")
configs = Config()
# configs.load()
# configs.save()
from ..models import aresnet
model = aresnet.Net()
\ No newline at end of file
from .utils import *
\ No newline at end of file
import time
import torch
from ..settings import configs
from random import randrange
import os
from os import walk
class Timer:
def __init__(self):
self.ini = time.time()
self.last = 0
self.curr = 0
def timeit(self)->float:
if self.last == 0 and self.curr == 0:
self.last = time.time()
self.curr = time.time()
return 0, 0
else:
self.last = self.curr
self.curr = time.time()
return time.strftime("%H:%M:%S",time.gmtime(round(self.curr - self.last, 2))), time.strftime("%H:%M:%S",time.gmtime(round(self.curr - self.ini, 2)))
def eval_total(model, testloader, timer, epoch=-1):
# Only neccessary to evaluate model on one gpu
if configs.LOCAL_RANK != 0:
return
device = configs.DEVICE
model.eval()
correct = 0
total = 0
# since we're not training, we don't need to calculate the
# gradients for our outputs
with torch.no_grad():
for data in testloader:
images, labels = data
# calculate outputs by running images through the network
outputs = model(images.to(device))
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.cpu().data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
save_model = 100 * correct / total >= configs.MODEL_SAVE_THRESHOLD
print(f"{'''''' if epoch==-1 else '''Epoch ''' + str(epoch) + ''': '''}Accuracy of the network on the {total} test images: {100 * correct / float(total)} % ({'saved' if save_model else 'discarded'})")
t = timer.timeit()
print(f"Delta time: {t[0]}, Already: {t[1]}\n")
model.train()
if save_model:
if configs.DDP_ON:
torch.save(model.module.state_dict(), configs.MODEL_DIR + f"{100 * correct / total}".replace('.', '_') + '.pth')
else:
torch.save(model.state_dict(), configs.MODEL_DIR + f"{100 * correct / total}".replace('.', '_') + '.pth')
def find_best3(local_rank, rand=False):
files = next(walk(configs.MODEL_DIR), (None, None, []))[2]
if len(files) == 0:
return ''
acc = sorted([float(i.split('.')[0].replace('_', '.')) for i in files], reverse=True)
best_acc = acc[:3]
for i in acc[3:]:
try:
os.remove(configs.MODEL_DIR + "/" + str(i).replace('.', '_') + ".pth")
except:
continue
model_name = str(best_acc[randrange(3) if (rand and len(acc[:3]) == 3) else 0]).replace('.', '_') + ".pth"
if local_rank == 0:
print(f"Loading one of top 3 best model: {model_name}\n")
return "/" + model_name
def remove_bad_models():
files = next(walk(configs.MODEL_DIR), (None, None, []))[2]
if len(files) == 0:
return
acc = sorted([float(i.split('.')[0].replace('_', '.')) for i in files], reverse=True)
for i in acc[3:]:
try:
os.remove(configs.MODEL_DIR + "/" + str(i).replace('.', '_') + ".pth")
except:
continue
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment