当前位置: 代码网 > 科技>人工智能>机器学习 > 2024-2-24简易回归代码总结

2024-2-24简易回归代码总结

2024年08月01日 机器学习 我要评论
5. 数据提供函数,传入的参数即真实的x 和y,和一次取的数据量,随机取数据进行分析。2. create_data()创造批量数据,即真实的数据,torch.normal()返回一个随机的正态分布的tensor,接受三个参数,分别为mean均值,std方差,size张量大小,需注意的是此处表示张量大小为(data_num,len(w)),我的理解:这是由众多单笔数据(行向量)拼成的矩阵,其行数就是数据量,行向量中每个分量就是Xi的值,而具体有几个Xi,需要由len(w)告诉我们,有几个权重,就有几个Xi。
import torch
import matplotlib.pyplot as plt  #画图
import random

1. 导入三个库:torch:即pytorch,提供torch.tensor类,和对张量的相关操作,比如计算微分等, matplotlib:matplotlib是个用于绘制图表和数据可视化的库,而pyplot是其最常用的一个子模块,提供函数接口,用于绘制简单图形        random:提供各种处理随即问题的函数

def create_data(w,b,data_num):    #w b都为矩阵
    x=torch.normal(0,1,(data_num,len(w)))    #data_num为数据量,个体数,w为权重,矩阵相乘
    y=torch.matmul(x,w)+b     #matual返回两个矩阵乘积

    noise=torch.normal(0,0.01,y.shape)     #噪声加在y身上
    y+=noise

    return x,y

2. create_data()创造批量数据,即真实的数据,torch.normal()返回一个随机的正态分布的tensor,接受三个参数,分别为mean均值,std方差,size张量大小,需注意的是此处表示张量大小为(data_num,len(w)),我的理解:这是由众多单笔数据(行向量)拼成的矩阵,其行数就是数据量,行向量中每个分量就是xi的值,而具体有几个xi,需要由len(w)告诉我们,有几个权重,就有几个xi。torch.matmul()计算张量乘积,此为标准结果,再生成一个noise加在y身上,注意size部分可以直接用y.size表示

num=500
true_w=torch.tensor([8.1,2,2,4])
true_b=torch.tensor([1.1])
x,y=create_data(true_w,true_b,num)

3. 给出w和b 同时利用wb生成x y的真实值

plt.scatter(x[:,3],y,1)
plt.show() 

4. 画图部分:第一句即将x矩阵的第四列(从0开始计数)和y画在同一张图上,第一个参数为x轴,第二个参数为y轴,散点大小为1,考虑:两个向量是怎么映射成散点图的?:x的每个分量都代表了一个点的横坐标,y的每个分量就代表了每个散点的纵坐标,由此可以体现出x和y的一一对应的关系,至于为什么是x矩阵的一列,则是我们要求每个属性(xi)对于结果y的影响/关系。加上show才会将图像显示出来

def data_provider(data,label,batch_size):        #每次访问就提供一次数据
    length = len(label)
    indices=list(range(length))    #将长度转化为列表
    random.shuffle(indices)  #random.shuffle()是random库中的函数,将输入打乱,输入通常是列表
    for each in range(0,length,batch_size):
        get_indices=indices[each:each+batch_size]#取出
        get_data=data[get_indices]
        get_label=label[get_indices]

        yield get_data,get_label   #yield-->有存档点的return 再次访问时从此处开始
                                    #get_data是x的值,get_label是y的值

5. 数据提供函数,传入的参数即真实的x 和y,和一次取的数据量,随机取数据进行分析。前三句:得到label张量的长度,并且得到一个这么长的列表,用random里的库函数将indices列表打乱,后面:遍历这个下标列表,步进一个batch_size,得到一组16个下标,并取出下标对应的data和label,此时的get_data和get_label分别是16x4和16x1的张量    注:get_indices是切片操作,得到的是indices的一部分的列表,注意get_data=indices[get_indices]的这种操作,可以将切片来的列表直接用于再次切片       yield的作用如上

batch_size=16
for batch_x,batch_y in data_provider(x,y,batch_size):
    print(batch_x,batch_y)
    break

6. 查看所取的切片

def fun(x,w,b):
    pred_y=torch.matmul(x,w)+b
    return pred_y

7. 计算预测的y,x,w,b均为形参,实际传入的是batch_x,即取的样品,w_0,b_0,即自己选取的超参数

def maeloss(pred_y,y):
    loss=torch.sum(abs(pred_y-y))/len(y)
    return loss

8. 定义loss函数,此处我们选取的是maeloss,即差的绝对值相加再取平均,由于对张量进行操作,所有要用到库函数torch.sum(),将两张量的各分量相减之后相加再取平均值,其个数是len(y)

def sgd(paras,lr):      #梯度下降算法
    with torch.no_grad():   #接下来的计算不算梯度
        for para in paras:
            para-=lr*para.grad   #此处不能展开写
            para.grad.zero_()    #把梯度归零

9. 定义梯度下降算法:line48:任何参数的计算都会积累没用的梯度,所以我们要torch.no_grad(),保证在下面的操作不会计算梯度,这是个全局的设置,line49:注意写法,即使paras是张量,也可以用for直接遍历,line50:这里不能展开写,展开写相当于新建了一个para,但是这个para没有梯度属性。line51:将para的梯度归零,由于para计算梯度都是积累在具体参数里,所以要对已经计算出来的,正确的梯度进行归零

lr=0.03
w_0=torch.normal(0,0.01,true_w.shape,requires_grad=true)
b_0=torch.tensor(0.01,requires_grad=true)
print(w_0,b_0)
epochs=50

10. 定义学习率learning rate,同时通过随机生成,选取w_0,b_0也是随机选取的。requires_grad=true表示将张量放在张量网上,pytorch开始跟踪对该张量的操作,以便计算梯度。定义epoch的大小,表示训练多少轮

for epoch in range(epochs):
    data_loss=0
    for batch_x,batch_y in data_provider(x,y,batch_size):
        pred=fun(batch_x,w_0,b_0)
        loss=maeloss(pred,batch_y)
        loss.backward()
        sgd([w_0,b_0], lr)

        data_loss+=loss
        print("epoch %03d: loss=%.6f"%(epoch,data_loss))

print("原来的函数值:",true_w,true_b)
print(w_0,b_0)

 11. 第一个for循环:用data_loss来记录loss有没有变小,第二个for循环:用data_provider()函数选出批量资料batch_x,batch_y,用batch_x,w_0,b_0计算出预测值pred,用预测值和真实值batch_y定义出loss函数,loss.backward()的调用触发自动微分系统,计算了loss对各个参数的梯度,sgd()函数用于利用计算出的梯度更新参数,之后将每轮计算出的loss累加

idx=0
plt.plot(x[:,idx].detach().numpy(),x[:,idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy(),label="pred")
plt.scatter(x[:,idx],y,1)
plt.show()

12. 画图: idx表示要打印的列,在画图之前,要将w_0和b_0从张量网上取下,此处原代码里也将x取下,但是笔者认为x本来没有规定他的require g。rad,不必x.detach().numpy(),经实践后没报错。plot()用于绘制连续的曲线,在这里表示x和其预测值的直线,scatter()用于绘制散点图,这里表示真实值,show()将绘制的曲线显示出来。

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com