1
0
This commit is contained in:
2024-12-28 12:50:41 +08:00
parent 150cd78949
commit 6361618c8e
13 changed files with 289 additions and 5760 deletions

File diff suppressed because one or more lines are too long

View File

@@ -1,310 +0,0 @@
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from matplotlib import pyplot as plt
def run1():
def compute_error_for_line_given_points(b, w, points):
totalError = 0
N = float(len(points))
for i in range(len(points)):
x = points[i][0]
y = points[i][1]
totalError += (y - (w * x + b)) ** 2
return totalError / N
def step_gradient(b_current, w_current, points, learningRate):
b_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32)
w_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32)
N = float(len(points))
for i in range(len(points)):
x = points[i][0]
y = points[i][1]
b_gradient += -(2 / N) * (y - (w_current * x + b_current))
w_gradient += -(2 / N) * x * (y - (w_current * x + b_current))
new_b = b_current - (learningRate * b_gradient)
new_w = w_current - (learningRate * w_gradient)
return [new_b, new_w]
def gradient_descent_runner(points, starting_b, starting_w, learningRate, num_iterations):
b = torch.tensor(starting_b, device=points.device, dtype=torch.float32)
w = torch.tensor(starting_w, device=points.device, dtype=torch.float32)
for i in range(num_iterations):
b, w = step_gradient(b, w, points, learningRate)
return [b, w]
def run():
# 修改为生成数据的文件路径
points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32)
points = torch.tensor(points_np, device='mps')
learning_rate = 0.0001 # 使用较小的学习率
initial_b = 0.0
initial_w = 0.0
num_iterations = 1000
print("Starting gradient descent at b={0},w={1},error={2}".format(initial_b, initial_w,
compute_error_for_line_given_points(initial_b,
initial_w,
points)))
print("running...")
[b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations)
print("After gradient descent at b={0},w={1},error={2}".format(b.item(), w.item(),
compute_error_for_line_given_points(b, w,
points)))
run()
def run1_cuda():
def compute_error_for_line_given_points(b, w, points):
totalError = 0
N = float(len(points))
for i in range(len(points)):
x = points[i][0]
y = points[i][1]
totalError += (y - (w * x + b)) ** 2
return totalError / N
def step_gradient(b_current, w_current, points, learningRate):
b_gradient = torch.tensor(0.0, device=points.device)
w_gradient = torch.tensor(0.0, device=points.device)
N = float(len(points))
for i in range(len(points)):
x = points[i][0]
y = points[i][1]
b_gradient += -(2 / N) * (y - (w_current * x + b_current))
w_gradient += -(2 / N) * x * (y - (w_current * x + b_current))
new_b = b_current - (learningRate * b_gradient)
new_w = w_current - (learningRate * w_gradient)
return [new_b, new_w]
def gradient_descent_runner(points, starting_b, starting_w, learningRate, num_iterations):
b = torch.tensor(starting_b, device=points.device)
w = torch.tensor(starting_w, device=points.device)
for i in range(num_iterations):
b, w = step_gradient(b, w, points, learningRate)
print("round:", i)
return [b, w]
def run():
points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32)
points = torch.tensor(points_np, device='cuda')
learning_rate = 0.0001
initial_b = 0.0
initial_w = 0.0
num_iterations = 100000
[b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations)
print("After gradient descent at b={0}, w={1}, error={2}".format(b.item(), w.item(),
compute_error_for_line_given_points(b, w,
points)))
return b.item(), w.item()
# 运行线性回归
final_b, final_w = run()
# 绘制图像
points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32)
x = points_np[:, 0]
y = points_np[:, 1]
x_range = np.linspace(min(x), max(x), 100)
y_pred = final_w * x_range + final_b
plt.figure(figsize=(8, 6))
plt.scatter(x, y, color='blue', label='Original data')
plt.plot(x_range, y_pred, color='red', label='Fitted line')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Fitting a line to random data')
plt.legend()
plt.grid(True)
plt.savefig('print1.png')
plt.show()
def run1x():
# 线性回归训练代码
def compute_error_for_line_given_points(b, w, points):
totalError = 0
N = float(len(points))
for i in range(len(points)):
x = points[i][0]
y = points[i][1]
totalError += (y - (w * x + b)) ** 2
return totalError / N
def step_gradient(b_current, w_current, points, learningRate):
b_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32)
w_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32)
N = float(len(points))
for i in range(len(points)):
x = points[i][0]
y = points[i][1]
b_gradient += -(2 / N) * (y - (w_current * x + b_current))
w_gradient += -(2 / N) * x * (y - (w_current * x + b_current))
new_b = b_current - (learningRate * b_gradient)
new_w = w_current - (learningRate * w_gradient)
return [new_b, new_w]
def gradient_descent_runner(points, starting_b, starting_w, learningRate, num_iterations):
b = torch.tensor(starting_b, device=points.device, dtype=torch.float32)
w = torch.tensor(starting_w, device=points.device, dtype=torch.float32)
for i in range(num_iterations):
b, w = step_gradient(b, w, points, learningRate)
return [b, w]
def run():
points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32)
points = torch.tensor(points_np, device='mps')
learning_rate = 0.0001
initial_b = 0.0
initial_w = 0.0
num_iterations = 5000
[b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations)
print("After gradient descent at b={0},w={1},error={2}".format(b.item(), w.item(),
compute_error_for_line_given_points(b, w,
points)))
return b.item(), w.item()
# 运行线性回归
final_b, final_w = run()
# 绘制图像
points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32)
x = points_np[:, 0]
y = points_np[:, 1]
x_range = np.linspace(min(x), max(x), 100)
y_pred = final_w * x_range + final_b
plt.figure(figsize=(8, 6))
plt.scatter(x, y, color='blue', label='Original data')
plt.plot(x_range, y_pred, color='red', label='Fitted line')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Fitting a line to random data')
plt.legend()
plt.grid(True)
plt.savefig('print1.png')
plt.show()
def run_m1():
# 检查是否支持MPSApple Metal Performance Shaders
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"使用设备: {device}")
# 生成示例数据
# y = 3x + 2 + 噪声
torch.manual_seed(0)
X = torch.linspace(-10, 10, steps=100).reshape(-1, 1)
y = 3 * X + 2 + torch.randn(X.size()) * 2
# 创建数据集和数据加载器
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
# 定义线性回归模型
class LinearRegressionModel(nn.Module):
def __init__(self):
super(LinearRegressionModel, self).__init__()
self.linear = nn.Linear(1, 1) # 输入和输出都是1维
def forward(self, x):
return self.linear(x)
# 实例化模型并移动到设备
model = LinearRegressionModel().to(device)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
for batch_X, batch_y in dataloader:
batch_X = batch_X.to(device)
batch_y = batch_y.to(device)
# 前向传播
outputs = model(batch_X)
loss = criterion(outputs, batch_y)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
# 保存整个模型
torch.save(model.state_dict(), 'm1.pth')
print("整个模型已保存为 m1.pth")
# 评估模型
model.eval()
with torch.no_grad():
X_test = torch.linspace(-10, 10, steps=100).reshape(-1, 1).to(device)
y_pred = model(X_test).cpu()
plt.scatter(X.numpy(), y.numpy(), label='真实数据')
plt.plot(X_test.cpu().numpy(), y_pred.numpy(), color='red', label='预测线')
plt.legend()
plt.xlabel('X')
plt.ylabel('y')
plt.title('线性回归结果')
plt.show()
def run_m1_test():
# 定义线性回归模型结构
class LinearRegressionModel(nn.Module):
def __init__(self):
super(LinearRegressionModel, self).__init__()
self.linear = nn.Linear(1, 1) # 输入和输出都是1维
def forward(self, x):
return self.linear(x)
def main():
# 检查是否支持MPSApple Metal Performance Shaders
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"使用设备: {device}")
# 实例化模型并加载保存的模型参数
model = LinearRegressionModel().to(device)
model.load_state_dict(torch.load('m1.pth'))
with open('m1.pth', 'rb') as f:
f.seek(0, 2)
size = f.tell()
print(f"模型文件大小: {size} 字节")
model.eval()
# 输出模型大小
model_size = sum(p.numel() for p in model.parameters())
print(f"模型大小: {model_size} 个参数")
print("模型参数已加载")
# 生成测试数据
X_test = torch.linspace(-10, 10, steps=100).reshape(-1, 1).to(device)
# 使用加载的模型进行预测
with torch.no_grad():
y_pred = model(X_test).cpu()
# 将测试数据移至CPU并转换为NumPy数组
X_test_numpy = X_test.cpu().numpy()
y_pred_numpy = y_pred.numpy()
# 可视化预测结果
plt.scatter(X_test_numpy, 3 * X_test_numpy + 2, label='真实线性关系', color='blue')
plt.plot(X_test_numpy, y_pred_numpy, color='red', label='模型预测线')
plt.legend()
plt.xlabel('X')
plt.ylabel('y')
plt.title('加载模型后的线性回归预测结果')
plt.show()
main()
if __name__ == '__main__':
print("start")

