本文概述
- LSTM层
- 损失函数, 优化器和准确性
- 建立图表和训练
- 测试
例如, LSTM是诸如未分段的, 连接的手写识别或语音识别之类的任务的应用程序。
一般的LSTM单元由一个单元, 一个输入门, 一个输出门和一个忘记门组成。单元会记住任意时间间隔内的值, 并且三个门控制着进出单元的信息流。 LSTM非常适合对未知持续时间给出的时间序列进行分类, 处理和预测。
长短期记忆(LSTM)网络是递归神经网络的修改版本, 可以更轻松地记住记忆中的过去数据。
文章图片
2.忘记门-从区块中发现要丢弃的细节。乙状结肠功能决定。它查看先前的状态(ht-1)和内容输入(Xt), 并为单元格状态Ct-1中的每个数字输出介于0(忽略此)和1(保留此)之间的数字。
文章图片
3.输出门-块的输入和存储器用于确定输出。 Sigmoid函数确定要让0或1允许的值。tanh函数决定要让0或1允许的值。tanh函数对所传递的值进行加权, 确定其重要性级别, 范围从-1到1并乘以输出为S形。
文章图片
文章图片
它代表一个完整的RNN单元, 该单元采用序列xi的当前输入, 并输出当前隐藏状态, 嗨, 将其传递给我们输入序列的下一个RNN单元。 LSTM单元的内部要比传统RNN单元复杂得多, 而常规RNN单元只有一个作用于当前状态(ht-1)和输入(xt)的” 内部层” 。
文章图片
在上图中, 我们看到了一个” 未展开的” LSTM网络, 该网络具有一个嵌入层, 一个后续的LSTM层和一个S型激活函数。我们认识到我们的输入(在这种情况下, 即电影评论中的单词)是按顺序输入的。
单词被输入到嵌入查找中。在大多数情况下, 使用文本数据集时, 词汇量异常大。
这是向量空间中单词的多维分布表示。可以使用其他深度学习技术(例如word2vec)来学习这些嵌入, 我们可以以端到端的方式训练模型, 以根据我们的教学确定嵌入。
然后将这些嵌入内容输入到我们的LSTM层中, 在该层中, 输出将被馈送到S型输出层和LSTM单元中, 用于序列中的下一个单词。
LSTM层 我们将建立一个函数来构建LSTM层, 以动态处理层数和大小。该服务将获取一个LSTM大小列表, 该列表可根据列表的长度指示LSTM层的数量(例如, 我们的示例将使用长度为2的列表, 其中包含大小为128和64的LSTM列表, 表示两层LSTM网络其中第一层尺寸128和第二层具有隐藏层尺寸64)。
def build_lstm_layers(lstm_sizes, embed, keep_prob_, batch_size):
"""
Create the LSTM layers
"""
lstms = [tf.contrib.rnn.BasicLSTMCell(size) for size in lstm_sizes]
# Add dropout to the cell
drops = [tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob_) for lstm in lstms]
# Stacking up multiple LSTM layers, for deep learning
cell = tf.contrib.rnn.MultiRNNCell(drops)
# Getting an initial state of all zeros
initial_state = cell.zero_state(batch_size, tf.float32)
lstm_outputs, final_state = tf.nn.dynamic_rnn(cell, embed, initial_state=initial_state)
然后, 将退出包装的LSTM列表传递到TensorFlow MultiRNN单元以将各层堆叠在一起。
损失函数, 优化器和准确性 最后, 我们创建函数来定义模型损失函数, 优化器和准确性。即使只是根据结果计算损失和准确性, 但在TensorFlow中, 所有内容都是计算图的一部分。
def build_cost_fn_and_opt(lstm_outputs, labels_, learning_rate):
"""
Creating the Loss function and Optimizer
"""
predictions = tf.contrib.layers.fully_connected(lstm_outputs[:, -1], 1, activation_fn=tf.sigmoid)
loss = tf.losses.mean_squared_error(labels_, predictions)
optimzer = tf.train.AdadeltaOptimizer (learning_rate).minimize(loss)
def build_accuracy(predictions, labels_):
"""
Create accuracy
"""
correct_pred = tf.equal(tf.cast(tf.round(predictions), tf.int32), labels_)
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
建立图表和训练 首先, 我们调用已定义的用于构建网络的每个函数, 并调用TensorFlow会话以使用迷你批处理在预定义数量的纪元上训练模型。在每个阶段结束时, 我们将打印损失, 训练准确性和验证准确性, 以在训练模型时监视结果。
def build_and_train_network(lstm_sizes, vocab_size, embed_size, epochs, batch_size, learning_rate, keep_prob, train_x, val_x, train_y, val_y):# Build Graphwith tf.Session() as sess:# Train Network
# Save Network
接下来, 我们定义模型超参数, 然后我们将构建一个两层LSTM网络, 其隐藏层大小分别为128和64。
完成模型训练后, 我们使用TensorFlow保护程序保存模型参数以供以后使用。
Epoch: 1/50 Batch: 303/303 Train Loss: 0.247 Train Accuracy: 0.562 Val Accuracy: 0.578
Epoch: 2/50 Batch: 303/303 Train Loss: 0.245 Train Accuracy: 0.583 Val Accuracy: 0.596
Epoch: 3/50 Batch: 303/303 Train Loss: 0.247 Train Accuracy: 0.597 Val Accuracy: 0.617
Epoch: 4/50 Batch: 303/303 Train Loss: 0.240 Train Accuracy: 0.610 Val Accuracy: 0.627
Epoch: 5/50 Batch: 303/303 Train Loss: 0.238 Train Accuracy: 0.620 Val Accuracy: 0.632
Epoch: 6/50 Batch: 303/303 Train Loss: 0.234 Train Accuracy: 0.632 Val Accuracy: 0.642
Epoch: 7/50 Batch: 303/303 Train Loss: 0.230 Train Accuracy: 0.636 Val Accuracy: 0.648
Epoch: 8/50 Batch: 303/303 Train Loss: 0.227 Train Accuracy: 0.641 Val Accuracy: 0.653
Epoch: 9/50 Batch: 303/303 Train Loss: 0.223 Train Accuracy: 0.646 Val Accuracy: 0.656
Epoch: 10/50 Batch: 303/303 Train Loss: 0.221 Train Accuracy: 0.652 Val Accuracy: 0.659
测试 最后, 我们在测试集中检查模型结果, 以确保它们与我们在训练中观察到的结果一致。
def test_network(model_dir, batch_size, test_x, test_y):# Build Networkwith tf.Session() as sess:# Restore Model
# Test Model
【Tensorflow中的长短期记忆(LSTM)RNN介绍和使用】测试精度为72%。这完全符合我们的验证准确性, 并表明我们在数据拆分期间以适当的数据分布捕获了数据。
INFO:tensorflow:Restoring parameters from checkpoints/sentiment.ckpt
Test Accuracy: 0.717
推荐阅读
- TensorFlow Gram矩阵原理介绍和用法示例
- TensorFlow中的样式传输解释和实例
- 加快android studio 编译速度
- 6.简单提取小红书app数据保存txt-2
- 品质创新,江铃控股携手华天软件CAPP系统决战SUV中高端市场
- Appium之编写H5应用测试脚本(切换到Webview)
- 解决android中EditText导致的内存泄漏问题
- BAT大厂APP架构演进实践与优化之路
- Android中的数据结构