代码更多细节待更新。
目标: 掌握最小二乘法求解(无惩罚项的损失函数)、掌握加惩罚项(2范数)的损失函数优化、梯度下降法、共轭梯度法、理解过拟合、克服过拟合的方法(如加惩罚项、增加样本)
【数据结构与算法|哈工大《机器学习》最小二乘法曲线拟合——实验一】已完成的要求:
- 生成数据,加入噪声;
- 用高阶多项式函数拟合曲线;
- 用解析解求解两种loss的最优解(无正则项和有正则项)
4. 优化方法求解最优解(梯度下降,共轭梯度);
5. 用你得到的实验数据,解释过拟合。
6. 用不同数据量,不同超参数,不同的多项式阶数,比较实验效果。
7. 语言不限,可以用matlab,python。求解解析解时可以利用现成的矩阵求逆。梯度下降,共轭梯度要求自己求梯度,迭代优化自己写。不许用现成的平台,例如pytorch,tensorflow的自动微分工具。
这里的4,5功能等到9.30ddl过后放在最后的下载网址里,打包下载。
需要积分2分。
此文不会再更新
%author:hitwtj
clear all;
n = 2;
%取样频率
T = 2*pi;
%周期
step = (T / n)*0.1;
%采样步长
t = (0 : step : 2*T);
%取样t的函数值采样频率
y = sin(pi/2*t);
%取样y的值,产生频率 5Hz 的 sin 函数
figure(1);
plot(t,y,'b');
z1=0.3*randn(1,41);
%产生方差 N(0,0.12)高斯白噪声 (b=0.01/0.1/1) plot(x,z1,'b');
y2=y+z1;
%叠加高斯白噪声的正弦波
figure(2);
plot(t,y2,'ro');
%展示加过噪声后的图像。
figure(3);
%其中t时自变量值,y2是生成的采集点。[~,k]=size(t);
for n=1:9
X0=zeros(n+1,k);
for k0=1:k%构造矩阵X0
for n0=1:n+1
X0(n0,k0)=t(k0)^(n+1-n0);
end
end
X=X0';
ANSS=(X'*X)\X'*y2';
%自动解方程求逆了……
for i=1:n+1%answer矩阵存储每次求得的方程系数,按列存储
answer(i,n)=ANSS(i);
end
x0=0 : step : 2*T;
y0=ANSS(1)*x0.^n;
%根据求得的系数初始化并构造多项式方程
for num=2:1:n+1
y0=y0+ANSS(num)*x0.^(n+1-num);
end
subplot(3,3,n)
plot(t,y2,'*')
hold on
plot(x0,y0)
end
效果:
文章图片
文章图片
最小二乘法2.0版本 加入惩罚项后的代码:
%author:hitwtj
clear all;
n = 1;
%取样频率
T = 2*pi;
%周期
step = (T / n)*0.1;
%采样步长
t = (0 : step : 2*T);
%取样t的函数值采样频率
y = sin(pi/2*t);
%取样y的值,产生频率 5Hz 的 sin 函数
figure(1);
plot(t,y,'b');
z1=0.35*randn(1,21);
%产生方差 N(0,0.12)高斯白噪声 (b=0.01/0.1/1) plot(x,z1,'b');
y2=y+z1;
%叠加高斯白噪声的正弦波
figure(2);
plot(t,y2,'ro');
%展示加过噪声后的图像。
figure(3);
%其中t时自变量值,y2是生成的采集点。
%取测试样本%这里的机器学习结果存在x0,y0里面
%加过噪声的存在t,y2里
%原来的结果存在t,y里也就是测试样本
[~,k]=size(t);
w=zeros(1,50);
for n=1:9
X0=zeros(n+1,k);
for k0=1:k%构造矩阵X0
for n0=1:n+1
X0(n0,k0)=t(k0)^(n+1-n0);
end
end
X=X0';
ANSS=(X'*X)\X'*y2';
%自动解方程求逆了……
for i=1:n+1%answer矩阵存储每次求得的方程系数,按列存储
answer(i,n)=ANSS(i);
end
x0=0 : step : 2*T;
y0=ANSS(1)*x0.^n;
%根据求得的系数初始化并构造多项式方程
for num=2:1:n+1
y0=y0+ANSS(num)*x0.^(n+1-num);
end
subplot(3,3,n)
plot(t,y2,'*')
hold on
plot(x0,y0)
plot(t,y,'b');
%下面计算误差,存在w里
W0=1/2.*(y-y0).^2;
w(1,n)=sum(W0);
w1(1,n)=2*w(1,n)/21;
endfigure(4);
plot(x0,y0)
hold on
plot(t,y2,'*')
plot(t,y,'r');
%加入正则项
figure(5);
[~,k]=size(t);
wz=zeros(1,50);
p=2;
%λ参数
for n=1:9
X0z=zeros(n+1,k);
for k0=1:k%构造矩阵X0
for n0=1:n+1
X0z(n0,k0)=t(k0)^(n+1-n0);
end
end
Xz=X0z';
ANSSZ=(Xz'*Xz+p.*eye(n+1))\Xz'*y2';
%自动解方程求逆了……
for i=1:n+1%answer矩阵存储每次求得的方程系数,按列存储
answerz(i,n)=ANSSZ(i);
end
x0z=0 : step : 2*T;
y0z=ANSSZ(1)*x0z.^n;
%根据求得的系数初始化并构造多项式方程
for num=2:1:n+1
y0z=y0z+ANSSZ(num)*x0z.^(n+1-num);
end
subplot(3,3,n)
plot(t,y2,'*')
hold on
plot(x0z,y0z)
plot(t,y,'b');
%下面计算误差,存在w里
W0z=1/2.*(y-y0z).^2;
wz(1,n)=sum(W0z);
w1z(1,n)=2*wz(1,n)/21;
end
figure(4);
plot(x0z,y0z,'g');
红色线是原函数取样点的连线。
这里蓝色的线,无惩罚项;
绿色的,加了lambda为2的惩罚项,可以看到,
因为惩罚项加的太大,导致了退化。
这里需要调参,lambda需要修改,取0.01-0.5左右比较理想。
文章图片
可以清楚的看到学习函数的方均根值变大了。
文章图片
文章图片
方均根误差的计算方法:
文章图片
转载于:https://www.cnblogs.com/hitWTJ/p/9865414.html
推荐阅读
- 科技|更深的技术探索,更多的场景实践M etaCon元宇宙技术大会圆满召开
- 队列|数字人技术在直播场景下的应用
- 大数据|一文读懂元宇宙,AI、灵境计算...核心技术到人文生态
- 编程语言|Node 之父斥责 Oracle(你们也不用,那请交出 JavaScript 商标!)
- 分布式|随笔(分布式锁的一点思想)
- 错误|记录(There is no getter for property named ‘null‘ in ‘class)
- SQL|MVCC在重复读和读已提交场景以及幻读的解决
- 汽车|长城新能源汽车,战力已蓄满
- #|「SpringCloud」08 Config分布式配置中心