forked from akaraspt/deepsleepnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·95 lines (81 loc) · 3.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#! /usr/bin/python
# -*- coding: utf8 -*-
import os
import numpy as np
import tensorflow as tf
from deepsleep.trainer import DeepFeatureNetTrainer, DeepSleepNetTrainer
from deepsleep.sleep_stage import (NUM_CLASSES,
EPOCH_SEC_LEN,
SAMPLING_RATE)
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('data_dir', 'data',
"""Directory where to load training data.""")
tf.app.flags.DEFINE_string('output_dir', 'output',
"""Directory where to save trained models """
"""and outputs.""")
tf.app.flags.DEFINE_integer('n_folds', 20,
"""Number of cross-validation folds.""")
tf.app.flags.DEFINE_integer('fold_idx', 0,
"""Index of cross-validation fold to train.""")
tf.app.flags.DEFINE_integer('pretrain_epochs', 100,
"""Number of epochs for pretraining DeepFeatureNet.""")
tf.app.flags.DEFINE_integer('finetune_epochs', 200,
"""Number of epochs for fine-tuning DeepSleepNet.""")
tf.app.flags.DEFINE_boolean('resume', False,
"""Whether to resume the training process.""")
def pretrain(n_epochs):
trainer = DeepFeatureNetTrainer(
data_dir=FLAGS.data_dir,
output_dir=FLAGS.output_dir,
n_folds=FLAGS.n_folds,
fold_idx=FLAGS.fold_idx,
batch_size=100,
input_dims=EPOCH_SEC_LEN*100,
n_classes=NUM_CLASSES,
interval_plot_filter=50,
interval_save_model=100,
interval_print_cm=10
)
pretrained_model_path = trainer.train(
n_epochs=n_epochs,
resume=FLAGS.resume
)
return pretrained_model_path
def finetune(model_path, n_epochs):
trainer = DeepSleepNetTrainer(
data_dir=FLAGS.data_dir,
output_dir=FLAGS.output_dir,
n_folds=FLAGS.n_folds,
fold_idx=FLAGS.fold_idx,
batch_size=10,
input_dims=EPOCH_SEC_LEN*100,
n_classes=NUM_CLASSES,
seq_length=25,
n_rnn_layers=2,
return_last=False,
interval_plot_filter=50,
interval_save_model=100,
interval_print_cm=10
)
finetuned_model_path = trainer.finetune(
pretrained_model_path=model_path,
n_epochs=n_epochs,
resume=FLAGS.resume
)
return finetuned_model_path
def main(argv=None):
# Output dir
output_dir = os.path.join(FLAGS.output_dir, "fold{}".format(FLAGS.fold_idx))
if not FLAGS.resume:
if tf.gfile.Exists(output_dir):
tf.gfile.DeleteRecursively(output_dir)
tf.gfile.MakeDirs(output_dir)
pretrained_model_path = pretrain(
n_epochs=FLAGS.pretrain_epochs
)
finetuned_model_path = finetune(
model_path=pretrained_model_path,
n_epochs=FLAGS.finetune_epochs
)
if __name__ == "__main__":
tf.compat.v1.app.run()