diff --git a/beginner_source/basics/quickstart_tutorial.py b/beginner_source/basics/quickstart_tutorial.py index 0cf469f31f..b3f3123887 100644 --- a/beginner_source/basics/quickstart_tutorial.py +++ b/beginner_source/basics/quickstart_tutorial.py @@ -140,15 +140,16 @@ def train(dataloader, model, loss_fn, optimizer): model.train() for batch, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) - + + # Zero the gradients for batch + optimizer.zero_grad() # Compute prediction error pred = model(X) loss = loss_fn(pred, y) - # Backpropagation loss.backward() + # Optimizer step optimizer.step() - optimizer.zero_grad() if batch % 100 == 0: loss, current = loss.item(), (batch + 1) * len(X)