使用pytorch实现线性回归(很基础模型搭建详解)
                    
                 
                
使用pytorch实现线性回归步骤:1.prepare dataset2.design model using Class 目的是为了前向传播forward,即计算y hat(预测值)3.Construct loss and optimizer (using pytorch API) 其中计算loss是为了进行反向传播,optimizer是为了更新梯度4.Train Cycleimport torchx_da                
                
                
                
                
                
                    
                     
import torch
 
x_data = torch.tensor([[1.0],[2.0],[3.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0]])
 
 
# 实例化模型
class linearmodel(torch.nn.module):
    def __init__(self):
        super(linearmodel, self).__init__() # 调用父类的构造
        self.linear = torch.nn.linear(1,1) # 实例化类,构造对象,包含了权重和偏置
        # linear也是继承自module的,也能进行反向传播
        # nn:neural network
    def forward(self,x):
        y_pred = self.linear(x)
        return y_pred
    
model = linearmodel()
 
 
# 定义损失函数
criterion = torch.nn.mseloss(size_average=false)
# 定义优化器
optimizer = torch.optim.sgd(model.parameters(), lr = 0.01)
 
# 训练
for epoch in range(100):
    y_pred = model(x_data) # 计算y hat
    loss = criterion(y_pred,y_data) # 计算loss
    print(epoch,loss.item())
    
    # optimizer.zero_grad() # 梯度归零
    loss.backward() # 反向传播,计算梯度
    optimizer.step() # update 参数,即更新w和b的值
    optimizer.zero_grad() # 梯度归零
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred',y_test.data)
 
                
                    
                 
                    
                
                
                
                
                
                    相关文章:
                    
                        
                        
                                
                                
                                - 
                                    
                                
- 
                                    
                                
- 
                                    
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                        
                                            
PyTorch是一个基于python的科学计算包,主要针对两类人群:作为NumPy的替代品,可以利用GPU的性能进行计算作为一个高灵活性、速度快的深度学习平台在PyTorch中搭建…
                                             
 
 
- 
                                    
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                        
                                            
地理数据的统计处理,双变量相关分析,主成分分析,因子分析,多元线性回归,聚类分析,时间序列分析,地统计分析,趋势面分析,马尔可夫分析…
                                             
 
 
- 
                                    
                                        
                                        
                                            
                                            
                                            
                                            
                                        
                                        
                                            
本文对国内外部分交通数据集进行了介绍、对相关参数的进行了说明。…
                                             
 
 
- 
                                    
                                
 
                
                
                
                
                
                    
                        版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。
                        如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。
                    
                 
                
             
        
发表评论