-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
79 lines (63 loc) · 2.61 KB
/
train.py
File metadata and controls
79 lines (63 loc) · 2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import omegaconf
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from src.data.utils import get_data_loaders
from src.callbacks.callbacks import get_callbacks
from pytorch_lightning.callbacks import ModelCheckpoint
import os
import hydra
from hydra.utils import instantiate
from hydra.core.hydra_config import HydraConfig
from src.utils import get_wandb_logger
from omegaconf import OmegaConf
@hydra.main(version_base=None, config_path="configs")
def main(config: omegaconf.DictConfig):
# Logs and data directories
log_dir = os.environ['LOGDIR'] if 'LOGDIR' in os.environ else './'
config.data.root = os.path.join(log_dir, 'data', config.data.root)
config.train.log_dir = os.path.join(log_dir, 'logs', 'EBM_Hackathon')
# Model
model = instantiate(config.model, cfg=config)
# Data loaders
loaders = get_data_loaders(config) #[train_loader, val_loader, ?test_loader]
# Wandb Logger
config_name = HydraConfig.get().job.config_name
wandb_logger, log_dir = get_wandb_logger(config, config_name)
# Callbacks
callbacks = get_callbacks(config, loaders)
# Add ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
dirpath=os.path.join(log_dir, "checkpoints"),
filename="{epoch}-{val_loss:.2f}",
save_top_k=1,
save_last=True,
monitor="val_loss",
mode="min",
)
callbacks.append(checkpoint_callback)
# Activate mixed precision
if torch.cuda.is_available() and config.train.precision == 16:
torch.set_float32_matmul_precision('medium')
# Always resume from last.ckpt if resume is enabled
ckpt_path = None
if config.train.get("resume", False):
ckpt_path = os.path.join(log_dir, "checkpoints", config.train.get("ckpt", "last.ckpt"))
print(f"Resuming from checkpoint: {ckpt_path}")
# Trainer
trainer = Trainer(
max_epochs=config.train.epochs,
logger=wandb_logger,
accelerator=config.train.accelerator,
devices=config.train.devices if torch.cuda.is_available() else 1,
strategy=instantiate(config.train.strategy) if '_target_' in config.train.strategy else config.train.strategy,
precision="16-mixed" if torch.cuda.is_available() and config.train.precision == 16 else 32,
default_root_dir=log_dir,
callbacks=callbacks,
gradient_clip_val=config.train.gradient_clip_val,
)
trainer.fit(model, *loaders[:2], ckpt_path=ckpt_path)
if config.data.get("test_dataset", False):
trainer.test(model, dataloaders=loaders[-1])
if __name__ == "__main__":
main()