回归预测|回归预测 | MATLAB实现Attention-LSTM(注意力机制长短期记忆神经网络)多输入单输出

回归预测 | MATLAB实现Attention-LSTM(注意力机制长短期记忆神经网络)多输入单输出
目录

    • 回归预测 | MATLAB实现Attention-LSTM(注意力机制长短期记忆神经网络)多输入单输出
      • 基本介绍
      • 模型背景
        • LSTM模型
        • Attention-LSTM 模型
      • 数据下载
      • 程序设计
      • 参考资料
      • 致谢

基本介绍
本次运行测试环境MATLAB2020b;
文章针对LSTM 存在的局限性,提出了将Attention机制结合LSTM 神经网络的预测模型。采用多输入单输出回归预测,再将attention 机制与LSTM 结合作为预测模型,使预测模型增强了对关键时间序列的注意力。
模型背景
  • 由于LSTM 神经网络具有保存历史信息的功能,在处理长时间序列输入时相较于传统神经网络更为有效,于近几年取得了广泛的应用。长短时记忆神经网络最早是由Hochreite 和Schmidhuber 提出。
  • LSTM 神经网络克服了传统循环神经网络中存在的难以解决长期依赖性以及梯度消失和爆炸的问题。
  • 然而,采用传统的编码-解码器的LSTM模型在对输入序列学习时,模型会先将所有的输入序列编码成一个固定长度的向量,而解码过程则受限于该向量的表示,这也限制了LSTM 模型的性能。
  • 文章针对LSTM 存在的局限性,提出了将Attention机制结合LSTM 神经网络的预测模型,将attention 机制与LSTM 结合作为预测模型,使预测模型增强了对关键时间序列的注意力。
LSTM模型
  • 传统循环神经网络( RNN) 对短时间序列输入比较敏感,处理短时间序列的表现较好。但是当存在长期输入时,传统RNN 会出现在某一时刻之前的所有隐藏层状态在训练中都不会影响到权重数组W 的更新的情况。这就是所谓的梯度消失问题。
  • 由于传统RNN 神经网络在实际的应用中,存在着梯度爆炸或梯度消失的问题,因此传统RNN 不适合于解决长序列问题。为了解决RNN 存在的缺陷而出现的LSTM 神经网络近年来在语音识别、语言翻译以及图像处理等方面取得了广泛的应用,相较于传统循环神经网络( RNN) 有着其特别的优势。长短时记忆网络在原有RNN 的基础上,在隐藏层中额外加入了一个可以保存长期状态的单元C。LSTM单元内部结构如图1 所示。
    回归预测|回归预测 | MATLAB实现Attention-LSTM(注意力机制长短期记忆神经网络)多输入单输出
    文章图片
  • 与前馈神经网络类似,LSTM 网络的训练同样采用的是误差的反向传播算法( Back-propagation) ,因为LSTM 处理的是序列数据,所以在使用误差反向传播算法的时候需要将整个时间序列上的误差传播回来。当前LSTM 单元的状态会受到前一时刻LSTM 单元状态的影响。
  • 同时在误差反向传播计算时隐含层ht的误差不仅仅包含当前时刻t 的误差,也包括t 时刻之后所有时刻的误差,这就是误差基于时间反向传播算法的含义。
Attention-LSTM 模型
  • 传统的编码- 解码器( Encoder-Decoder) 模型在处理输入序列时,编码器Encoder 将输入序列Xt编码成固定长度的隐向量h,对隐向量赋予相同的权重。
  • 解码器Decoder 基于隐向量h 解码输出。当输入序列的长度增加时,分量的权重相同,模型对于输入序列Xt没有区分度,造成模型性能下降。
  • Attention 机制解决了此问题,Attention 是一种用于提升编码-解码模型效果的机制,其本质是模仿人在观察东西时大脑的思维活动。当某个场景经常在其中一部分出现重要的东西时,人脑就会进行学习,之后看到类似场景时注意力就会集中到该部分上。
  • 使模型对输入序列的不同时刻隐向量h 赋予了相对应的权重,按重要程度将隐向量合并为新的隐向量并输入到解码器Decoder。加入Attention 机制的Encoder-Decoder 模型如图2 所示。
    回归预测|回归预测 | MATLAB实现Attention-LSTM(注意力机制长短期记忆神经网络)多输入单输出
    文章图片
  • 标准的LSTM 采用的是传统编码- 解码器结构。输入LSTM 的数据序列无论长短都被编码成固定长度的向量表示。虽然LSTM 的记忆功能可以保存长期状态,但是在实际应用过程中,面对庞大的多维度,多变量数据集时不能很好地加以处理,在训练时模型可能会忽略某些重要的时序信息,造成模型的性能变差,影响预测精度。
  • 针对LSTM 自身存在的缺陷,文章在LSTM 的基础上引入了Attention 机制,目的是为了打破传统编码-解码器在编码过程中使用固定长度向量的限制,保留LSTM 编码器的中间状态,通过训练模型来对这些中间
    状态进行选择性学习。
  • 结合了Attention 机制的LSTM功率预测模型能够判断各输入时刻信息的重要程度,模型的训练效率得以提高。Attention 机制通过对LSTM 的输入特征赋予了不同的权重,突出了关键的影响因素,帮助LSTM 做出准确的判断,而且不会增加模型的计算和存储开销。
    回归预测|回归预测 | MATLAB实现Attention-LSTM(注意力机制长短期记忆神经网络)多输入单输出
    文章图片
