230523
This commit is contained in:
17
README.md
17
README.md
@@ -1,8 +1,23 @@
|
||||
# pytorch study
|
||||
|
||||
## ENV
|
||||
## BASE ENV
|
||||
```shell
|
||||
conda create -n pt python=3.10 -y
|
||||
|
||||
conda activate pt
|
||||
```
|
||||
|
||||
## MAC
|
||||
```shell
|
||||
# 安装 pytorch v1.12版本已经正式支持了用于mac m1芯片gpu加速的mps后端
|
||||
conda install pytorch::pytorch torchvision torchaudio -c pytorch -y
|
||||
|
||||
pip install numpy
|
||||
pip install pandas
|
||||
pip install matplotlib
|
||||
```
|
||||
|
||||
## gpt4free
|
||||
```
|
||||
pip install -U g4f[all]
|
||||
```
|
||||
15
autograd.py
Normal file
15
autograd.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
from torch import autograd
|
||||
|
||||
device = torch.device('mps')
|
||||
|
||||
x = torch.tensor(1.)
|
||||
a = torch.tensor(2., requires_grad=True)
|
||||
b = torch.tensor(2., requires_grad=True)
|
||||
c = torch.tensor(3., requires_grad=True)
|
||||
|
||||
y = a ** 2 * x + b * x + c ** 3
|
||||
|
||||
print('before:', a.grad, b.grad, c.grad)
|
||||
grads = autograd.grad(y, [a, b, c])
|
||||
print('after:', grads[0], grads[1], grads[2])
|
||||
11
gpt.py
Normal file
11
gpt.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from g4f.client import Client
|
||||
|
||||
content = "张量在机器学习中的主要用途"
|
||||
|
||||
|
||||
client = Client()
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": content}],
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
56
linear regression/1.py
Normal file
56
linear regression/1.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def compute_error_for_line_given_points(b, w, points):
|
||||
totalError = 0
|
||||
N = float(len(points))
|
||||
for i in range(len(points)):
|
||||
x = points[i][0]
|
||||
y = points[i][1]
|
||||
totalError += (y - (w * x + b)) ** 2
|
||||
return totalError / N
|
||||
|
||||
|
||||
def step_gradient(b_current, w_current, points, learningRate):
|
||||
b_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32)
|
||||
w_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32)
|
||||
N = float(len(points))
|
||||
for i in range(len(points)):
|
||||
x = points[i][0]
|
||||
y = points[i][1]
|
||||
b_gradient += -(2 / N) * (y - (w_current * x + b_current) + b_current)
|
||||
w_gradient += -(2 / N) * x * (y - (w_current * x + b_current + b_current))
|
||||
new_b = b_current - (learningRate * b_gradient)
|
||||
new_w = w_current - (learningRate * w_gradient)
|
||||
return [new_b, new_w]
|
||||
|
||||
|
||||
def gradient_descent_runner(points, starting_b, starting_w, learningRate, num_iterations):
|
||||
b = torch.tensor(starting_b, device=points.device, dtype=torch.float32)
|
||||
w = torch.tensor(starting_w, device=points.device, dtype=torch.float32)
|
||||
for i in range(num_iterations):
|
||||
b, w = step_gradient(b, w, points, learningRate)
|
||||
return [b, w]
|
||||
|
||||
|
||||
def run():
|
||||
# 修改为生成数据的文件路径
|
||||
points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32)
|
||||
points = torch.tensor(points_np, device='mps')
|
||||
learning_rate = 0.0001 # 使用较小的学习率
|
||||
initial_b = 0.0
|
||||
initial_w = 0.0
|
||||
num_iterations = 1000
|
||||
print("Starting gradient descent at b={0},w={1},error={2}".format(initial_b, initial_w,
|
||||
compute_error_for_line_given_points(initial_b,
|
||||
initial_w,
|
||||
points)))
|
||||
print("running...")
|
||||
[b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations)
|
||||
print("After gradient descent at b={0},w={1},error={2}".format(b.item(), w.item(),
|
||||
compute_error_for_line_given_points(b, w, points)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
100
linear regression/data1.csv
Normal file
100
linear regression/data1.csv
Normal file
@@ -0,0 +1,100 @@
|
||||
0.0,2.4360562173289115
|
||||
0.10101010101010101,2.4288710820592065
|
||||
0.20202020202020202,3.677943201375977
|
||||
0.30303030303030304,3.5029234515863217
|
||||
0.40404040404040403,4.007715980839878
|
||||
0.5050505050505051,3.95999321461469
|
||||
0.6060606060606061,3.220853916066527
|
||||
0.7070707070707071,3.2211460206798623
|
||||
0.8080808080808081,4.2516957270374505
|
||||
0.9090909090909091,4.311826715084292
|
||||
1.0101010101010102,4.153966583608258
|
||||
1.1111111111111112,4.224290328721461
|
||||
1.2121212121212122,4.551324602105953
|
||||
1.3131313131313131,5.157200101408408
|
||||
1.4141414141414141,5.199011258508288
|
||||
1.5151515151515151,5.248911218901843
|
||||
1.6161616161616161,5.789628423512512
|
||||
1.7171717171717171,5.126592322934872
|
||||
1.8181818181818181,4.546631494636344
|
||||
1.9191919191919191,5.7260434379514065
|
||||
2.0202020202020203,5.607446671816119
|
||||
2.121212121212121,5.401744626671172
|
||||
2.2222222222222223,5.568078510495838
|
||||
2.323232323232323,6.136817713051054
|
||||
2.4242424242424243,5.399802896696589
|
||||
2.525252525252525,6.7465591899811415
|
||||
2.6262626262626263,6.510002771256968
|
||||
2.727272727272727,6.194107987238278
|
||||
2.8282828282828283,6.280445605022811
|
||||
2.929292929292929,6.413289184504817
|
||||
3.0303030303030303,8.178951965980268
|
||||
3.131313131313131,7.438933818741419
|
||||
3.2323232323232323,8.161193108124682
|
||||
3.3333333333333335,6.466487953447159
|
||||
3.4343434343434343,7.6815673373443385
|
||||
3.5353535353535355,7.412509123916619
|
||||
3.6363636363636362,7.712231039046388
|
||||
3.7373737373737375,7.512155302443977
|
||||
3.8383838383838382,8.169468174953455
|
||||
3.9393939393939394,8.201406255891817
|
||||
4.040404040404041,9.413915839209679
|
||||
4.141414141414141,7.2131607261403
|
||||
4.242424242424242,8.244196707034996
|
||||
4.343434343434343,8.059400613529792
|
||||
4.444444444444445,9.127093042087843
|
||||
4.545454545454545,8.232456814994503
|
||||
4.646464646464646,9.026988954051767
|
||||
4.747474747474747,8.936405824368308
|
||||
4.848484848484849,8.838334259675397
|
||||
4.94949494949495,9.717080564295035
|
||||
5.05050505050505,9.635892324495916
|
||||
5.151515151515151,10.802758752616178
|
||||
5.252525252525253,9.889268431487773
|
||||
5.353535353535354,9.262021983987134
|
||||
5.454545454545454,9.905732041295009
|
||||
5.555555555555555,9.697006564677089
|
||||
5.656565656565657,10.435437946557755
|
||||
5.757575757575758,10.257651825530608
|
||||
5.858585858585858,11.394734709569004
|
||||
5.959595959595959,10.872621683473387
|
||||
6.0606060606060606,10.750944058491058
|
||||
6.161616161616162,11.375400587831757
|
||||
6.262626262626262,11.834436555701465
|
||||
6.363636363636363,11.536088544119654
|
||||
6.4646464646464645,11.261555999325722
|
||||
6.565656565656566,12.529961808490153
|
||||
6.666666666666667,12.19345219105891
|
||||
6.767676767676767,11.950653180245155
|
||||
6.8686868686868685,12.176773142948385
|
||||
6.96969696969697,12.055083206520518
|
||||
7.070707070707071,13.498633194384489
|
||||
7.171717171717171,12.542518727882712
|
||||
7.2727272727272725,13.318372269865769
|
||||
7.373737373737374,12.542630883166513
|
||||
7.474747474747475,12.93490122675753
|
||||
7.575757575757575,14.4040220344926
|
||||
7.6767676767676765,13.314367294113964
|
||||
7.777777777777778,14.061236496574551
|
||||
7.878787878787879,12.686346979737731
|
||||
7.979797979797979,14.024375221983842
|
||||
8.080808080808081,13.7042096336008
|
||||
8.181818181818182,13.342730021126272
|
||||
8.282828282828282,14.136548357864573
|
||||
8.383838383838384,14.619569834949138
|
||||
8.484848484848484,14.01453898823226
|
||||
8.585858585858587,15.154877807203663
|
||||
8.686868686868687,14.081910297898048
|
||||
8.787878787878787,14.474564310016353
|
||||
8.88888888888889,14.966525346412723
|
||||
8.98989898989899,15.526107019435932
|
||||
9.09090909090909,14.352357736719853
|
||||
9.191919191919192,15.843742065895144
|
||||
9.292929292929292,15.787083172159111
|
||||
9.393939393939394,15.211828607109144
|
||||
9.494949494949495,15.845176532492374
|
||||
9.595959595959595,15.622518083688107
|
||||
9.696969696969697,15.589081237426006
|
||||
9.797979797979798,15.511248085690712
|
||||
9.8989898989899,16.27050774059862
|
||||
10.0,16.1105549896166
|
||||
|
30
linear regression/np_genPoints.py
Normal file
30
linear regression/np_genPoints.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import numpy as np
|
||||
import csv
|
||||
|
||||
# 定义回归方程参数
|
||||
w = 1.35
|
||||
b = 2.89
|
||||
|
||||
# 生成x值范围
|
||||
x_min = 0
|
||||
x_max = 10
|
||||
|
||||
# 生成100个在x轴附近的点
|
||||
x = np.linspace(x_min, x_max, 100)
|
||||
|
||||
# 根据回归方程计算y值
|
||||
y = w * x + b
|
||||
|
||||
# 添加一些噪声,使数据更真实
|
||||
y += np.random.normal(scale=0.5, size=y.shape)
|
||||
|
||||
# 将x和y合并成一个二维数组
|
||||
data = np.column_stack((x, y))
|
||||
|
||||
# 将数据保存到CSV文件
|
||||
with open('data1.csv', 'w', newline='') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
# 写入表头
|
||||
# writer.writerow(['x', 'y'])
|
||||
# 写入数据
|
||||
writer.writerows(data)
|
||||
22
linear regression/plt_print.py
Normal file
22
linear regression/plt_print.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 原始数据
|
||||
points = np.genfromtxt("data1.csv", delimiter=',')
|
||||
x = points[:, 0]
|
||||
y = points[:, 1]
|
||||
|
||||
# 拟合直线
|
||||
x_range = np.linspace(min(x), max(x), 100)
|
||||
y_pred = 1.6455038785934448 * x_range + 1.827562689781189
|
||||
|
||||
# 绘图
|
||||
plt.figure(figsize=(8, 6))
|
||||
plt.scatter(x, y, color='blue', label='Original data')
|
||||
plt.plot(x_range, y_pred, color='red', label='Fitted line')
|
||||
plt.xlabel('X')
|
||||
plt.ylabel('Y')
|
||||
plt.title('Fitting a line to random data')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig('print1.png')
|
||||
BIN
linear regression/print1.png
Normal file
BIN
linear regression/print1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 36 KiB |
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
pytorch::pytorch
|
||||
torchvision
|
||||
torchaudio
|
||||
pandas
|
||||
matplotlib
|
||||
numpy
|
||||
g4f
|
||||
6
test/macTest.py
Normal file
6
test/macTest.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import torch
|
||||
|
||||
print(torch.backends.mps.is_available())
|
||||
print(torch.backends.mps.is_built())
|
||||
|
||||
print(torch.device("mps"))
|
||||
29
test/performance.py
Normal file
29
test/performance.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import time
|
||||
|
||||
print(torch.__version__)
|
||||
print(torch.backends.mps.is_available())
|
||||
print(torch.cuda.is_available())
|
||||
|
||||
a = torch.randn(10000,1000)
|
||||
b = torch.randn(1000,2000)
|
||||
|
||||
t0 = time.time()
|
||||
c = torch.matmul(a, b)
|
||||
t1 = time.time()
|
||||
print(a.device,t1-t0,c.norm(2))
|
||||
|
||||
device = torch.device('mps')
|
||||
|
||||
a = a.to(device)
|
||||
b = b.to(device)
|
||||
|
||||
t0 = time.time()
|
||||
c = torch.matmul(a, b)
|
||||
t1 = time.time()
|
||||
print(a.device,t1-t0,c.norm(2))
|
||||
|
||||
t0 = time.time()
|
||||
c = torch.matmul(a, b)
|
||||
t1 = time.time()
|
||||
print(a.device,t1-t0,c.norm(2))
|
||||
Reference in New Issue
Block a user