-
Notifications
You must be signed in to change notification settings - Fork 317
/
Copy pathpreprocess_data.py
67 lines (55 loc) · 2.39 KB
/
preprocess_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import torch
from torch import device
import glob
import datetime
import numpy as np
import shutil
from pathlib import Path
import pickle
def normalization(seqData,max,min):
return (seqData -min)/(max-min)
def standardization(seqData,mean,std):
return (seqData-mean)/std
def reconstruct(seqData,mean,std):
return seqData*std+mean
class PickleDataLoad(object):
def __init__(self, data_type, filename, augment_test_data=True):
self.augment_test_data=augment_test_data
self.trainData, self.trainLabel = self.preprocessing(Path('dataset',data_type,'labeled','train',filename),train=True)
self.testData, self.testLabel = self.preprocessing(Path('dataset',data_type,'labeled','test',filename),train=False)
def augmentation(self,data,label,noise_ratio=0.05,noise_interval=0.0005,max_length=100000):
noiseSeq = torch.randn(data.size())
augmentedData = data.clone()
augmentedLabel = label.clone()
for i in np.arange(0, noise_ratio, noise_interval):
scaled_noiseSeq = noise_ratio * self.std.expand_as(data) * noiseSeq
augmentedData = torch.cat([augmentedData, data + scaled_noiseSeq], dim=0)
augmentedLabel = torch.cat([augmentedLabel, label])
if len(augmentedData) > max_length:
augmentedData = augmentedData[:max_length]
augmentedLabel = augmentedLabel[:max_length]
break
return augmentedData, augmentedLabel
def preprocessing(self, path, train=True):
""" Read, Standardize, Augment """
with open(str(path), 'rb') as f:
data = torch.FloatTensor(pickle.load(f))
label = data[:,-1]
data = data[:,:-1]
if train:
self.mean = data.mean(dim=0)
self.std= data.std(dim=0)
self.length = len(data)
data,label = self.augmentation(data,label)
else:
if self.augment_test_data:
data, label = self.augmentation(data, label)
data = standardization(data,self.mean,self.std)
return data,label
def batchify(self,args,data, bsz):
nbatch = data.size(0) // bsz
trimmed_data = data.narrow(0,0,nbatch * bsz)
batched_data = trimmed_data.contiguous().view(bsz, -1, trimmed_data.size(-1)).transpose(0,1)
batched_data = batched_data.to(device(args.device))
return batched_data