distributed_training_system.py
import os
import time
import logging
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F
from typing import Dict, Any, Optional, Tuple
import numpy as np
from pathlib import Path
from dataclasses import dataclass
from torch.utils.tensorboard import SummaryWriter
Configure logging
logging.basicConfig(
level=logging.INFO,
format=’%(asctime)s – %(name)s – %(levelname)s – %(message)s’
)
logger = logging.getLogger(name)
@dataclass
class TrainingConfig:
“””Training configuration parameters”””
batch_size: int = 32
learning_rate: float = 1e-3
num_epochs: int = 10
save_every: int = 5
grad_accumulation_steps: int = 1
mixed_precision: bool = True
checkpoint_dir: str = “checkpoints”
log_dir: str = “logs”
class DistributedTrainer:
def init(
self,
model: nn.Module,
config: TrainingConfig,
rank: int,
world_size: int
):
self.model = model
self.config = config
self.rank = rank
self.world_size = world_size
self.device = torch.device(f’cuda:{rank}’)
# Initialize distributed training
self._setup_distributed()
# Wrap model in DDP
self.model = self._prepare_model()
# Initialize optimizer, scheduler and scaler
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=config.learning_rate
)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=config.num_epochs
)
self.scaler = GradScaler(enabled=config.mixed_precision)
# Setup tensorboard if main process
if self.rank == 0:
self.writer = SummaryWriter(log_dir=config.log_dir)
# Create checkpoint directory
if self.rank == 0:
Path(config.checkpoint_dir).mkdir(exist_ok=True)
def _setup_distributed(self):
"""Initialize the distributed environment"""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# Initialize process group
dist.init_process_group(
"nccl",
rank=self.rank,
world_size=self.world_size
)
# Set device
torch.cuda.set_device(self.rank)
def _prepare_model(self) -> nn.Module:
"""Prepare model for distributed training"""
self.model = self.model.to(self.device)
return DDP(
self.model,
device_ids=[self.rank],
output_device=self.rank,
find_unused_parameters=False
)
def _save_checkpoint(self, epoch: int, loss: float):
"""Save training checkpoint"""
if self.rank == 0:
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.module.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'loss': loss,
}
path = Path(self.config.checkpoint_dir) / f'checkpoint_epoch_{epoch}.pt'
torch.save(checkpoint, path)
logger.info(f'Saved checkpoint: {path}')
def _load_checkpoint(self, path: str) -> Dict[str, Any]:
"""Load training checkpoint"""
map_location = {'cuda:%d' % 0: 'cuda:%d' % self.rank}
checkpoint = torch.load(path, map_location=map_location)
self.model.module.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
return checkpoint
def train_epoch(
self,
train_loader: DataLoader,
epoch: int
) -> float:
"""Train for one epoch"""
self.model.train()
total_loss = 0.0
num_batches = len(train_loader)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(self.device), target.to(self.device)
# Calculate effective batch size for gradient accumulation
effective_batch_idx = batch_idx // self.config.grad_accumulation_steps
# Forward pass with mixed precision
with autocast(enabled=self.config.mixed_precision):
output = self.model(data)
loss = F.cross_entropy(output, target)
loss = loss / self.config.grad_accumulation_steps
# Backward pass with gradient scaling
self.scaler.scale(loss).backward()
# Update weights if we've accumulated enough gradients
if (batch_idx + 1) % self.config.grad_accumulation_steps == 0:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
# Log metrics
if batch_idx % 50 == 0:
logger.info(
f'Rank {self.rank}, Epoch {epoch}, '
f'Batch {batch_idx}/{num_batches}, '
f'Loss: {loss.item():.4f}'
)
total_loss += loss.item()
# Log to tensorboard
if self.rank == 0 and batch_idx % 10 == 0:
step = epoch * num_batches + batch_idx
self.writer.add_scalar(
'training_loss',
loss.item(),
step
)
self.writer.add_scalar(
'learning_rate',
self.scheduler.get_last_lr()[0],
step
)
# Average loss across all processes
avg_loss = total_loss / num_batches
dist.all_reduce(
torch.tensor(avg_loss).to(self.device),
op=dist.ReduceOp.AVG
)
return avg_loss
def validate(
self,
val_loader: DataLoader
) -> Tuple[float, float]:
"""Validate the model"""
self.model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(self.device), target.to(self.device)
with autocast(enabled=self.config.mixed_precision):
output = self.model(data)
loss = F.cross_entropy(output, target)
val_loss += loss.item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
# Average metrics across all processes
metrics = torch.tensor([val_loss, correct, total]).to(self.device)
dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
val_loss, correct, total = metrics.tolist()
return val_loss / len(val_loader), correct / total
def train(
self,
train_loader: DataLoader,
val_loader: Optional[DataLoader] = None,
resume_from: Optional[str] = None
):
"""Main training loop"""
start_epoch = 0
# Resume training if checkpoint provided
if resume_from:
checkpoint = self._load_checkpoint(resume_from)
start_epoch = checkpoint['epoch'] + 1
logger.info(f'Resuming training from epoch {start_epoch}')
for epoch in range(start_epoch, self.config.num_epochs):
# Set epoch for distributed sampler
train_loader.sampler.set_epoch(epoch)
# Train one epoch
train_loss = self.train_epoch(train_loader, epoch)
# Validate if validation loader provided
if val_loader:
val_loss, accuracy = self.validate(val_loader)
if self.rank == 0:
logger.info(
f'Epoch {epoch}: '
f'Train Loss: {train_loss:.4f}, '
f'Val Loss: {val_loss:.4f}, '
f'Accuracy: {accuracy:.4f}'
)
self.writer.add_scalar('validation_loss', val_loss, epoch)
self.writer.add_scalar('validation_accuracy', accuracy, epoch)
# Save checkpoint
if epoch % self.config.save_every == 0:
self._save_checkpoint(epoch, train_loss)
# Update learning rate
self.scheduler.step()
# Final cleanup
if self.rank == 0:
self.writer.close()
dist.destroy_process_group()
def main():
# Training configuration
config = TrainingConfig(
batch_size=32,
learning_rate=1e-3,
num_epochs=10,
save_every=2,
grad_accumulation_steps=4,
mixed_precision=True
)
# Define your model
class YourModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
def forward(self, x):
return self.layers(x)
def train_worker(rank: int, world_size: int):
# Create trainer
trainer = DistributedTrainer(
model=YourModel(),
config=config,
rank=rank,
world_size=world_size
)
# Create dataset (replace with your actual dataset)
train_dataset = torch.randn(1000, 784)
train_labels = torch.randint(0, 10, (1000,))
# Create distributed sampler
train_sampler = DistributedSampler(
train_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True
)
# Create data loader
train_loader = DataLoader(
list(zip(train_dataset, train_labels)),
batch_size=config.batch_size,
sampler=train_sampler,
num_workers=4,
pin_memory=True
)
# Train the model
trainer.train(train_loader)
# Launch training processes
world_size = torch.cuda.device_count()
logger.info(f'Starting training with {world_size} GPUs')
mp.spawn(
train_worker,
args=(world_size,),
nprocs=world_size,
join=True
)
if name == “main“:
main()