Skip to content

Commit

Permalink
Update config file for training on mnist 14*14.
Browse files Browse the repository at this point in the history
  • Loading branch information
YalcinerMustafa committed Oct 17, 2024
1 parent 3f857a9 commit 3a8fdb2
Showing 1 changed file with 6 additions and 57 deletions.
63 changes: 6 additions & 57 deletions experiments/mnist/mnist_0_scaled_14_linf_lognormal_gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ experiments:
trial_config:
logging:
images: true
"image_shape": [10, 10]
"image_shape": [14, 14]
dataset: &dataset
__object__: src.explib.datasets.MnistSplit
scale: true
digit: 0
device: *device
scale_factor: 2
epochs: &epochs 200000
patience: &patience 40
batch_size: &batch_size
Expand All @@ -35,7 +36,7 @@ experiments:
__class__: torch.optim.Adam
params:
lr:
__eval__: tune.loguniform(1e-7, 1e-4)
__eval__: tune.loguniform(1e-6, 1e-4)
weight_decay: 0.0
model_cfg:
type:
Expand All @@ -52,73 +53,21 @@ experiments:
coupling_layers: &coupling_layers
__eval__: tune.choice([i for i in range(3, 4)])
coupling_nn_layers: &coupling_nn_layers
__eval__: "tune.choice([[w] * l for l in [1, 2, 3] for w in [100, 200, 300]])" # tune.choice([[c*32, c*16, c*8, c*16, c*32] for c in [1, 2, 3, 4]] + [[c*64, c*32, c*64] for c in range(1,5)] + [[c*128] * 2 for c in range(1,5)] + [[c*256] for c in range(1,5)])
__eval__: "tune.choice([[w] * l for l in [1, 2, 3] for w in [196, 392]])" # tune.choice([[c*32, c*16, c*8, c*16, c*32] for c in [1, 2, 3, 4]] + [[c*64, c*32, c*64] for c in range(1,5)] + [[c*128] * 2 for c in range(1,5)] + [[c*256] for c in range(1,5)])
nonlinearity: &nonlinearity
__eval__: tune.choice([torch.nn.ReLU()])
split_dim: 50
split_dim: 98
base_distribution:
__object__: src.veriflow.distributions.RadialDistribution
device: *device
p:
__eval__: math.inf
loc:
__eval__: torch.zeros(100).to("cuda")
__eval__: torch.zeros(196).to("cuda")
norm_distribution:
__object__: pyro.distributions.LogNormal
loc:
__eval__: (1.2 * torch.ones(1)).to("cuda")
scale:
__eval__: (0.5 * torch.ones(1)).to("cuda")
use_lu: false
- &mnist_logNormal_linf_loc_12_scale_05_medium_sized
__overwrites__: *mnist_logNormal_linf_loc_1_scale_05_medium_sized
name: mnist_logNormal_linf_loc_12_scale_05_medium_sized
trial_config:
model_cfg:
params:
base_distribution:
norm_distribution:
__object__: pyro.distributions.LogNormal
loc:
__eval__: (1.2 * torch.ones(1)).to("cuda")
scale:
__eval__: (0.5 * torch.ones(1)).to("cuda")
- &mnist_logNormal_linf_loc_08_scale_05_medium_sized
__overwrites__: *mnist_logNormal_linf_loc_1_scale_05_medium_sized
name: mnist_logNormal_linf_loc_08_scale_05_medium_sized
trial_config:
model_cfg:
params:
base_distribution:
norm_distribution:
__object__: pyro.distributions.LogNormal
loc:
__eval__: (0.8 * torch.ones(1)).to("cuda")
scale:
__eval__: (0.5 * torch.ones(1)).to("cuda")
- &mnist_logNormal_linf_loc_1_scale_03_medium_sized
__overwrites__: *mnist_logNormal_linf_loc_1_scale_05_medium_sized
name: mnist_logNormal_linf_loc_1_scale_03_medium_sized
trial_config:
model_cfg:
params:
base_distribution:
norm_distribution:
__object__: pyro.distributions.LogNormal
loc:
__eval__: (1 * torch.ones(1)).to("cuda")
scale:
__eval__: (0.3 * torch.ones(1)).to("cuda")
- &mnist_logNormal_linf_loc_1_scale_07_medium_sized
__overwrites__: *mnist_logNormal_linf_loc_1_scale_05_medium_sized
name: mnist_logNormal_linf_loc_1_scale_07_medium_sized
trial_config:
model_cfg:
params:
base_distribution:
norm_distribution:
__object__: pyro.distributions.LogNormal
loc:
__eval__: (1 * torch.ones(1)).to("cuda")
scale:
__eval__: (0.7 * torch.ones(1)).to("cuda")

0 comments on commit 3a8fdb2

Please sign in to comment.