250314
This commit is contained in:
		
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -10,4 +10,5 @@
 | 
			
		||||
 | 
			
		||||
.vscode
 | 
			
		||||
 | 
			
		||||
lab/data
 | 
			
		||||
lab/data
 | 
			
		||||
lab/models
 | 
			
		||||
@@ -1,40 +0,0 @@
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
# 检查MPS可用性(需要PyTorch 1.12+和macOS 12.3+)
 | 
			
		||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
 | 
			
		||||
 | 
			
		||||
# 生成训练数据(移动到MPS设备)
 | 
			
		||||
X = torch.randn(1000, 2).to(device)  # 1000个样本,2个特征
 | 
			
		||||
y = X @ torch.tensor([2.0, -3.4], device=device) + 4  # 真实关系式
 | 
			
		||||
y += 0.01 * torch.randn(y.shape, device=device)  # 添加噪声
 | 
			
		||||
 | 
			
		||||
# 定义模型(必须继承nn.Module)
 | 
			
		||||
class LinearRegression(torch.nn.Module):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.linear = torch.nn.Linear(2, 1)  # 输入2维,输出1维
 | 
			
		||||
        
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.linear(x)
 | 
			
		||||
 | 
			
		||||
model = LinearRegression().to(device)  # 将模型移至MPS设备
 | 
			
		||||
criterion = torch.nn.MSELoss()
 | 
			
		||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
 | 
			
		||||
 | 
			
		||||
# 训练循环
 | 
			
		||||
for epoch in range(500):
 | 
			
		||||
    # 前向传播
 | 
			
		||||
    outputs = model(X)
 | 
			
		||||
    loss = criterion(outputs, y.unsqueeze(1))
 | 
			
		||||
    
 | 
			
		||||
    # 反向传播
 | 
			
		||||
    optimizer.zero_grad()
 | 
			
		||||
    loss.backward()
 | 
			
		||||
    optimizer.step()
 | 
			
		||||
    
 | 
			
		||||
    if epoch % 50 == 0:
 | 
			
		||||
        print(f'Epoch {epoch}, loss: {loss.item():.4f}')
 | 
			
		||||
 | 
			
		||||
# 输出最终参数
 | 
			
		||||
print("Learned weights:", model.linear.weight.data)
 | 
			
		||||
print("Learned bias:", model.linear.bias.data)
 | 
			
		||||
							
								
								
									
										819
									
								
								lab/8_FC-MNIST.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										819
									
								
								lab/8_FC-MNIST.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								lab/test/0.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								lab/test/0.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 28 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								lab/test/2.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								lab/test/2.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 28 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								lab/test/3.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								lab/test/3.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 27 KiB  | 
@@ -6,4 +6,5 @@ ipywidgets
 | 
			
		||||
jupyter
 | 
			
		||||
scikit-learn
 | 
			
		||||
mnist
 | 
			
		||||
graphviz
 | 
			
		||||
graphviz
 | 
			
		||||
tqdm
 | 
			
		||||
		Reference in New Issue
	
	Block a user