generated from PumasAI-Labs/Workshop-Template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path04-generalization.jl
118 lines (99 loc) · 3.48 KB
/
04-generalization.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Note: `1-linear_regression.jl`, `2-complex_relationships.jl` and
# `3-bias-variance_tradeoff.jl` need to be executed first and in that order.
#
# 4. GENERALIZATION
#
# 4.1. Withheld (or unseen) data
# 4.2. Validation loss as a proxy for generalization performance
# 4.3. Regularization to prevent overfitting
# 4.4. Programmatic hyperparameter tuning
#
# 4.1. Withheld (or unseen) data
x_train, y_train = x, y
target_train = target
ϵ_valid = rand(normal, 1, num_samples)
x_valid = rand(uniform, 1, num_samples)
y_valid = true_function.(x_valid) + σ * ϵ_valid
target_valid = preprocess(x_valid, y_valid)
fig = scatter(
vec(x_train),
vec(y_train);
axis = (xlabel = "x", ylabel = "y"),
label = "training data",
);
scatter!(vec(x_valid), vec(y_valid); label = "validation data");
lines!(-1 .. 1, true_function; color = :gray, label = "true");
axislegend();
fig
# 4.2. Validation loss as a proxy for generalization performance
loss_train_l, loss_valid_l = [], []
fitted_nn =
fit(nn, target_train; optim_alg = DeepPumas.BFGS(), optim_options = (; iterations = 10))
push!(loss_train_l, sum((fitted_nn(x_train) .- y_train) .^ 2))
push!(loss_valid_l, sum((fitted_nn(x_valid) .- y_valid) .^ 2))
iteration_blocks = 100
for _ = 2:iteration_blocks
global fitted_nn = fit(
nn,
target_train,
coef(fitted_nn);
optim_alg = DeepPumas.BFGS(),
optim_options = (; iterations = 10),
)
push!(loss_train_l, sum((fitted_nn(x_train) .- y_train) .^ 2))
push!(loss_valid_l, sum((fitted_nn(x_valid) .- y_valid) .^ 2))
end
iteration = 10 .* (1:iteration_blocks)
fig, ax = scatterlines(
iteration,
Float32.(loss_train_l);
label = "training",
axis = (; xlabel = "Iteration", ylabel = "Mean squared loss"),
);
scatterlines!(iteration, Float32.(loss_valid_l); label = "validation");
axislegend();
fig
# 4.3. Regularization to prevent overfitting
reg_nn = MLPDomain(1, (32, tanh), (32, tanh), (1, identity); bias = true, reg = L2(0.1))
reg_loss_train_l, reg_loss_valid_l = [], []
fitted_reg_nn = fit(
reg_nn,
target_train;
optim_alg = DeepPumas.BFGS(),
optim_options = (; iterations = 10),
)
push!(reg_loss_train_l, sum((fitted_reg_nn(x_train) .- y_train) .^ 2))
push!(reg_loss_valid_l, sum((fitted_reg_nn(x_valid) .- y_valid) .^ 2))
iteration_blocks = 100
for _ = 2:iteration_blocks
global fitted_reg_nn = fit(
reg_nn,
target_train,
coef(fitted_reg_nn);
optim_alg = DeepPumas.BFGS(),
optim_options = (; iterations = 10),
)
push!(reg_loss_train_l, sum((fitted_reg_nn(x_train) .- y_train) .^ 2))
push!(reg_loss_valid_l, sum((fitted_reg_nn(x_valid) .- y_valid) .^ 2))
end
iteration = 10 .* (1:iteration_blocks)
fig, ax = scatterlines(
iteration,
Float32.(loss_train_l);
label = "training",
axis = (; xlabel = "Blocks of 10 iterations", ylabel = "Mean squared loss"),
);
scatterlines!(iteration, Float32.(loss_valid_l); label = "validation");
scatterlines!(iteration, Float32.(reg_loss_train_l); label = "training (L2)");
scatterlines!(iteration, Float32.(reg_loss_valid_l); label = "validation (L2)");
axislegend();
fig
# 4.4. Programmatic hyperparameter tuning
nn_ho = hyperopt(reg_nn, target_train)
nn_ho.best_hyperparameters
ŷ_ho = nn_ho(x_valid)
fig = scatter(vec(x_valid), vec(y_valid); label = "validation data");
scatter!(vec(x_valid), vec(ŷ_ho), label = "prediction (hyperparam opt.)");
lines!(-1 .. 1, true_function; color = :gray, label = "true");
axislegend(; position = :ct);
fig