Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions climanet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def __init__(
)

# Convert to numpy once — all __getitem__ calls use these
self.daily_np = daily_mt.to_numpy().copy() # (M, T=31, H, W) float
self.monthly_np = monthly_m.to_numpy().copy() # (M, H, W) float
self.daily_np = daily_mt.to_numpy().copy().astype(np.float32) # (M, T=31, H, W) float
self.monthly_np = monthly_m.to_numpy().copy().astype(np.float32) # (M, H, W) float
self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool
self.daily_timef_np = daily_timef.to_numpy().copy() # (M,T=31, 4)
self.daily_timef_np = daily_timef.to_numpy().copy().astype(np.float32) # (M,T=31, 4)

# Store coordinate arrays
self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy()
Expand Down Expand Up @@ -148,19 +148,19 @@ def __getitem__(self, idx):

if self.land_mask_np is not None:
land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W)
land_tensor = torch.from_numpy(land_patch.copy()).bool()
land_tensor = torch.from_numpy(np.ascontiguousarray(land_patch)).bool()
else:
land_tensor = torch.zeros(ph, pw, dtype=torch.bool)

# Convert to tensors (from_numpy is zero-copy on contiguous arrays)
# (1, M, T, H, W)
daily_tensor = torch.from_numpy(daily_patch).float().unsqueeze(0)
daily_tensor = torch.from_numpy(daily_patch).unsqueeze(0)
# (M, H, W)
monthly_tensor = torch.from_numpy(monthly_patch).float()
monthly_tensor = torch.from_numpy(monthly_patch)
# (1, M, T, H, W)
daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0)
# ( M, T, 2)
daily_timef_tensor = torch.from_numpy(self.daily_timef_np).float()
daily_timef_tensor = torch.from_numpy(self.daily_timef_np)

# daily_mask: NaN locations that are NOT land
# Reshape land_tensor for broadcasting: (H, W) → (1, 1, 1, H, W)
Expand Down
11 changes: 10 additions & 1 deletion climanet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def predict_monthly_var(
device: str = "cpu",
run_dir: str = ".",
verbose: bool = True,
dataloader_num_workers: int = 2,
predict_threads: int | None = None,
):
"""
Predicts monthly variable values using a trained model and a provided dataset.
Expand All @@ -79,6 +81,8 @@ def predict_monthly_var(
device: The device to run the predictions on (e.g., 'cpu' or 'cuda').
run_dir: Directory to save log files and predictions.
verbose: If True, prints progress information during prediction.
dataloader_num_workers: how many subprocesses to use for data loading.
See torch DataLoader docs for details.
Returns:
A NumPy array, PyTorch tensor, or xarray Dataset containing the predicted values.
If return_loss is True, it also returns the average loss over the dataset.
Expand All @@ -92,7 +96,12 @@ def predict_monthly_var(

use_cuda = device == "cuda"
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, pin_memory=use_cuda
dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=use_cuda,
num_workers=dataloader_num_workers,# for data loading
persistent_workers=True, # keep workers alive between epochs
)

# Initialize an empty list to store predictions
Expand Down
33 changes: 16 additions & 17 deletions climanet/st_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,32 +79,32 @@ def forward(self, x, mask):
return x # (B, N_patches, embed_dim)

class CyclicTimeEmbedding(nn.Module):
"""Cyclical Temporal encoding using day-of-year and hour-of-day values in
"""Cyclical Temporal encoding using day-of-year and hour-of-day values in
combination sine and cosine functions

This module generates fixed (non-learnable) trigonometric temporal encodings
for the temporal dimension using the cyclcial phase encoded day-of-year and
This module generates fixed (non-learnable) trigonometric temporal encodings
for the temporal dimension using the cyclcial phase encoded day-of-year and
hour-of-day values extracted from the datetime associated with the input.
This represents a natural positional encoding on the temporal cycle related
This represents a natural positional encoding on the temporal cycle related
to the solar (tropical) year and the diurnal cycle.

The module uses fixed Fourier frequencies and mixed doy-hod terms to expand
the cyclic encoding to the embedding dimension and capture time of day and
day of year interactions. The returned encodings are intended to be added to
embeddings of the input data by the caller. The module does not perform the
The module uses fixed Fourier frequencies and mixed doy-hod terms to expand
the cyclic encoding to the embedding dimension and capture time of day and
day of year interactions. The returned encodings are intended to be added to
embeddings of the input data by the caller. The module does not perform the
additon.
"""

def __init__(self, embed_dim=128, include_cross=True):
"""
Initialize temporal encodings

Args:
embed_dim: Dimension of the embedding.The default is 128.
Many vision transformers use embedding dimensions that are multiples
of 64 (e.g., 64, 128, 256). This can be tuned.
include_cross: bool, default True. Also Create phase_doy +/- phase_hod
cross term emeddings
include_cross: bool, default True. Also Create phase_doy +/- phase_hod
cross term emeddings
"""

super().__init__()
Expand All @@ -127,11 +127,11 @@ def __init__(self, embed_dim=128, include_cross=True):
f"embed_dim must be an even multiple of num_phase_terms for fixed encoding."
f"Got embed_dim: {embed_dim} and num_phase_terms: {num_phase_terms}."
)

