distributed_training_system.py

distributed_training_system.py

import os
import time
import logging
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F
from typing import Dict, Any, Optional, Tuple
import numpy as np
from pathlib import Path
from dataclasses import dataclass
from torch.utils.tensorboard import SummaryWriter

Configure logging

logging.basicConfig(
level=logging.INFO,
format=’%(asctime)s – %(name)s – %(levelname)s – %(message)s’
)
logger = logging.getLogger(name)

@dataclass
class TrainingConfig:
“””Training configuration parameters”””
batch_size: int = 32
learning_rate: float = 1e-3
num_epochs: int = 10
save_every: int = 5
grad_accumulation_steps: int = 1
mixed_precision: bool = True
checkpoint_dir: str = “checkpoints”
log_dir: str = “logs”

class DistributedTrainer:
def init(
self,
model: nn.Module,
config: TrainingConfig,
rank: int,
world_size: int
):
self.model = model
self.config = config
self.rank = rank
self.world_size = world_size
self.device = torch.device(f’cuda:{rank}’)

    # Initialize distributed training
    self._setup_distributed()

    # Wrap model in DDP
    self.model = self._prepare_model()

    # Initialize optimizer, scheduler and scaler
    self.optimizer = torch.optim.AdamW(
        self.model.parameters(),
        lr=config.learning_rate
    )
    self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        self.optimizer,
        T_max=config.num_epochs
    )
    self.scaler = GradScaler(enabled=config.mixed_precision)

    # Setup tensorboard if main process
    if self.rank == 0:
        self.writer = SummaryWriter(log_dir=config.log_dir)

    # Create checkpoint directory
    if self.rank == 0:
        Path(config.checkpoint_dir).mkdir(exist_ok=True)

