Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

25 normalize data on gpu #39

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ data
│ ├── parameter_std.pt - Std.-dev. of state parameters (create_parameter_weights.py)
│ ├── diff_mean.pt - Means of one-step differences (create_parameter_weights.py)
│ ├── diff_std.pt - Std.-dev. of one-step differences (create_parameter_weights.py)
│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (create_parameter_weights.py)
│ ├── flux_mean.pt - Mean of solar flux forcing (create_parameter_weights.py)
│ ├── flux_std.pt - Std.-dev. of solar flux forcing (create_parameter_weights.py)
│ └── parameter_weights.npy - Loss weights for different state parameters (create_parameter_weights.py)
├── dataset2
├── ...
Expand Down
24 changes: 9 additions & 15 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def main():
split="train",
subsample_step=1,
pred_length=63,
standardize=False,
) # Without standardization
loader = torch.utils.data.DataLoader(
ds, args.batch_size, shuffle=False, num_workers=args.n_workers
Expand Down Expand Up @@ -107,30 +106,25 @@ def main():
flux_mean = torch.mean(torch.stack(flux_means)) # (,)
flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,)
flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,)
flux_stats = torch.stack((flux_mean, flux_std))

print("Saving mean, std.-dev, flux_stats...")
print("Saving mean, std.-dev, flux_mean, flux_std...")
torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt"))
torch.save(std, os.path.join(static_dir_path, "parameter_std.pt"))
torch.save(flux_stats, os.path.join(static_dir_path, "flux_stats.pt"))
torch.save(flux_mean, os.path.join(static_dir_path, "flux_mean.pt"))
torch.save(flux_std, os.path.join(static_dir_path, "flux_std.pt"))

# Compute mean and std.-dev. of one-step differences across the dataset
print("Computing mean and std.-dev. for one-step differences...")
ds_standard = WeatherDataset(
config_loader.dataset.name,
split="train",
subsample_step=1,
pred_length=63,
standardize=True,
) # Re-load with standardization
loader_standard = torch.utils.data.DataLoader(
ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers
)
used_subsample_len = (65 // args.step_length) * args.step_length

diff_means = []
diff_squares = []
for init_batch, target_batch, _ in tqdm(loader_standard):
for init_batch, target_batch, _ in tqdm(loader):
# normalize the batch
init_batch = (init_batch - mean) / std
target_batch = (target_batch - mean) / std

batch = torch.cat((init_batch, target_batch), dim=1)
batch = torch.cat(
(init_batch, target_batch), dim=1
) # (N_batch, N_t', N_grid, d_features)
Expand Down
13 changes: 13 additions & 0 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,19 @@ def common_step(self, batch):

return prediction, target_states, pred_std

def on_after_batch_transfer(self, batch, dataloader_idx):
"""Normalize Batch data after transferring to the device."""
sadamov marked this conversation as resolved.
Show resolved Hide resolved
init_states, target_states, forcing_features = batch
init_states = (init_states - self.data_mean) / self.data_std
target_states = (target_states - self.data_mean) / self.data_std
forcing_features = (forcing_features - self.flux_mean) / self.flux_std
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now all forcing seem to be normalized with the flux statistics, but there is more forcing than flux. Note how this was only applied to the flux in WeatherDataset before.
Have you tested that this gives exactly the same tensors as before? (e.g. save the first batch to disk on main, check this out, save first batch and compare).

Copy link
Collaborator Author

@sadamov sadamov Jun 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, the forcings are now handled differently. I suggest to implement a new logic to handle forcings in #54. The user can define combined_vars that share statistics and also define vars that should not be normalized.

batch = (
init_states,
target_states,
forcing_features,
)
return batch

def training_step(self, batch):
"""
Train on single batch
Expand Down
30 changes: 5 additions & 25 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,6 @@
from tueplots import bundles, figsizes


def load_dataset_stats(dataset_name, device="cpu"):
"""
Load arrays with stored dataset statistics from pre-processing
"""
static_dir_path = os.path.join("data", dataset_name, "static")

def loads_file(fn):
return torch.load(
os.path.join(static_dir_path, fn), map_location=device
)

data_mean = loads_file("parameter_mean.pt") # (d_features,)
data_std = loads_file("parameter_std.pt") # (d_features,)

flux_stats = loads_file("flux_stats.pt") # (2,)
flux_mean, flux_std = flux_stats

return {
"data_mean": data_mean,
"data_std": data_std,
"flux_mean": flux_mean,
"flux_std": flux_std,
}


def load_static_data(dataset_name, device="cpu"):
"""
Load static files related to dataset
Expand Down Expand Up @@ -64,6 +39,9 @@ def loads_file(fn):
data_mean = loads_file("parameter_mean.pt") # (d_features,)
data_std = loads_file("parameter_std.pt") # (d_features,)

flux_mean = loads_file("flux_mean.pt") # (,)
flux_std = loads_file("flux_std.pt") # (,)

# Load loss weighting vectors
param_weights = torch.tensor(
np.load(os.path.join(static_dir_path, "parameter_weights.npy")),
Expand All @@ -78,6 +56,8 @@ def loads_file(fn):
"step_diff_std": step_diff_std,
"data_mean": data_mean,
"data_std": data_std,
"flux_mean": flux_mean,
"flux_std": flux_std,
"param_weights": param_weights,
}

Expand Down
4 changes: 2 additions & 2 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def plot_prediction(
1,
2,
figsize=(13, 7),
subplot_kw={"projection": data_config.coords_projection()},
subplot_kw={"projection": data_config.coords_projection},
)

# Plot pred and target
Expand Down Expand Up @@ -136,7 +136,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):

