pytorch|使用pytorch从头实现多层LSTM

【pytorch|使用pytorch从头实现多层LSTM】代码如下:

#自定义LSTM实现 class NaiveCustomLSTM(nn.Module):def __init__(self,input_size,hidden_size,num_layers=2): super().__init__() self.input_size = input_size self.hidden_size= hidden_size self.num_layers = num_layers #self._all_weights = {} self.param_names = [] for layer in range(self.num_layers): self.input_size = self.input_size if layer == 0 else self.hidden_size #* num_directions # i_t W_i = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size))# .to(x.device) U_i = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))# .to(x.device) b_i = nn.Parameter(torch.Tensor(self.hidden_size))# .to(x.device)# f_t W_f = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size))# .to(x.device) U_f = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))# .to(x.device) b_f = nn.Parameter(torch.Tensor(self.hidden_size))# .to(x.device)# c_t W_c = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size))# .to(x.device) U_c = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))# .to(x.device) b_c = nn.Parameter(torch.Tensor(self.hidden_size))# .to(x.device)# o_t W_o = nn.Parameter(torch.Tensor(self.input_size, self.hidden_size))# .to(x.device) U_o = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))# .to(x.device) b_o = nn.Parameter(torch.Tensor(self.hidden_size))# .to(x.device)#print(self.W_c) layer_params = (W_i, U_i,W_f,U_f,W_c,U_c,W_o,U_o,b_i,b_f,b_c,b_o)suffix = '' self.param_name = ['weight_W_i{}{}', 'weight_U_i{}{}','weight_W_f{}{}','weight_U_f{}{}', 'weight_W_c{}{}','weight_U_c{}{}','weight_W_o{}{}','weight_U_o{}{}'] #if bias: self.param_name += ['bias_b_i{}{}', 'bias_b_f{}{}','bias_b_c{}{}','bias_b_o{}{}'] self.param_name = [x.format(layer, suffix) for x in self.param_name] for name, param in zip(self.param_name, layer_params): setattr(self, name, param) self.param_names.append(self.param_name)#print(self.param_names) #for name, param in zip(self.param_name, layer_params): #self._all_weights[name] = param self.init_weights() #self.all_weights = [[getattr(self, weight) for weight in weights] for weights in self.param_names]def reset_parameters(self): stdv = 1.0 / math.sqrt(self.hidden_size) for weight in self.parameters(): torch.nn.init.uniform_(weight, -stdv, stdv)def init_weights(self): stdv = 1.0 / math.sqrt(self.hidden_size) for weight in self.parameters(): weight.data.uniform_(-stdv, stdv)#defdef forward(self, x, init_states=None):""" assumes x.shape represents (batch_size, sequence_size, input_size) """bs, seq_sz, _ = x.size() hidden_seqs = 0if init_states is None: h_t, c_t = ( torch.zeros(bs, self.hidden_size).to(x.device), torch.zeros(bs, self.hidden_size).to(x.device), ) else: h_t, c_t = init_states for layer in range(self.num_layers): #print(self.all_weights[0][0]) #for i in range(len(self.param_names)): #self.weight = self.all_weights[layer] #print(self.param_name) #print(layer) #print(self._all_weights) #if self.param_names suffix = '' param_name = ['weight_W_i{}{}', 'weight_U_i{}{}', 'weight_W_f{}{}', 'weight_U_f{}{}', 'weight_W_c{}{}', 'weight_U_c{}{}', 'weight_W_o{}{}', 'weight_U_o{}{}'] # if bias: param_name += ['bias_b_i{}{}', 'bias_b_f{}{}', 'bias_b_c{}{}', 'bias_b_o{}{}'] param_name = [x.format(layer, suffix) for x in param_name]#print(param_name[0]) #W_i = self.param_name[0] #print(W_i) self.param_name = self.param_names[layer] """ print(getattr(self,self.param_name[0])) w_i = self._all_weights[self.param_name[0]] U_i = self._all_weights[self.param_name[1]]W_f = self._all_weights[self.param_name[2]] U_f = self._all_weights[self.param_name[3]]W_c = self._all_weights[self.param_name[4]] U_c = self._all_weights[self.param_name[5]]W_o = self._all_weights[self.param_name[6]] U_o = self._all_weights[self.param_name[7]]b_i = self._all_weights[self.param_name[8]] b_f = self._all_weights[self.param_name[9]] b_c = self._all_weights[self.param_name[10]] b_o = self._all_weights[self.param_name[11]] """ hidden_seq = [] for t in range(seq_sz): x_t = x[:, t, :] #print(x_t.shape,self.W_i.shape) """ i_t = torch.sigmoid(x_t @ self.W_i + h_t @ self.U_i + self.b_i) f_t = torch.sigmoid(x_t @ self.W_f + h_t @ self.U_f + self.b_f) g_t = torch.tanh(x_t @ self.W_c + h_t @ self.U_c + self.b_c) o_t = torch.sigmoid(x_t @ self.W_o + h_t @ self.U_o + self.b_o) """ i_t = torch.sigmoid(x_t @ getattr(self,self.param_name[0])+ h_t @ getattr(self,self.param_name[1])+ getattr(self,self.param_name[8])) f_t = torch.sigmoid(x_t @ getattr(self,self.param_name[2]) + h_t @ getattr(self,self.param_name[3]) + getattr(self,self.param_name[9])) g_t = torch.tanh(x_t @ getattr(self,self.param_name[4]) + h_t @ getattr(self,self.param_name[5]) + getattr(self,self.param_name[10])) o_t = torch.sigmoid(x_t @ getattr(self,self.param_name[6]) + h_t @ getattr(self,self.param_name[7]) + getattr(self,self.param_name[11])) c_t = f_t * c_t + i_t * g_t h_t = o_t * torch.tanh(c_t) h_t = h_t[0] hidden_seq.append(h_t.unsqueeze(1)) #print(np.array(hidden_seq).shape)# reshape hidden_seq p/ retornar hidden_seqs = torch.cat(hidden_seq, dim=1) #print(hidden_seqs.shape) x = hidden_seqsreturn hidden_seqs, (h_t, c_t)

    推荐阅读