数据下载
  • 下载地址:
  • https://mianbaoduo.com/o/bread/mbd-YZ2clJ5x
程序设计
  • 主程序
%% Attention_LSTM % 数据集,列为特征,行为样本数目 clc clear close all % 导入数据 load('./data.mat') data(1,:) =[]; % 训练集 y = data.demand(1:1000); x = data{1:1000,3:end}; [xnorm,xopt] = mapminmax(x',0,1); [ynorm,yopt] = mapminmax(y',0,1); x = x'; xnorm = xnorm(:,1:1000); ynorm = ynorm(1:1000); % 滞后长度 k = 24; % 转换成2-D image for i = 1:length(ynorm)-k Train_xNorm(:,i,:) = xnorm(:,i:i+k-1); Train_yNorm(i) = ynorm(i+k-1); end Train_yNorm= Train_yNorm'; % 测试集 ytest = data.demand(1001:1170); xtest = data{1001:1170,3:end}; [xtestnorm] = mapminmax('apply', xtest',xopt); [ytestnorm] = mapminmax('apply',ytest',yopt); xtest = xtest'; for i = 1:length(ytestnorm)-k Test_xNorm(:,i,:) = xtestnorm(:,i:i+k-1); Test_yNorm(i) = ytestnorm(i+k-1); Test_y(i) = ytest(i+k-1); end Test_yNorm = Test_yNorm'; clear k i x y % 自定义训练循环的深度学习数组 Train_xNorm = dlarray(Train_xNorm,'CBT'); Train_yNorm = dlarray(Train_yNorm,'BC'); Test_xNorm = dlarray(Test_xNorm,'CBT'); Test_yNorm = dlarray(Test_yNorm,'BC'); % 训练集和验证集划分 TrainSampleLength = length(Train_yNorm); validatasize = floor(TrainSampleLength * 0.1); Validata_xNorm = Train_xNorm(:,end - validatasize:end,:); Validata_yNorm = Train_yNorm(:,end-validatasize:end,:); Train_xNorm = Train_xNorm(:,1:end-validatasize,:); Train_yNorm = Train_yNorm(:,1:end-validatasize,:); %% 参数设定 %数据输入x的特征维度 inputSize = size(Train_xNorm,1); %数据输出y的维度 outputSize = 1; numhidden_units=50; % 导入初始化参数 [params,~] = paramsInit(numhidden_units,inputSize,outputSize); [~,validatastate] = paramsInit(numhidden_units,inputSize,outputSize); [~,TestState] = paramsInit(numhidden_units,inputSize,outputSize); % 训练相关参数 TrainOptions; numIterationsPerEpoch = floor((TrainSampleLength-validatasize)/minibatchsize); LearnRate = 0.01; %% 迭代更新 figure start = tic; lineLossTrain = animatedline('color','b'); validationLoss = animatedline('color','r','Marker','o'); xlabel("Iteration") ylabel("Loss") % epoch 更新 iteration = 0; for epoch = 1 : numEpochs [~,state] = paramsInit(numhidden_units,inputSize,outputSize); % 每轮epoch,state初始化 disp(['Epoch: ', int2str(epoch)]) % batch 更新 for i = 1 : numIterationsPerEpoch iteration = iteration + 1; disp(['Iteration: ', int2str(iteration)]) idx = (i-1)*minibatchsize+1:i*minibatchsize; dlX = gpuArray(Train_xNorm(:,idx,:)); dlY = gpuArray(Train_yNorm(idx)); [gradients,loss,state] = dlfeval(@ModelD,dlX,dlY,params,state); % L2正则化 L2regulationFactor = 0.001; [params,averageGrad,averageSqGrad] = adamupdate(params,gradients,averageGrad,averageSqGrad,iteration,LearnRate); % 验证集测试 if iteration == 1 || mod(iteration,validationFrequency) == 0 output_Ynorm = ModelPredict(gpuArray(Validata_xNorm),params,validatastate); lossValidation = mse(output_Ynorm, gpuArray(Validata_yNorm)); end % 作图(训练过程损失图) D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) if iteration == 1 || mod(iteration,validationFrequency) == 0 addpoints(validationLoss,iteration,double(gather(extractdata(lossValidation)))) end title("Epoch: " + epoch + ", Elapsed: " + string(D)) legend('训练集','验证集') drawnow end % 每轮epoch 更新学习率 if mod(epoch,10) == 0 LearnRate = LearnRate * LearnRateDropFactor; end end

  • 子函数下载地址:
  • https://mianbaoduo.com/o/bread/mbd-YZ2clJ5x
  • 预测效果:
    回归预测|回归预测 | MATLAB实现Attention-LSTM(注意力机制长短期记忆神经网络)多输入单输出
    文章图片

    回归预测|回归预测 | MATLAB实现Attention-LSTM(注意力机制长短期记忆神经网络)多输入单输出
    文章图片
参考资料
【回归预测|回归预测 | MATLAB实现Attention-LSTM(注意力机制长短期记忆神经网络)多输入单输出】[1] https://mianbaoduo.com/o/bread/mbd-YZ2clJ5x
[2] https://blog.csdn.net/kjm13182345320/article/details/120406657?spm=1001.2014.3001.5501
[3] https://blog.csdn.net/kjm13182345320/article/details/120377303?spm=1001.2014.3001.5501
致谢
  • 大家的支持是我写作的动力!
  • 感谢大家订阅,感谢,需要加Q-【1153460737】,记得备注!

    推荐阅读