1
0
Files
AI-learning/Users/username/code/linear_regression_m1.py
2025-03-13 18:14:01 +08:00

41 lines
1.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
# 检查MPS可用性需要PyTorch 1.12+和macOS 12.3+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# 生成训练数据移动到MPS设备
X = torch.randn(1000, 2).to(device) # 1000个样本2个特征
y = X @ torch.tensor([2.0, -3.4], device=device) + 4 # 真实关系式
y += 0.01 * torch.randn(y.shape, device=device) # 添加噪声
# 定义模型必须继承nn.Module
class LinearRegression(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 1) # 输入2维输出1维
def forward(self, x):
return self.linear(x)
model = LinearRegression().to(device) # 将模型移至MPS设备
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# 训练循环
for epoch in range(500):
# 前向传播
outputs = model(X)
loss = criterion(outputs, y.unsqueeze(1))
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 50 == 0:
print(f'Epoch {epoch}, loss: {loss.item():.4f}')
# 输出最终参数
print("Learned weights:", model.linear.weight.data)
print("Learned bias:", model.linear.bias.data)