View File

@@ -1,100 +0,0 @@
0.0,3.267264598063918
0.10101010101010101,3.3715980311448757
0.20202020202020202,3.624529195538264
0.30303030303030304,2.2426026435865785
0.40404040404040403,2.881354128303834
0.5050505050505051,4.108915299601898
0.6060606060606061,3.2833072841616344
0.7070707070707071,3.401314666490608
0.8080808080808081,3.4471224977820083
0.9090909090909091,4.597483332850038
1.0101010101010102,4.1948230194917615
1.1111111111111112,4.770110614428998
1.2121212121212122,4.3466984672473545
1.3131313131313131,4.085374736788284
1.4141414141414141,4.860667770156386
1.5151515151515151,5.367460741298345
1.6161616161616161,5.1076464505556585
1.7171717171717171,4.517380143483942
1.8181818181818181,6.028333717306668
1.9191919191919191,5.268642728341781
2.0202020202020203,5.2032646463511885
2.121212121212121,5.776924577040542
2.2222222222222223,5.914239664440679
2.323232323232323,6.195294604152318
2.4242424242424243,6.67461745554651
2.525252525252525,6.62895898059055
2.6262626262626263,6.423496434474387
2.727272727272727,6.520626001853953
2.8282828282828283,6.252851138402289
2.929292929292929,7.045662416151556
3.0303030303030303,7.062687254815803
3.131313131313131,6.950015155958233
3.2323232323232323,7.71420449451215
3.3333333333333335,7.536987534120887
3.4343434343434343,8.408446293914908
3.5353535353535355,8.281116903817127
3.6363636363636362,6.862064335470844
3.7373737373737375,8.455114086555362
3.8383838383838382,8.610569256326439
3.9393939393939394,8.695172603505283
4.040404040404041,7.987174933011048
4.141414141414141,8.484042583282307
4.242424242424242,8.152218590549857
4.343434343434343,8.810112829362456
4.444444444444445,9.098520210970904
4.545454545454545,9.315991463976044
4.646464646464646,9.266562387635002
4.747474747474747,8.457655714255173
4.848484848484849,8.577190426286784
4.94949494949495,9.992687218959654
5.05050505050505,9.949888900251127
5.151515151515151,10.112370557219064
5.252525252525253,10.250084050804231
5.353535353535354,9.675169646286898
5.454545454545454,9.790565255890696
5.555555555555555,9.91666488079517
5.656565656565657,10.325538746448835
5.757575757575758,9.77548528051785
5.858585858585858,10.55371401462777
5.959595959595959,10.757696722894282
6.0606060606060606,10.893354131765726
6.161616161616162,12.049342074375708
6.262626262626262,10.936118426966079
6.363636363636363,11.031578580287063
6.4646464646464645,11.713471927909302
6.565656565656566,12.343664117608101
6.666666666666667,12.067856620638729
6.767676767676767,11.814430199711552
6.8686868686868685,11.123516182999314
6.96969696969697,12.496962644202316
7.070707070707071,12.767487737755147
7.171717171717171,12.632934104476211
7.2727272727272725,12.728225932468364
7.373737373737374,12.97630338885533
7.474747474747475,12.896220727223701
7.575757575757575,13.047808359849581
7.6767676767676765,13.110443597152527
7.777777777777778,12.921358752181128
7.878787878787879,13.30038615173782
7.979797979797979,13.836945395153705
8.080808080808081,13.054484897082014
8.181818181818182,14.01038452336861
8.282828282828282,13.643336312636018
8.383838383838384,14.564671365817466
8.484848484848484,14.040540515755758
8.585858585858587,14.57992522742261
8.686868686868687,14.88631275019171
8.787878787878787,14.021963220008606
8.88888888888889,15.068155050128949
8.98989898989899,15.083538874549268
9.09090909090909,15.417748308319911
9.191919191919192,14.89714205401168
9.292929292929292,14.534676091762206
9.393939393939394,15.556467883324295
9.494949494949495,15.525938847099864
9.595959595959595,15.560767751324764
9.696969696969697,15.982790914773943
9.797979797979798,16.062079721169738
9.8989898989899,16.232818049890696
10.0,17.053472736980353
1 0.0 3.267264598063918
2 0.10101010101010101 3.3715980311448757
3 0.20202020202020202 3.624529195538264
4 0.30303030303030304 2.2426026435865785
5 0.40404040404040403 2.881354128303834
6 0.5050505050505051 4.108915299601898
7 0.6060606060606061 3.2833072841616344
8 0.7070707070707071 3.401314666490608
9 0.8080808080808081 3.4471224977820083
10 0.9090909090909091 4.597483332850038
11 1.0101010101010102 4.1948230194917615
12 1.1111111111111112 4.770110614428998
13 1.2121212121212122 4.3466984672473545
14 1.3131313131313131 4.085374736788284
15 1.4141414141414141 4.860667770156386
16 1.5151515151515151 5.367460741298345
17 1.6161616161616161 5.1076464505556585
18 1.7171717171717171 4.517380143483942
19 1.8181818181818181 6.028333717306668
20 1.9191919191919191 5.268642728341781
21 2.0202020202020203 5.2032646463511885
22 2.121212121212121 5.776924577040542
23 2.2222222222222223 5.914239664440679
24 2.323232323232323 6.195294604152318
25 2.4242424242424243 6.67461745554651
26 2.525252525252525 6.62895898059055
27 2.6262626262626263 6.423496434474387
28 2.727272727272727 6.520626001853953
29 2.8282828282828283 6.252851138402289
30 2.929292929292929 7.045662416151556
31 3.0303030303030303 7.062687254815803
32 3.131313131313131 6.950015155958233
33 3.2323232323232323 7.71420449451215
34 3.3333333333333335 7.536987534120887
35 3.4343434343434343 8.408446293914908
36 3.5353535353535355 8.281116903817127
37 3.6363636363636362 6.862064335470844
38 3.7373737373737375 8.455114086555362
39 3.8383838383838382 8.610569256326439
40 3.9393939393939394 8.695172603505283
41 4.040404040404041 7.987174933011048
42 4.141414141414141 8.484042583282307
43 4.242424242424242 8.152218590549857
44 4.343434343434343 8.810112829362456
45 4.444444444444445 9.098520210970904
46 4.545454545454545 9.315991463976044
47 4.646464646464646 9.266562387635002
48 4.747474747474747 8.457655714255173
49 4.848484848484849 8.577190426286784
50 4.94949494949495 9.992687218959654
51 5.05050505050505 9.949888900251127
52 5.151515151515151 10.112370557219064
53 5.252525252525253 10.250084050804231
54 5.353535353535354 9.675169646286898
55 5.454545454545454 9.790565255890696
56 5.555555555555555 9.91666488079517
57 5.656565656565657 10.325538746448835
58 5.757575757575758 9.77548528051785
59 5.858585858585858 10.55371401462777
60 5.959595959595959 10.757696722894282
61 6.0606060606060606 10.893354131765726
62 6.161616161616162 12.049342074375708
63 6.262626262626262 10.936118426966079
64 6.363636363636363 11.031578580287063
65 6.4646464646464645 11.713471927909302
66 6.565656565656566 12.343664117608101
67 6.666666666666667 12.067856620638729
68 6.767676767676767 11.814430199711552
69 6.8686868686868685 11.123516182999314
70 6.96969696969697 12.496962644202316
71 7.070707070707071 12.767487737755147
72 7.171717171717171 12.632934104476211
73 7.2727272727272725 12.728225932468364
74 7.373737373737374 12.97630338885533
75 7.474747474747475 12.896220727223701
76 7.575757575757575 13.047808359849581
77 7.6767676767676765 13.110443597152527
78 7.777777777777778 12.921358752181128
79 7.878787878787879 13.30038615173782
80 7.979797979797979 13.836945395153705
81 8.080808080808081 13.054484897082014
82 8.181818181818182 14.01038452336861
83 8.282828282828282 13.643336312636018
84 8.383838383838384 14.564671365817466
85 8.484848484848484 14.040540515755758
86 8.585858585858587 14.57992522742261
87 8.686868686868687 14.88631275019171
88 8.787878787878787 14.021963220008606
89 8.88888888888889 15.068155050128949
90 8.98989898989899 15.083538874549268
91 9.09090909090909 15.417748308319911
92 9.191919191919192 14.89714205401168
93 9.292929292929292 14.534676091762206
94 9.393939393939394 15.556467883324295
95 9.494949494949495 15.525938847099864
96 9.595959595959595 15.560767751324764
97 9.696969696969697 15.982790914773943
98 9.797979797979798 16.062079721169738
99 9.8989898989899 16.232818049890696
100 10.0 17.053472736980353

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -1,30 +0,0 @@
import numpy as np
import csv
# 定义回归方程参数
w = 1.35
b = 2.89
# 生成x值范围
x_min = 0
x_max = 10
# 生成100个在x轴附近的点
x = np.linspace(x_min, x_max, 100)
# 根据回归方程计算y值
y = w * x + b
# 添加一些噪声,使数据更真实
y += np.random.normal(scale=0.5, size=y.shape)
# 将x和y合并成一个二维数组
data = np.column_stack((x, y))
# 将数据保存到CSV文件
with open('data1.csv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
# 写入表头
# writer.writerow(['x', 'y'])
# 写入数据
writer.writerows(data)

View File

@@ -1,23 +0,0 @@
import numpy as np
import matplotlib.pyplot as plt
# 原始数据
points = np.genfromtxt("data1.csv", delimiter=',')
x = points[:, 0]
y = points[:, 1]
# 拟合直线
x_range = np.linspace(min(x), max(x), 100)
y_pred = 0.3880246877670288 * x_range + 1.7214288711547852
# 绘图
plt.figure(figsize=(8, 6))
plt.scatter(x, y, color='blue', label='Original data')
plt.plot(x_range, y_pred, color='red', label='Fitted line')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Fitting a line to random data')
plt.legend()
plt.grid(True)
plt.savefig('print1.png')

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

17
test/outputtest.py Normal file
View File

@@ -0,0 +1,17 @@
import numpy as np
# 创建一个二维数组
array = np.array([[1, 2, 3], [4, 5, 6]])
# 对整个数组求和
total_sum = np.sum(array)
# 对每一列求和
column_sum = np.sum(array, axis=0)
# 对每一行求和
row_sum = np.sum(array, axis=1)
print("总和:", total_sum)
print("列和:", column_sum)
print("行和:", row_sum)

File diff suppressed because one or more lines are too long

Binary file not shown.

Binary file not shown.

Binary file not shown.