def _setup_distributed(self):
    """Initialize the distributed environment"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # Initialize process group
    dist.init_process_group(
        "nccl",
        rank=self.rank,
        world_size=self.world_size
    )

    # Set device
    torch.cuda.set_device(self.rank)

def _prepare_model(self) -> nn.Module:
    """Prepare model for distributed training"""
    self.model = self.model.to(self.device)
    return DDP(
        self.model,
        device_ids=[self.rank],
        output_device=self.rank,
        find_unused_parameters=False
    )

def _save_checkpoint(self, epoch: int, loss: float):
    """Save training checkpoint"""
    if self.rank == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.module.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'loss': loss,
        }
        path = Path(self.config.checkpoint_dir) / f'checkpoint_epoch_{epoch}.pt'
        torch.save(checkpoint, path)
        logger.info(f'Saved checkpoint: {path}')

def _load_checkpoint(self, path: str) -> Dict[str, Any]:
    """Load training checkpoint"""
    map_location = {'cuda:%d' % 0: 'cuda:%d' % self.rank}
    checkpoint = torch.load(path, map_location=map_location)
    self.model.module.load_state_dict(checkpoint['model_state_dict'])
    self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    return checkpoint

def train_epoch(
    self,
    train_loader: DataLoader,
    epoch: int
) -> float:
    """Train for one epoch"""
    self.model.train()
    total_loss = 0.0
    num_batches = len(train_loader)

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(self.device), target.to(self.device)

        # Calculate effective batch size for gradient accumulation
        effective_batch_idx = batch_idx // self.config.grad_accumulation_steps

        # Forward pass with mixed precision
        with autocast(enabled=self.config.mixed_precision):
            output = self.model(data)
            loss = F.cross_entropy(output, target)
            loss = loss / self.config.grad_accumulation_steps

        # Backward pass with gradient scaling
        self.scaler.scale(loss).backward()

        # Update weights if we've accumulated enough gradients
        if (batch_idx + 1) % self.config.grad_accumulation_steps == 0:
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad()

        # Log metrics
        if batch_idx % 50 == 0:
            logger.info(
                f'Rank {self.rank}, Epoch {epoch}, '
                f'Batch {batch_idx}/{num_batches}, '
                f'Loss: {loss.item():.4f}'
            )

        total_loss += loss.item()

        # Log to tensorboard
        if self.rank == 0 and batch_idx % 10 == 0:
            step = epoch * num_batches + batch_idx
            self.writer.add_scalar(
                'training_loss',
                loss.item(),
                step
            )
            self.writer.add_scalar(
                'learning_rate',
                self.scheduler.get_last_lr()[0],
                step
            )

    # Average loss across all processes
    avg_loss = total_loss / num_batches
    dist.all_reduce(
        torch.tensor(avg_loss).to(self.device),
        op=dist.ReduceOp.AVG
    )

    return avg_loss

def validate(
    self,
    val_loader: DataLoader
) -> Tuple[float, float]:
    """Validate the model"""
    self.model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(self.device), target.to(self.device)

            with autocast(enabled=self.config.mixed_precision):
                output = self.model(data)
                loss = F.cross_entropy(output, target)

            val_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

    # Average metrics across all processes
    metrics = torch.tensor([val_loss, correct, total]).to(self.device)
    dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
    val_loss, correct, total = metrics.tolist()

    return val_loss / len(val_loader), correct / total

def train(
    self,
    train_loader: DataLoader,
    val_loader: Optional[DataLoader] = None,
    resume_from: Optional[str] = None
):
    """Main training loop"""
    start_epoch = 0

    # Resume training if checkpoint provided
    if resume_from:
        checkpoint = self._load_checkpoint(resume_from)
        start_epoch = checkpoint['epoch'] + 1
        logger.info(f'Resuming training from epoch {start_epoch}')

    for epoch in range(start_epoch, self.config.num_epochs):
        # Set epoch for distributed sampler
        train_loader.sampler.set_epoch(epoch)

        # Train one epoch
        train_loss = self.train_epoch(train_loader, epoch)

        # Validate if validation loader provided
        if val_loader:
            val_loss, accuracy = self.validate(val_loader)
            if self.rank == 0:
                logger.info(
                    f'Epoch {epoch}: '
                    f'Train Loss: {train_loss:.4f}, '
                    f'Val Loss: {val_loss:.4f}, '
                    f'Accuracy: {accuracy:.4f}'
                )
                self.writer.add_scalar('validation_loss', val_loss, epoch)
                self.writer.add_scalar('validation_accuracy', accuracy, epoch)

        # Save checkpoint
        if epoch % self.config.save_every == 0:
            self._save_checkpoint(epoch, train_loss)

        # Update learning rate
        self.scheduler.step()

    # Final cleanup
    if self.rank == 0:
        self.writer.close()
    dist.destroy_process_group()

def main():
# Training configuration
config = TrainingConfig(
batch_size=32,
learning_rate=1e-3,
num_epochs=10,
save_every=2,
grad_accumulation_steps=4,
mixed_precision=True
)

# Define your model
class YourModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        return self.layers(x)

def train_worker(rank: int, world_size: int):
    # Create trainer
    trainer = DistributedTrainer(
        model=YourModel(),
        config=config,
        rank=rank,
        world_size=world_size
    )

    # Create dataset (replace with your actual dataset)
    train_dataset = torch.randn(1000, 784)
    train_labels = torch.randint(0, 10, (1000,))

    # Create distributed sampler
    train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )

    # Create data loader
    train_loader = DataLoader(
        list(zip(train_dataset, train_labels)),
        batch_size=config.batch_size,
        sampler=train_sampler,
        num_workers=4,
        pin_memory=True
    )

    # Train the model
    trainer.train(train_loader)

# Launch training processes
world_size = torch.cuda.device_count()
logger.info(f'Starting training with {world_size} GPUs')
mp.spawn(
    train_worker,
    args=(world_size,),
    nprocs=world_size,
    join=True
)

if name == “main“:
main()