def forward(self, time_features):
"""
create encodings in of size embedding dimension

Args:
time_features: (B, M, T, D) ; D is base_dim

Expand Down Expand Up @@ -170,7 +170,7 @@ def forward(self, time_features):
emb_encode = emb_encode.view(B,M,T,-1) # flatten

return emb_encode



class TemporalPositionalEncoding(nn.Module):
Expand Down Expand Up @@ -290,7 +290,7 @@ def forward(self, x, M, T, H, W, time_features, padded_days_mask=None):
T: number of temporal tokens per month after temporal patching (Tp)
H: spatial height after spatial patching
W: spatial width after spatial patching
time_features: (B,M,T,2) containing cyclically phase encoded DOY and HOD
time_features: (B,M,T,2) containing cyclically phase encoded DOY and HOD
padded_days_mask: Optional boolean tensor of shape (B, M, T), bool,
True indicating which day tokens are padded (because some months
have fewer days). This is used to mask out padded tokens in attention computation.
Expand All @@ -305,7 +305,6 @@ def forward(self, x, M, T, H, W, time_features, padded_days_mask=None):
temp_emb = self.time_embed(time_features) # (B,M,T,emd_dim)
#expand spatially
temp_emb = temp_emb[:, None, :, :, :] #[B, 1, M, T, C]
temp_emb = temp_emb.expand(-1, H*W, -1, -1, -1)
pe_months = self.pos_months(M).to(seq.device).to(seq.dtype) # (M, C)

seq = seq + temp_emb # add temporal embeddings
Expand Down Expand Up @@ -689,7 +688,7 @@ def forward(self, daily_data, daily_mask, daily_timef, land_mask_patch, padded_d
daily_data: Tensor of shape (B, C, M, T, H, W) containing daily
data, where C is the number of channels (e.g., 1 for SST)
daily_mask: Boolean tensor of same shape as daily_data indicating missing values
daily_timef: Tensor of shape (B, M, T, 2) containing the cyclically phase encoded day-of-year
daily_timef: Tensor of shape (B, M, T, 2) containing the cyclically phase encoded day-of-year
and hour-of-day information for the daily data
land_mask_patch: Boolean tensor of shape (B, H, W) to mask land areas in the output
padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded
Expand Down
67 changes: 43 additions & 24 deletions climanet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,39 @@
from climanet.utils import setup_logging, compute_masked_loss, save_model


def _run_one_batch(model: torch.nn.Module, batch: dict):
pred = model(
batch["daily_patch"],
batch["daily_mask_patch"],
batch["daily_timef_patch"],
batch["land_mask_patch"],
batch["padded_days_mask"],
) # (B, M, H, W)

# Compute masked loss
return compute_masked_loss(
pred, batch["monthly_patch"], batch["land_mask_patch"]
)


def _compute_stats(dataset: Dataset):
# check if dataset has indices attribute for stats calculation
base_dataset = dataset.dataset if hasattr(dataset, "dataset") else dataset
indices = dataset.indices if hasattr(dataset, "indices") else None
mean, std = base_dataset.compute_stats(indices)
return mean, std


def _initialize_decoder(model: torch.nn.Module, dataset: Dataset):
mean, std = _compute_stats(dataset)
decoder = model.module.decoder if hasattr(model, 'module') else model.decoder
with torch.no_grad():
decoder.bias.copy_(torch.from_numpy(mean))
decoder.scale.copy_(torch.from_numpy(std) + 1e-6)

return model


def train_monthly_model(
model: torch.nn.Module,
dataset: Dataset,
Expand All @@ -22,6 +55,7 @@ def train_monthly_model(
store_model: bool = True,
device: str = "cpu",
verbose: bool = True,
dataloader_num_workers: int = 2,
):
"""Train the model to predict monthly data from daily data.
Args:
Expand All @@ -37,26 +71,22 @@ def train_monthly_model(
store_model: whether to save the best model to disk
device: device to run training on ("cpu" or "cuda")
verbose: whether to print training progress
dataloader_num_workers: how many subprocesses to use for data loading.
See torch DataLoader docs for details.
"""

# check if dataset has indices attribute for stats calculation
base_dataset = dataset.dataset if hasattr(dataset, "dataset") else dataset
indices = dataset.indices if hasattr(dataset, "indices") else None
mean, std = base_dataset.compute_stats(indices)

# Initialize the model
model = model.to(device)
decoder = model.decoder
with torch.no_grad():
decoder.bias.copy_(torch.from_numpy(mean))
decoder.scale.copy_(torch.from_numpy(std) + 1e-6)
model = _initialize_decoder(model, dataset)

# Create data loader
use_cuda = device == "cuda"
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=False,
pin_memory=use_cuda,
num_workers=dataloader_num_workers, # for data loading
persistent_workers=True, # keep workers alive between epochs
)

# Set up logging
Expand Down Expand Up @@ -86,19 +116,7 @@ def train_monthly_model(
optimizer.zero_grad()

for i, batch in enumerate(dataloader):
# Batch prediction
pred = model(
batch["daily_patch"],
batch["daily_mask_patch"],
batch["daily_timef_patch"],
batch["land_mask_patch"],
batch["padded_days_mask"],
) # (B, M, H, W)

# Compute masked loss
loss = compute_masked_loss(
pred, batch["monthly_patch"], batch["land_mask_patch"]
)
loss = _run_one_batch(model, batch)

# Scale loss for gradient accumulation
scaled_loss = loss / accumulation_steps
Expand Down Expand Up @@ -136,6 +154,7 @@ def train_monthly_model(
return_loss=True,
verbose=False,
run_dir=run_dir,
dataloader_num_workers=dataloader_num_workers,
)
writer.add_scalar("Loss/validation", avg_epoch_loss, epoch)

Expand Down
Loading