class Model(nn.Module): def __init__(self, words, args): super(Model, self).__init__() self.args = args self.n_d = args.d self.depth = args.depth self.drop = nn.Dropout(args.dropout) self.embedding_layer = EmbeddingLayer(self.n_d, words) self.n_V = self.embedding_layer.n_V if args.lstm: self.rnn = nn.LSTM(self.n_d, self.n_d, self.depth, dropout = args.rnn_dropout ) else: self.rnn = MF.SRU(self.n_d, self.n_d, self.depth, dropout = args.rnn_dropout, rnn_dropout = args.rnn_dropout, use_tanh = 0 ) self.output_layer = nn.Linear(self.n_d, self.n_V) # tie weights self.output_layer.weight = self.embedding_layer.embedding.weight#我运行了一下应该是指每个单词所对应的向量 self.init_weights() if not args.lstm: self.rnn.set_bias(args.bias) def init_weights(self): val_range = (3.0/self.n_d)**0.5 for p in self.parameters(): if p.dim() > 1: # matrix p.data.uniform_(-val_range, val_range) else: p.data.zero_() def forward(self, x, hidden): emb = self.drop(self.embedding_layer(x)) output, hidden = self.rnn(emb, hidden) output = self.drop(output) output = output.view(-1, output.size(2)) output = self.output_layer(output) return output, hidden def init_hidden(self, batch_size):#hidden层的0初始化 weight = next(self.parameters()).data zeros = Variable(weight.new(self.depth, batch_size, self.n_d).zero_()) if self.args.lstm: return (zeros, zeros) else: return zeros def print_pnorm(self):#p范数 norms = [ "{:.0f}".format(x.norm().data[0]) for x in self.parameters() ] sys.stdout.write("\tp_norm: {}\n".format( norms ))
这个问题源于我对Model类中的方法init_weight的理解,一直读不懂这个方法是做什么的,即self.parameters(),这个迭代器送出来的参数是什么呢,我假设这个里面应该是每一层更新的权重,所以我将sru源码的一部分给取了出来,让其输出Model里的parameters,代码如下():
#coding:UTF-8'''Created on 2017-12-4@author: lai'''import timeimport randomimport mathimport argparse import numpy as npimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.autograd import Variableimport sysimport cuda_functional as MF def read_corpus(path, eos=""): data = [ ] with open(path) as fin: for line in fin: data += line.split() + [ eos ] return data def create_batches(data_text, map_to_ids, batch_size): data_ids = map_to_ids(data_text) N = len(data_ids) L = ((N-1) // batch_size) * batch_size x = np.copy(data_ids[:L].reshape(batch_size,-1).T) y = np.copy(data_ids[1:L+1].reshape(batch_size,-1).T) x, y = torch.from_numpy(x), torch.from_numpy(y) x, y = x.contiguous(), y.contiguous() return x,y class EmbeddingLayer(nn.Module):#为语料中每一个单词对应的其相应的词向量 def __init__(self, n_d, words, fix_emb=False): super(EmbeddingLayer, self).__init__() word2id = {} for w in words: if w not in word2id: word2id[w] = len(word2id)#把文本映射到数字上。 self.word2id = word2id self.n_V, self.n_d = len(word2id), n_d#n_V应该是指词库大小,n_d指hidden state size self.embedding = nn.Embedding(self.n_V, n_d)#赋予每个单词相应的词向量 def forward(self, x): return self.embedding(x) def map_to_ids(self, text):#映射 return np.asarray([self.word2id[x] for x in text], dtype='int64' )class Model(nn.Module): def __init__(self, words, args): super(Model, self).__init__() self.args = args self.n_d = args.d self.depth = args.depth self.drop = nn.Dropout(args.dropout) self.embedding_layer = EmbeddingLayer(self.n_d, words) self.n_V = self.embedding_layer.n_V if args.lstm: self.rnn = nn.LSTM(self.n_d, self.n_d, self.depth, dropout = args.rnn_dropout ) else: self.rnn = MF.SRU(self.n_d, self.n_d, self.depth, dropout = args.rnn_dropout, rnn_dropout = args.rnn_dropout, use_tanh = 0 ) self.output_layer = nn.Linear(self.n_d, self.n_V) # tie weights self.output_layer.weight = self.embedding_layer.embedding.weight#我运行了一下应该是指每个单词所对应的向量 self.init_weights() if not args.lstm: self.rnn.set_bias(args.bias) def init_weights(self): val_range = (3.0/self.n_d)**0.5 for p in self.parameters(): if p.dim() > 1: # matrix p.data.uniform_(-val_range, val_range) print('222222',p.data) else: p.data.zero_() print('0000',p.data) if __name__ == "__main__": argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler='resolve') argparser.add_argument("--lstm", action="store_true") argparser.add_argument("--train", type=str, required=True, help="training file") argparser.add_argument("--batch_size", "--batch", type=int, default=32) argparser.add_argument("--unroll_size", type=int, default=35) argparser.add_argument("--max_epoch", type=int, default=300) argparser.add_argument("--d", type=int, default=910) argparser.add_argument("--dropout", type=float, default=0.7, help="dropout of word embeddings and softmax output" ) argparser.add_argument("--rnn_dropout", type=float, default=0.2, help="dropout of RNN layers" ) argparser.add_argument("--bias", type=float, default=-3, help="intial bias of highway gates", ) argparser.add_argument("--depth", type=int, default=6) argparser.add_argument("--lr", type=float, default=1.0) argparser.add_argument("--lr_decay", type=float, default=0.98) argparser.add_argument("--lr_decay_epoch", type=int, default=175) argparser.add_argument("--weight_decay", type=float, default=1e-5) argparser.add_argument("--clip_grad", type=float, default=5) args = argparser.parse_args() print(args)train = read_corpus(args.train)model = Model(train, args)model.cuda()map_to_ids = model.embedding_layer.map_to_idstrain = create_batches(train, map_to_ids, args.batch_size)print('111',model.parameters())
再终端中输入运行命令:
python 2.py --train train.txt
输出:
Namespace(batch_size=32, bias=-3, clip_grad=5, d=910, depth=6, dropout=0.7, lr=1.0, lr_decay=0.98, lr_decay_epoch=175, lstm=False, max_epoch=300, rnn_dropout=0.2, train='train.txt', unroll_size=35, weight_decay=1e-05)222222 4.8794e-02 5.0702e-02 -3.2630e-02 ... -5.3750e-02 4.2253e-02 1.6446e-02-5.1652e-02 -2.3051e-02 4.3890e-02 ... 1.8805e-02 1.6605e-02 2.6666e-02 2.5273e-02 -5.1426e-03 5.3130e-02 ... -4.8786e-02 4.0186e-02 -4.3724e-02 ... ⋱ ... -3.3133e-02 3.3400e-02 3.2185e-02 ... -5.0593e-02 -2.3048e-02 -2.1572e-02 2.9908e-03 -2.1938e-02 -2.1926e-02 ... -4.5163e-02 -4.1678e-02 -5.2639e-02-2.2036e-02 2.3908e-04 1.9383e-02 ... -1.0341e-02 4.7491e-02 -5.0599e-02[torch.FloatTensor of size 10000x910]222222 -6.1627e-03 1.9962e-02 5.6098e-02 ... 5.2324e-02 -1.0912e-02 1.7969e-02 1.1683e-02 1.4485e-02 3.7155e-02 ... -4.6458e-02 -2.8750e-02 -1.7442e-02 5.3697e-02 3.4534e-02 -2.5292e-02 ... -3.9264e-02 -2.8864e-02 2.3790e-02 ... ⋱ ... 7.6450e-03 -2.1589e-02 -7.6684e-03 ... -5.6521e-02 -5.5103e-02 -3.8065e-02 4.7252e-02 5.7209e-02 -4.9279e-02 ... -2.0944e-02 -4.3891e-03 1.8820e-02 2.7026e-02 3.5590e-02 1.3660e-02 ... -1.6219e-02 -2.1856e-02 3.2678e-02[torch.FloatTensor of size 910x2730]0000 0 0 0⋮ 0 0 0[torch.FloatTensor of size 1820]222222 -1.2439e-02 -5.5866e-02 -3.5799e-02 ... -4.9976e-02 7.3134e-03 4.5684e-03-4.6130e-02 -4.7773e-02 -4.3640e-02 ... -3.2027e-02 -8.8562e-03 4.3218e-02-3.5260e-02 3.1456e-02 1.3324e-02 ... 3.4487e-02 -7.7102e-03 2.9963e-02 ... ⋱ ... -1.6921e-02 -1.5771e-02 5.3847e-02 ... 4.6351e-02 4.9333e-02 -1.1978e-02-1.8770e-02 -1.5817e-02 -7.6655e-05 ... -8.4615e-03 1.4490e-02 -5.6743e-02 4.1060e-03 -2.4452e-02 2.5512e-02 ... -2.3961e-02 -5.2609e-02 3.3445e-02[torch.FloatTensor of size 910x2730]0000 0 0 0⋮ 0 0 0[torch.FloatTensor of size 1820]222222 -3.6535e-02 -2.4697e-02 3.2514e-02 ... 3.0889e-02 -4.7916e-03 9.5873e-03 4.5222e-02 -5.7333e-02 5.4079e-02 ... 1.7790e-02 3.5510e-02 -1.2171e-02 7.5279e-03 -2.7133e-02 -5.1036e-02 ... 5.6305e-02 -2.0042e-02 -2.8884e-02 ... ⋱ ... -4.5409e-02 -1.6207e-02 3.4128e-02 ... -5.6980e-02 1.6646e-02 -2.0662e-02 2.8941e-02 3.1405e-02 5.7100e-02 ... 3.9499e-03 9.5197e-03 -2.3475e-02-5.1939e-02 -9.6567e-03 3.1139e-02 ... -1.0642e-02 -4.8837e-02 2.7009e-02[torch.FloatTensor of size 910x2730]0000 0 0 0⋮ 0 0 0[torch.FloatTensor of size 1820]222222 1.4545e-02 -1.7484e-02 -1.3450e-02 ... 4.9990e-02 3.6013e-03 -2.5272e-02 4.6915e-02 2.4484e-02 -2.6583e-02 ... 3.4737e-02 3.9499e-02 -2.8632e-02 1.8722e-02 -2.1864e-02 2.4649e-02 ... 4.9049e-02 4.8219e-02 3.7317e-02 ... ⋱ ... -2.6708e-02 4.2176e-02 3.8287e-02 ... 3.3608e-02 -2.7229e-02 9.4752e-03 1.2404e-02 1.7356e-02 7.0494e-03 ... 1.5802e-02 -7.5168e-03 -4.1576e-02-3.1050e-02 3.5632e-02 2.2318e-03 ... -1.9828e-02 4.4247e-02 -2.3669e-02[torch.FloatTensor of size 910x2730]0000 0 0 0⋮ 0 0 0[torch.FloatTensor of size 1820]222222 -8.6860e-03 2.4917e-02 -4.8584e-02 ... -1.1277e-02 -1.2668e-02 -1.6445e-02-2.5161e-02 -4.4705e-03 -4.5265e-02 ... -3.1264e-02 -4.2164e-02 -2.4916e-02-1.8575e-02 -1.8767e-02 -5.2647e-02 ... 5.4461e-02 -5.0726e-02 -3.1518e-03 ... ⋱ ... -3.1745e-02 -3.8159e-02 1.7577e-02 ... -5.6739e-02 1.9196e-02 1.6574e-02-5.5951e-02 -6.2410e-03 -5.6714e-02 ... 2.8419e-02 5.7141e-02 2.3431e-02-1.7646e-02 8.7587e-04 -2.3462e-02 ... -4.9807e-04 4.2565e-02 -4.5738e-02[torch.FloatTensor of size 910x2730]0000 0 0 0⋮ 0 0 0[torch.FloatTensor of size 1820]222222 -8.5008e-03 4.9589e-02 4.8005e-02 ... 5.2643e-03 1.4385e-02 -1.8161e-02 3.0520e-03 5.5756e-02 3.9487e-02 ... -2.9614e-03 -5.1740e-02 -4.8080e-02 1.8335e-02 -5.5416e-02 -1.0836e-02 ... 2.8635e-02 -8.8250e-03 -1.4533e-02 ... ⋱ ... 5.2809e-02 -3.2417e-02 3.9305e-02 ... 2.2464e-02 -4.7438e-02 5.1094e-02-5.5829e-02 -4.9564e-02 1.3892e-02 ... -3.4778e-02 4.3359e-02 8.6556e-03-2.1687e-03 -3.7360e-03 4.2217e-03 ... 3.9019e-02 -4.2598e-02 1.6985e-02[torch.FloatTensor of size 910x2730]0000 0 0 0⋮ 0 0 0[torch.FloatTensor of size 1820]0000 0 0 0⋮ 0 0 0[torch.FloatTensor of size 10000]111
下面是方法init_weight的代码:
def init_weights(self): val_range = (3.0/self.n_d)**0.5 for p in self.parameters(): if p.dim() > 1: # matrix p.data.uniform_(-val_range, val_range) print('222222',p.data) else: p.data.zero_() print('0000',p.data)
上面运行输出的结果就是p.data.uniform_(-val_range, val_range)以及p.data.zero_()的值,这里的参数我猜测一个是sru中的权重(w)另一个是偏置(b),但是这样的话就有一个疑问,这里输出的第一个大小为10000*910的tensor是词向量化得到的10000个单词的词向量,而最后一个大小为10000的tensor是最后线性分类全连接层的参数,所以剩下有六对的w和b,但是这样的话就有一个疑问,因为循环神经网络是时间共享的,所以应该只有一对才对,为了解决这个疑问,
我将用lstm做mnist分类的代码拿了出来,并将它的model的参数打印了出来,代码和结果如下所示
代码:
import torchfrom torch import nnfrom torch.autograd import Variableimport torchvision.datasets as dsetsimport torchvision.transforms as transformsimport matplotlib.pyplot as plttorch.manual_seed(1) # reproducible# Hyper ParametersEPOCH = 1 # train the training data n times, to save time, we just train 1 epochBATCH_SIZE = 64TIME_STEP = 28 # rnn time step / image heightINPUT_SIZE = 28 # rnn input size / image widthLR = 0.01 # learning rateDOWNLOAD_MNIST = True # set to True if haven't download the data# Mnist digital datasettrain_data = dsets.MNIST( root='./mnist/', train=True, # this is training data transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] download=DOWNLOAD_MNIST, # download it if you don't have it)# plot one exampleprint(train_data.train_data.size()) # (60000, 28, 28)print(train_data.train_labels.size()) # (60000)plt.imshow(train_data.train_data[0].numpy(), cmap='gray')plt.title('%i' % train_data.train_labels[0])plt.show()# Data Loader for easy mini-batch return in trainingtrain_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)# convert test data into Variable, pick 2000 samples to speed up testingtest_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())test_x = Variable(test_data.test_data, volatile=True).type(torch.FloatTensor)[:2000]/255. # shape (2000, 28, 28) value in range(0,1)test_y = test_data.test_labels.numpy().squeeze()[:2000] # covert to numpy arrayclass RNN(nn.Module): def __init__(self): super(RNN, self).__init__() self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns input_size=INPUT_SIZE, hidden_size=64, # rnn hidden unit num_layers=2, # number of rnn layer batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size) ) self.out = nn.Linear(64, 10) def forward(self, x): # x shape (batch, time_step, input_size) # r_out shape (batch, time_step, output_size) # h_n shape (n_layers, batch, hidden_size) # h_c shape (n_layers, batch, hidden_size) r_out, (h_n, h_c) = self.rnn(x, None) # None represents zero initial hidden state # choose r_out at the last time step out = self.out(r_out[:, -1, :]) return out def init_weights(self): for p in self.parameters(): print('PPP',p.data)rnn = RNN()print(rnn.init_weights())
输出:
torch.Size([60000, 28, 28])torch.Size([60000])PPP -2.0745e-02 1.2430e-01 5.5081e-02 ... -1.4137e-02 9.4529e-02 -6.7606e-02-1.1815e-01 8.6035e-03 4.2617e-02 ... 8.2401e-02 -1.1524e-01 -5.6738e-02-8.2542e-02 -1.1019e-01 9.4536e-02 ... 4.0159e-02 6.2041e-02 -5.0376e-02 ... ⋱ ... 1.0238e-01 5.3194e-02 5.3342e-02 ... -1.5019e-02 -1.0299e-01 2.3091e-02 4.5909e-02 -5.0352e-02 -2.5497e-02 ... 1.1765e-01 -1.1448e-01 -3.1609e-02 3.1011e-06 -1.0142e-01 1.2229e-01 ... 3.1813e-02 7.6921e-02 4.4233e-03[torch.FloatTensor of size 256x28]PPP -2.4325e-03 1.1478e-02 9.3458e-02 ... -1.1657e-01 -3.6968e-03 1.2013e-01 1.2265e-01 -2.3560e-02 -5.3951e-02 ... 4.1457e-02 -6.7170e-02 6.1414e-02 1.2334e-01 -6.3188e-02 3.9050e-02 ... 8.4631e-02 4.0930e-04 8.3604e-03 ... ⋱ ... 5.6417e-02 3.7298e-02 5.7616e-02 ... 2.9125e-02 -6.6484e-02 -4.2838e-02-6.0267e-02 8.6004e-02 4.4727e-02 ... -4.9643e-02 -3.5065e-03 -2.5401e-02 8.1001e-02 5.8518e-02 -9.0292e-02 ... -1.5258e-02 5.6519e-02 6.1370e-02[torch.FloatTensor of size 256x64]PPP 0.0282-0.0362 0.0864 0.0677 0.0012 0.0699 0.0850-0.0927 0.0074-0.0183 0.0679 0.1177 0.0255 0.1012 0.1248-0.0625 0.0023-0.0255 0.0870-0.0900 0.1057 0.1233 0.0982 0.0475-0.0387-0.0267-0.0964-0.0153 0.0004-0.0410 0.0771-0.0399 0.0746-0.0210-0.0396 0.1108 0.0347 0.0263 0.0244 0.1113-0.1071 0.1036 0.0478 0.0217 0.0314 0.0138-0.1113-0.1192-0.0286-0.0674-0.0165-0.0097 0.0663-0.1072 0.0048-0.1062 0.0677-0.0028 0.0809 0.0119 0.1111 0.0363 0.0877 0.0189 0.0396 0.0358-0.0257 0.0966 0.0951-0.1179-0.0906-0.0619-0.0229-0.1193 0.0254 0.0110 0.0400 0.0655 0.1200-0.0940 0.0728 0.0882-0.1049 0.0939 0.0041-0.0711 0.0914-0.0461 0.0109-0.0800-0.0766-0.0265-0.0381-0.0433 0.0193 0.0812 0.0163 0.0358-0.0053-0.0900-0.0037 0.1009 0.1084 0.1006-0.1237-0.1227 0.0808-0.0083 0.0376 0.0424-0.1121 0.0379 0.0457 0.0443-0.0528 0.0220-0.0690 0.0620-0.0660-0.1124 0.1238 0.1188 0.0121 0.0574 0.1246 0.1000-0.1034 0.0387 0.0307-0.0669-0.0619-0.0819 0.0566 0.0150 0.0271-0.0843-0.0209-0.0957-0.1174 0.1031-0.1250 0.0180-0.0449 0.0920 0.1114 0.0604-0.0987 0.0378-0.0088-0.0471 0.0549-0.1234 0.1069-0.0567 0.0241-0.0163 0.0585 0.0199-0.0188 0.0265-0.0673 0.0697-0.1224 0.1042-0.0697 0.0695 0.0575-0.1156 0.0663 0.1177 0.0562-0.0417-0.0054 0.0045 0.0614-0.0089 0.0203-0.1049-0.1201-0.0638 0.0728 0.0208-0.1018-0.0363 0.1128-0.0524 0.0992 0.0937-0.0378-0.0195-0.0188-0.0483 0.0779-0.0754 0.0148-0.0060 0.0743-0.0820-0.0673-0.1153-0.1039 0.1002 0.1217-0.0797 0.0217 0.1129 0.0951 0.0616-0.1183-0.0252-0.0304 0.1234-0.0538 0.0367 0.0407 0.1176-0.0902-0.0805 0.0111-0.0863-0.1222-0.0678-0.0044-0.1218 0.0300 0.0739-0.1152 0.1235-0.0317 0.0685 0.0598 0.1120-0.0902 0.1143 0.0801 0.0399 0.0360-0.1152-0.1007-0.1126 0.0860-0.0592 0.0955 0.0719-0.1118 0.0839-0.1176 0.0537 0.0078 0.1173 0.0129-0.0301 0.0105 0.0961 0.1167-0.0015[torch.FloatTensor of size 256]PPP -0.0896-0.0394 0.0575 0.0898-0.0369-0.0604-0.1172-0.0549-0.0869 0.0679 0.0554 0.0323 0.1063 0.0728 0.0056-0.0021-0.0868-0.0736-0.1204-0.0460-0.0145-0.0992 0.0601 0.0738 0.0064-0.0570-0.0947 0.0027 0.0669 0.0408-0.0228 0.0554 0.0698 0.0994 0.0893 0.1066 0.1231-0.0688 0.0152-0.0445-0.0341-0.0329 0.1052-0.0456-0.0409 0.0484 0.0768 0.0061 0.0429-0.0186 0.0379-0.0657-0.0839 0.0442-0.0539-0.0483 0.0572-0.0753-0.0779-0.1166 0.0279-0.0066 0.0854 0.0428 0.0903-0.0658 0.1244-0.0133 0.0524 0.0666-0.0662 0.1046-0.0649 0.1223 0.0819-0.0074 0.0782-0.0263-0.0057-0.0470 0.1029 0.1156 0.0884 0.0517 0.0135 0.0975 0.0406 0.0615-0.1222 0.0127 0.0202 0.0154-0.0490 0.0423-0.0904 0.0034 0.0662-0.0574 0.1162-0.0481-0.0147 0.0243 0.0805 0.0352 0.1058 0.0748-0.0551-0.0796-0.1161-0.0610-0.0102 0.0143 0.0791 0.0752 0.0099 0.1133-0.0766 0.0520 0.0810 0.1068-0.0541 0.0390 0.1153 0.0095 0.0118-0.0185-0.1179 0.0452 0.0302-0.0776 0.0909-0.0086 0.0527 0.0133 0.1130-0.0909 0.1160 0.1218 0.0347-0.0277 0.0401 0.1104-0.0635-0.0656-0.0928-0.0365 0.0579 0.1197-0.0098-0.0489-0.1086 0.0579 0.0282-0.0649 0.0929 0.0039 0.0507 0.1174 0.0951-0.0533 0.0641 0.0185 0.0011-0.0621 0.0776-0.0298-0.1170 0.0693 0.0740-0.0802 0.0799-0.0972-0.0010 0.0589-0.0510-0.0292-0.0500 0.0838-0.0176 0.0527-0.0037 0.0092 0.0478 0.0512-0.1239 0.0042-0.0440-0.0278-0.0434 0.0052 0.0466-0.0746-0.1143-0.0694 0.0201 0.0768-0.0924 0.0589-0.0591-0.1036 0.0529 0.0197-0.1067-0.0165-0.0370 0.0374-0.0818-0.0040 0.0659 0.1040-0.0619-0.1208-0.1066 0.1142 0.0920 0.0833 0.0214 0.1020-0.0266-0.0508 0.0550-0.0452-0.0696 0.0879 0.0680 0.1009-0.0232 0.0159-0.1064-0.0839 0.1089-0.0473-0.0158 0.0185-0.1224 0.1131 0.1089 0.1030-0.0451-0.0555-0.0767-0.0546 0.0403-0.1247-0.0622-0.0063-0.0933 0.0445 0.0727 0.0664-0.0864-0.0978 0.0016-0.1126 0.0716 0.0169[torch.FloatTensor of size 256]PPP -6.6907e-02 -1.1469e-01 6.4129e-02 ... 3.8876e-02 -4.4813e-02 4.7873e-02 1.0064e-01 -1.2048e-01 7.3207e-02 ... -1.2326e-02 -1.1054e-01 -1.1371e-01-9.9514e-02 -4.0268e-04 7.1349e-03 ... -1.0321e-01 -1.2389e-01 -4.2875e-03 ... ⋱ ... 6.1065e-02 -5.2070e-02 -7.4900e-02 ... 3.0900e-02 5.6731e-02 1.0931e-01-4.2554e-03 1.2137e-01 -1.0776e-02 ... -9.8254e-03 -3.8701e-02 -2.6478e-02-6.6246e-02 4.3564e-02 4.7540e-02 ... -8.6700e-02 -6.5478e-03 -7.8267e-02[torch.FloatTensor of size 256x64]PPP -9.3750e-02 -8.5315e-02 -3.2224e-02 ... 4.6174e-02 1.2341e-01 7.0605e-02-1.0107e-01 -1.1443e-01 -1.2133e-01 ... -1.1138e-01 7.7709e-02 4.1309e-02-1.0675e-01 -9.5286e-02 8.1566e-02 ... -5.4656e-02 -2.9437e-02 -3.4233e-02 ... ⋱ ... 1.0409e-01 6.9673e-02 6.2664e-02 ... -3.2450e-02 -7.9281e-02 1.1497e-01-2.8081e-02 -1.2337e-01 6.9056e-02 ... -1.0816e-01 -8.9076e-02 5.8901e-02 6.1354e-02 -2.9104e-02 -5.5389e-02 ... -3.9486e-02 -2.9318e-02 1.1121e-01[torch.FloatTensor of size 256x64]PPP -0.0661 0.0039 0.0343-0.0428-0.0931 0.0150 0.0667-0.0503 0.1009 0.0786 0.0435-0.0952 0.0759-0.0155-0.0651-0.0916 0.1066 0.0204-0.0731 0.1241 0.0861-0.0129-0.0326-0.0626-0.1194 0.0683-0.0699-0.0822 0.0856-0.0142-0.0683-0.1223-0.0443-0.1215 0.0422 0.0083 0.0220-0.1037 0.0534 0.0914-0.0479-0.0273 0.0670-0.0777 0.0030 0.0343-0.1053-0.0880-0.0184 0.0800-0.0517-0.0596-0.0919 0.0129 0.0592 0.0903 0.0144-0.0522-0.0801-0.0489 0.0093-0.0173-0.0433-0.0887 0.1231-0.0524-0.0295-0.0432-0.0109-0.0625 0.0006-0.0658 0.0526-0.0297 0.0765-0.0805 0.0268-0.0250-0.0652-0.1201-0.1215-0.0732 0.0856-0.0101-0.1052-0.0456-0.0750-0.1149 0.0586 0.0594 0.1186 0.0742 0.0826 0.0612 0.0535 0.0827 0.1247-0.0917 0.0162 0.0731-0.0980-0.0508 0.1217-0.0242 0.0939 0.0172 0.1151 0.0706-0.1080-0.1144-0.0062 0.1227 0.0040 0.0451 0.0370 0.0963-0.0548 0.0073 0.0590-0.0860 0.0873 0.0123 0.0907-0.0206 0.0959 0.1026 0.0361 0.0632-0.0422 0.0934-0.1055-0.1022 0.0365-0.0169-0.0298 0.0096 0.0932-0.0130-0.0151 0.0693-0.0823-0.0176 0.0714-0.0319 0.0251 0.0878-0.0841-0.0804 0.0915 0.0282 0.0470-0.0592-0.0913-0.1234 0.0315 0.0182-0.0110 0.0275-0.0983 0.0250-0.0442-0.0113-0.0569 0.0902 0.0690 0.0543-0.0904 0.0373 0.0728-0.1175-0.0886-0.0702-0.0567-0.0740 0.1204-0.0247-0.0659 0.0075 0.0327 0.0215 0.0539-0.1142-0.0042 0.0156-0.1102 0.0036 0.0363-0.0509-0.0219-0.0764 0.1240-0.0074 0.0395 0.0058-0.0012 0.0614 0.0985 0.0915-0.0060-0.0268 0.1034 0.1116 0.0221 0.1064-0.0271 0.0554 0.0099-0.0627-0.0422 0.0102-0.0310 0.0050-0.0806 0.1235-0.0786-0.1168-0.1148 0.0717-0.1048 0.0509 0.0219 0.0902-0.0821-0.0005 0.0549-0.0563-0.0460-0.0904-0.0209 0.0030-0.1225-0.1071-0.0584-0.0711-0.0749-0.1088-0.0597-0.0829 0.0858-0.0987-0.0564-0.0063 0.0432-0.1095-0.0563 0.0691-0.0815-0.0858 0.1200 0.0459 0.0008 0.0818-0.0996-0.0737-0.0613-0.0190[torch.FloatTensor of size 256]PPP 0.0130-0.0655 0.0321-0.0441 0.0407 0.0434-0.0885 0.1136-0.0390 0.0391-0.0185 0.1143 0.0910 0.0787 0.1237 0.0194 0.1165 0.0155-0.0504 0.0776-0.0269 0.0218-0.0945-0.0426 0.0947-0.0057 0.1128 0.0760-0.0732-0.0685-0.0252 0.0184 0.0505 0.0759 0.0615-0.0737 0.0955-0.0121-0.0377-0.0322-0.1096 0.0560-0.0542 0.0561 0.0817-0.1046-0.1038 0.0840 0.0799-0.0957-0.0016 0.0730 0.0618 0.0825 0.0690-0.0078-0.1246 0.0268-0.0774 0.0724-0.0090 0.0527 0.0685 0.0065 0.1016 0.0774-0.0896-0.1083-0.0638 0.0117 0.0420-0.0266-0.1220 0.0789 0.1214-0.1015-0.0909-0.0033 0.0222 0.0632-0.0497 0.1060-0.0510-0.0921 0.0712 0.0647 0.0967 0.0060-0.0525 0.1039 0.0658-0.0608 0.0169 0.0928-0.0088-0.0515 0.1121 0.0269-0.0597 0.0628-0.0472-0.1149 0.0278-0.0011-0.1209-0.0417-0.0575-0.1082-0.0024-0.0415 0.0768-0.0113-0.0656-0.1064 0.0836-0.0422 0.0870-0.1213-0.1221-0.0013-0.0250 0.0287 0.0259 0.1054-0.0570 0.0618-0.0923-0.0611 0.0055 0.0844 0.0405 0.1082-0.0302-0.1106-0.0838 0.0420 0.0394 0.1039 0.0928-0.1081 0.1234-0.0382-0.0146 0.0087-0.1011-0.0149 0.0597 0.0590-0.0194-0.0813-0.0690 0.0264-0.1082-0.0783 0.0951 0.1159-0.0691 0.0259-0.0214 0.1139-0.0472 0.0963 0.0718 0.1083-0.1242 0.0716-0.0109 0.0272 0.1071-0.1237 0.0692-0.0022 0.0654 0.1097 0.0385 0.0353-0.0804 0.0428 0.0702-0.1195 0.0169-0.0206 0.1065 0.0441 0.0651-0.0746 0.0194-0.0477 0.0950-0.0569-0.0991 0.0898-0.0652 0.0683 0.1220-0.0222-0.0751 0.0174 0.0994 0.0596-0.1138 0.0801-0.0527 0.0947 0.0996 0.0951-0.0851-0.0969-0.0364-0.0450-0.0039 0.0870-0.1237-0.1074 0.0992 0.0800-0.0711 0.0041 0.0270-0.0486-0.0652-0.0523-0.0862-0.0883-0.1182-0.0350-0.1132 0.0665-0.0439 0.0392 0.0400 0.0344-0.1176-0.0682-0.1236 0.0208-0.1139 0.0633-0.1106 0.0126 0.0185-0.0219 0.1117 0.0977 0.0860 0.0608 0.0103 0.0771-0.0751 0.0909 0.0020-0.0930 0.0830-0.0403-0.0516 0.0852[torch.FloatTensor of size 256]PPP Columns 0 to 9 0.0991 0.1218 -0.0816 0.0220 0.1029 0.0342 -0.0448 -0.0178 -0.0067 0.0853 0.1030 -0.0817 0.0258 0.0233 0.0885 -0.1076 0.0526 0.0402 0.0480 -0.1025 0.0224 -0.1067 0.0508 -0.0831 -0.0963 0.1152 -0.0994 -0.0305 -0.1041 -0.0282-0.0365 -0.0857 -0.0107 0.0929 -0.0940 -0.0774 -0.0135 -0.0096 0.1087 0.1086 0.0340 -0.0464 -0.1135 0.0084 -0.0820 -0.0957 0.0070 0.0113 0.0882 0.1237 0.0658 -0.1047 -0.1228 -0.0985 0.0482 0.1177 -0.0759 -0.0205 0.0492 -0.0698-0.0384 0.0334 0.0953 0.1019 -0.1207 -0.0936 -0.0745 -0.0863 0.0533 0.0637-0.0595 0.0473 -0.0147 0.0062 -0.0191 -0.1011 -0.0289 -0.0175 -0.0966 -0.0236 0.0033 0.0701 0.0546 0.0245 -0.0388 -0.0780 0.1232 0.0122 -0.0397 -0.0912-0.1052 -0.0875 -0.0197 0.0015 0.1021 -0.0661 -0.0445 0.0846 -0.0606 -0.0982Columns 10 to 19 0.1033 -0.0640 0.0401 0.0702 -0.0747 -0.0222 -0.0202 -0.1072 0.0767 0.0377 0.0887 0.1194 0.1097 0.0148 -0.0138 0.0688 0.0077 0.1012 0.0860 0.0938-0.0802 -0.0107 0.1062 -0.0412 -0.0003 -0.0302 0.0076 -0.0905 0.0395 0.0955-0.0888 -0.1035 0.0805 0.0047 -0.0107 0.1076 0.0193 -0.0615 -0.0366 0.0952-0.0148 0.1075 -0.0537 -0.0461 -0.0562 0.0190 -0.1205 -0.0974 -0.1083 -0.0353-0.0527 0.1049 -0.0480 0.0007 0.0755 -0.0399 0.0567 0.0688 0.0719 -0.0474 0.0052 -0.0320 0.0903 -0.0895 0.0861 -0.1100 -0.0788 -0.0094 -0.0595 0.0111 0.0535 -0.0790 -0.0736 -0.0512 0.0414 0.0372 -0.0638 -0.1041 -0.0484 -0.0755 0.1205 -0.0672 0.1016 0.0827 0.0972 -0.0551 -0.0410 -0.0551 -0.1206 -0.0395-0.0214 0.0026 -0.0185 0.0001 0.0064 0.0982 0.0946 0.0116 -0.0024 -0.1074Columns 20 to 29 0.0014 -0.0417 0.0009 0.0854 0.0269 -0.0232 0.0012 0.0069 0.1210 -0.0919-0.0958 -0.1185 -0.1184 0.0191 0.0536 -0.0257 0.0315 -0.0092 0.1055 -0.1166 0.0894 -0.0709 0.0922 -0.0424 0.0420 -0.0950 -0.0118 -0.0910 -0.1123 0.0984-0.0553 0.0978 0.0158 -0.0619 0.0885 -0.0976 0.1039 -0.0054 -0.0926 0.0064 0.1147 -0.0009 -0.0362 -0.0879 -0.0277 -0.1015 -0.1144 -0.0243 -0.1179 0.0933-0.0904 -0.1183 0.0636 -0.0606 0.0001 -0.0374 -0.0823 -0.0881 -0.0811 -0.0672 0.0241 -0.0959 0.0423 -0.0978 -0.0285 0.0123 0.0488 0.0487 0.0176 0.0173 0.1008 0.0326 -0.0710 -0.1112 -0.0287 -0.0300 -0.0440 -0.0343 -0.0450 -0.1118 0.1113 -0.0555 0.0969 -0.0204 -0.0316 -0.0028 -0.0019 0.0290 -0.0231 0.0070-0.0039 -0.0672 -0.0438 0.0368 0.0553 -0.0499 0.0267 -0.0649 0.0019 0.0879Columns 30 to 39 0.1117 -0.0552 0.0605 0.0743 0.0197 -0.0904 0.0005 0.0353 -0.0751 -0.0130 0.0750 -0.1095 0.0277 0.1156 0.0949 -0.0796 0.1044 0.0500 0.1119 0.0033-0.1121 0.0314 0.0501 0.0035 -0.1149 0.0623 0.0100 -0.0163 0.1058 0.0865 0.0800 -0.0530 -0.0353 0.0779 0.1238 -0.0200 -0.0272 0.0986 0.0196 -0.0383-0.0122 -0.1203 0.0466 -0.0569 -0.1043 -0.0704 0.1004 0.0055 0.0543 -0.0131-0.0977 -0.0751 0.0328 0.0662 -0.0501 0.1024 0.1224 -0.0401 0.0107 0.0433 0.0638 -0.1180 -0.0250 -0.1239 0.0566 0.0193 -0.0407 -0.0628 0.0466 -0.0568 0.0265 -0.1144 -0.0753 0.1054 -0.0994 0.1162 0.0292 0.0838 -0.0420 -0.0506-0.0177 0.0262 -0.0189 -0.0819 -0.0847 -0.0090 -0.0930 0.1133 0.0611 -0.0546 0.0987 -0.0040 -0.0567 -0.0284 0.0951 -0.0739 0.0193 -0.0317 -0.0896 0.0663Columns 40 to 49 0.0285 0.0341 0.1245 -0.0614 -0.0078 -0.0584 -0.0105 0.0094 0.0422 -0.0227 0.0398 0.1004 -0.0884 0.0318 -0.0911 -0.1213 -0.0907 -0.0738 -0.0523 -0.0317-0.1230 0.0846 -0.0740 -0.0878 0.0250 0.0375 -0.0831 0.1182 -0.0754 -0.0871-0.0256 0.0675 -0.0249 0.0952 -0.1188 -0.0273 0.0934 0.1209 0.0765 0.0063 0.0708 0.0393 0.0189 0.0350 -0.0329 0.1113 0.0110 -0.0083 -0.1152 -0.0735 0.0585 0.0925 0.0616 0.0478 0.0957 0.1038 0.0545 -0.0227 -0.1126 0.0958 0.1080 -0.1215 0.0274 0.0803 -0.1214 0.0364 0.0985 -0.0505 0.0941 -0.0675-0.0153 0.1246 -0.0902 0.0092 0.1193 -0.1020 -0.0869 0.0396 0.1078 0.0155 0.1243 0.0651 -0.0685 -0.0275 -0.0058 0.0416 -0.0851 0.0398 0.0317 -0.0656-0.0128 0.0311 -0.0837 -0.0885 -0.0965 0.0931 -0.0942 -0.0342 0.0851 0.0435Columns 50 to 59 -0.0706 0.0740 0.0403 0.0486 0.0804 0.1016 0.0948 0.0042 -0.0204 -0.1151 0.1095 0.0921 -0.1028 0.0282 0.0878 0.0996 0.1205 -0.0796 -0.0634 -0.1172 0.1047 -0.0863 0.0562 0.0295 0.0177 -0.0250 0.0261 0.1133 0.0844 0.0866-0.0407 0.0486 -0.1202 -0.1043 0.0989 0.0932 0.0133 0.0651 -0.1158 -0.0456-0.1219 0.0920 0.0697 0.0927 0.1020 0.0391 0.0309 0.0199 0.0844 0.0428-0.0501 0.0589 0.0111 -0.0826 0.0056 -0.0369 -0.0911 0.1175 -0.0292 0.0318 0.0445 0.1137 0.1123 -0.0716 0.0885 -0.0383 0.0276 0.0571 0.0976 0.0298-0.1082 -0.1132 -0.0977 -0.0630 0.1066 0.0418 0.0862 -0.0329 -0.0949 -0.1048 0.0947 0.0587 -0.0304 0.0770 -0.0187 0.0003 -0.0628 -0.1068 0.1023 0.0669-0.0424 -0.0686 -0.0745 -0.0949 -0.0700 0.1227 -0.0021 -0.1125 -0.1001 0.0545Columns 60 to 63 0.0592 -0.0805 -0.0735 -0.0953 0.0493 -0.0285 0.0179 0.0019 0.0548 0.0819 -0.1057 0.0855 0.0880 -0.0224 0.0091 0.0845 0.0501 -0.0397 -0.0922 0.1050 0.0109 -0.1045 0.0098 -0.0755 0.1079 0.0461 0.0320 -0.0830 0.0902 0.0743 -0.0809 -0.0330-0.0153 0.0420 0.0624 -0.1119-0.0138 -0.0618 0.1001 0.0437[torch.FloatTensor of size 10x64]PPP 0.0109-0.0778-0.0501 0.0163 0.0763-0.0792 0.1141-0.0127 0.0162 0.0808[torch.FloatTensor of size 10]None
关于pytorch中LSTM的可以再这里查看。
我打印出Lstm的参数,并将它们结合pytorch的官方文档,发现其实LSTM的这些参数都是Variables,注意到这个例子里的w和b也不只有一对,而是有两对,因为LSTM的num_layers=2,当这个值为3时就会有3对,由这里我受到启发,在改变sru的layer后,也发生了变化。由此我得出结论循环神经网络并不是只有一个神经单元,而是可以有多个,之前我一直以为只有一个。
而sru中的参数也是以Variable的形式存在与整个模型中,可以被更新。