fig, ax = plt.subplots(
figsize=(5, 4.8),
subplot_kw={"projection": data_config.coords_projection()},
subplot_kw={"projection": data_config.coords_projection},
)

ax.coastlines() # Add coastline outlines
Expand Down
22 changes: 0 additions & 22 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
import numpy as np
import torch

# First-party
from neural_lam import utils


class WeatherDataset(torch.utils.data.Dataset):
"""
Expand All @@ -29,7 +26,6 @@ def __init__(
pred_length=19,
split="train",
subsample_step=3,
standardize=True,
subset=False,
control_only=False,
):
Expand Down Expand Up @@ -61,17 +57,6 @@ def __init__(
self.sample_length <= self.original_sample_length
), "Requesting too long time series samples"

# Set up for standardization
self.standardize = standardize
if standardize:
ds_stats = utils.load_dataset_stats(dataset_name, "cpu")
self.data_mean, self.data_std, self.flux_mean, self.flux_std = (
ds_stats["data_mean"],
ds_stats["data_std"],
ds_stats["flux_mean"],
ds_stats["flux_std"],
)

# If subsample index should be sampled (only duing training)
self.random_subsample = split == "train"

Expand Down Expand Up @@ -148,10 +133,6 @@ def __getitem__(self, idx):
sample = sample[init_id : (init_id + self.sample_length)]
# (sample_length, N_grid, d_features)

if self.standardize:
# Standardize sample
sample = (sample - self.data_mean) / self.data_std

# Split up sample in init. states and target states
init_states = sample[:2] # (2, N_grid, d_features)
target_states = sample[2:] # (sample_length-2, N_grid, d_features)
Expand Down Expand Up @@ -185,9 +166,6 @@ def __getitem__(self, idx):
-1
) # (N_t', dim_x, dim_y, 1)

if self.standardize:
flux = (flux - self.flux_mean) / self.flux_std

# Flatten and subsample flux forcing
flux = flux.flatten(1, 2) # (N_t, N_grid, 1)
flux = flux[subsample_index :: self.subsample_step] # (N_t, N_grid, 1)
Expand Down
Loading