230523
This commit is contained in:
17
README.md
17
README.md
@@ -1,8 +1,23 @@
|
|||||||
# pytorch study
|
# pytorch study
|
||||||
|
|
||||||
## ENV
|
## BASE ENV
|
||||||
```shell
|
```shell
|
||||||
conda create -n pt python=3.10 -y
|
conda create -n pt python=3.10 -y
|
||||||
|
|
||||||
conda activate pt
|
conda activate pt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## MAC
|
||||||
|
```shell
|
||||||
|
# 安装 pytorch v1.12版本已经正式支持了用于mac m1芯片gpu加速的mps后端
|
||||||
|
conda install pytorch::pytorch torchvision torchaudio -c pytorch -y
|
||||||
|
|
||||||
|
pip install numpy
|
||||||
|
pip install pandas
|
||||||
|
pip install matplotlib
|
||||||
|
```
|
||||||
|
|
||||||
|
## gpt4free
|
||||||
|
```
|
||||||
|
pip install -U g4f[all]
|
||||||
|
```
|
||||||
15
autograd.py
Normal file
15
autograd.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import torch
|
||||||
|
from torch import autograd
|
||||||
|
|
||||||
|
device = torch.device('mps')
|
||||||
|
|
||||||
|
x = torch.tensor(1.)
|
||||||
|
a = torch.tensor(2., requires_grad=True)
|
||||||
|
b = torch.tensor(2., requires_grad=True)
|
||||||
|
c = torch.tensor(3., requires_grad=True)
|
||||||
|
|
||||||
|
y = a ** 2 * x + b * x + c ** 3
|
||||||
|
|
||||||
|
print('before:', a.grad, b.grad, c.grad)
|
||||||
|
grads = autograd.grad(y, [a, b, c])
|
||||||
|
print('after:', grads[0], grads[1], grads[2])
|
||||||
11
gpt.py
Normal file
11
gpt.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from g4f.client import Client
|
||||||
|
|
||||||
|
content = "张量在机器学习中的主要用途"
|
||||||
|
|
||||||
|
|
||||||
|
client = Client()
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=[{"role": "user", "content": content}],
|
||||||
|
)
|
||||||
|
print(response.choices[0].message.content)
|
||||||
56
linear regression/1.py
Normal file
56
linear regression/1.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def compute_error_for_line_given_points(b, w, points):
|
||||||
|
totalError = 0
|
||||||
|
N = float(len(points))
|
||||||
|
for i in range(len(points)):
|
||||||
|
x = points[i][0]
|
||||||
|
y = points[i][1]
|
||||||
|
totalError += (y - (w * x + b)) ** 2
|
||||||
|
return totalError / N
|
||||||
|
|
||||||
|
|
||||||
|
def step_gradient(b_current, w_current, points, learningRate):
|
||||||
|
b_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32)
|
||||||
|
w_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32)
|
||||||
|
N = float(len(points))
|
||||||
|
for i in range(len(points)):
|
||||||
|
x = points[i][0]
|
||||||
|
y = points[i][1]
|
||||||
|
b_gradient += -(2 / N) * (y - (w_current * x + b_current) + b_current)
|
||||||
|
w_gradient += -(2 / N) * x * (y - (w_current * x + b_current + b_current))
|
||||||
|
new_b = b_current - (learningRate * b_gradient)
|
||||||
|
new_w = w_current - (learningRate * w_gradient)
|
||||||
|
return [new_b, new_w]
|
||||||
|
|
||||||
|
|
||||||
|
def gradient_descent_runner(points, starting_b, starting_w, learningRate, num_iterations):
|
||||||
|
b = torch.tensor(starting_b, device=points.device, dtype=torch.float32)
|
||||||
|
w = torch.tensor(starting_w, device=points.device, dtype=torch.float32)
|
||||||
|
for i in range(num_iterations):
|
||||||
|
b, w = step_gradient(b, w, points, learningRate)
|
||||||
|
return [b, w]
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
# 修改为生成数据的文件路径
|
||||||
|
points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32)
|
||||||
|
points = torch.tensor(points_np, device='mps')
|
||||||
|
learning_rate = 0.0001 # 使用较小的学习率
|
||||||
|
initial_b = 0.0
|
||||||
|
initial_w = 0.0
|
||||||
|
num_iterations = 1000
|
||||||
|
print("Starting gradient descent at b={0},w={1},error={2}".format(initial_b, initial_w,
|
||||||
|
compute_error_for_line_given_points(initial_b,
|
||||||
|
initial_w,
|
||||||
|
points)))
|
||||||
|
print("running...")
|
||||||
|
[b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations)
|
||||||
|
print("After gradient descent at b={0},w={1},error={2}".format(b.item(), w.item(),
|
||||||
|
compute_error_for_line_given_points(b, w, points)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run()
|
||||||
100
linear regression/data1.csv
Normal file
100
linear regression/data1.csv
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
0.0,2.4360562173289115
|
||||||
|
0.10101010101010101,2.4288710820592065
|
||||||
|
0.20202020202020202,3.677943201375977
|
||||||
|
0.30303030303030304,3.5029234515863217
|
||||||
|
0.40404040404040403,4.007715980839878
|
||||||
|
0.5050505050505051,3.95999321461469
|
||||||
|
0.6060606060606061,3.220853916066527
|
||||||
|
0.7070707070707071,3.2211460206798623
|
||||||
|
0.8080808080808081,4.2516957270374505
|
||||||
|
0.9090909090909091,4.311826715084292
|
||||||
|
1.0101010101010102,4.153966583608258
|
||||||
|
1.1111111111111112,4.224290328721461
|
||||||
|
1.2121212121212122,4.551324602105953
|
||||||
|
1.3131313131313131,5.157200101408408
|
||||||
|
1.4141414141414141,5.199011258508288
|
||||||
|
1.5151515151515151,5.248911218901843
|
||||||
|
1.6161616161616161,5.789628423512512
|
||||||
|
1.7171717171717171,5.126592322934872
|
||||||
|
1.8181818181818181,4.546631494636344
|
||||||
|
1.9191919191919191,5.7260434379514065
|
||||||
|
2.0202020202020203,5.607446671816119
|
||||||
|
2.121212121212121,5.401744626671172
|
||||||
|
2.2222222222222223,5.568078510495838
|
||||||
|
2.323232323232323,6.136817713051054
|
||||||
|
2.4242424242424243,5.399802896696589
|
||||||
|
2.525252525252525,6.7465591899811415
|
||||||
|
2.6262626262626263,6.510002771256968
|
||||||
|
2.727272727272727,6.194107987238278
|
||||||
|
2.8282828282828283,6.280445605022811
|
||||||
|
2.929292929292929,6.413289184504817
|
||||||
|
3.0303030303030303,8.178951965980268
|
||||||
|
3.131313131313131,7.438933818741419
|
||||||
|
3.2323232323232323,8.161193108124682
|
||||||
|
3.3333333333333335,6.466487953447159
|
||||||
|
3.4343434343434343,7.6815673373443385
|
||||||
|
3.5353535353535355,7.412509123916619
|
||||||
|
3.6363636363636362,7.712231039046388
|
||||||
|
3.7373737373737375,7.512155302443977
|
||||||
|
3.8383838383838382,8.169468174953455
|
||||||
|
3.9393939393939394,8.201406255891817
|
||||||
|
4.040404040404041,9.413915839209679
|
||||||
|
4.141414141414141,7.2131607261403
|
||||||
|
4.242424242424242,8.244196707034996
|
||||||
|
4.343434343434343,8.059400613529792
|
||||||
|
4.444444444444445,9.127093042087843
|
||||||
|
4.545454545454545,8.232456814994503
|
||||||
|
4.646464646464646,9.026988954051767
|
||||||
|
4.747474747474747,8.936405824368308
|
||||||
|
4.848484848484849,8.838334259675397
|
||||||
|
4.94949494949495,9.717080564295035
|
||||||
|
5.05050505050505,9.635892324495916
|
||||||
|
5.151515151515151,10.802758752616178
|
||||||
|
5.252525252525253,9.889268431487773
|
||||||
|
5.353535353535354,9.262021983987134
|
||||||
|
5.454545454545454,9.905732041295009
|
||||||
|
5.555555555555555,9.697006564677089
|
||||||
|
5.656565656565657,10.435437946557755
|
||||||
|
5.757575757575758,10.257651825530608
|
||||||
|
5.858585858585858,11.394734709569004
|
||||||
|
5.959595959595959,10.872621683473387
|
||||||
|
6.0606060606060606,10.750944058491058
|
||||||
|
6.161616161616162,11.375400587831757
|
||||||
|
6.262626262626262,11.834436555701465
|
||||||
|
6.363636363636363,11.536088544119654
|
||||||
|
6.4646464646464645,11.261555999325722
|
||||||
|
6.565656565656566,12.529961808490153
|
||||||
|
6.666666666666667,12.19345219105891
|
||||||
|
6.767676767676767,11.950653180245155
|
||||||
|
6.8686868686868685,12.176773142948385
|
||||||
|
6.96969696969697,12.055083206520518
|
||||||
|
7.070707070707071,13.498633194384489
|
||||||
|
7.171717171717171,12.542518727882712
|
||||||
|
7.2727272727272725,13.318372269865769
|
||||||
|
7.373737373737374,12.542630883166513
|
||||||
|
7.474747474747475,12.93490122675753
|
||||||
|
7.575757575757575,14.4040220344926
|
||||||
|
7.6767676767676765,13.314367294113964
|
||||||
|
7.777777777777778,14.061236496574551
|
||||||
|
7.878787878787879,12.686346979737731
|
||||||
|
7.979797979797979,14.024375221983842
|
||||||
|
8.080808080808081,13.7042096336008
|
||||||
|
8.181818181818182,13.342730021126272
|
||||||
|
8.282828282828282,14.136548357864573
|
||||||
|
8.383838383838384,14.619569834949138
|
||||||
|
8.484848484848484,14.01453898823226
|
||||||
|
8.585858585858587,15.154877807203663
|
||||||
|
8.686868686868687,14.081910297898048
|
||||||
|
8.787878787878787,14.474564310016353
|
||||||
|
8.88888888888889,14.966525346412723
|
||||||
|
8.98989898989899,15.526107019435932
|
||||||
|
9.09090909090909,14.352357736719853
|
||||||
|
9.191919191919192,15.843742065895144
|
||||||
|
9.292929292929292,15.787083172159111
|
||||||
|
9.393939393939394,15.211828607109144
|
||||||
|
9.494949494949495,15.845176532492374
|
||||||
|
9.595959595959595,15.622518083688107
|
||||||
|
9.696969696969697,15.589081237426006
|
||||||
|
9.797979797979798,15.511248085690712
|
||||||
|
9.8989898989899,16.27050774059862
|
||||||
|
10.0,16.1105549896166
|
||||||
|
30
linear regression/np_genPoints.py
Normal file
30
linear regression/np_genPoints.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import numpy as np
|
||||||
|
import csv
|
||||||
|
|
||||||
|
# 定义回归方程参数
|
||||||
|
w = 1.35
|
||||||
|
b = 2.89
|
||||||
|
|
||||||
|
# 生成x值范围
|
||||||
|
x_min = 0
|
||||||
|
x_max = 10
|
||||||
|
|
||||||
|
# 生成100个在x轴附近的点
|
||||||
|
x = np.linspace(x_min, x_max, 100)
|
||||||
|
|
||||||
|
# 根据回归方程计算y值
|
||||||
|
y = w * x + b
|
||||||
|
|
||||||
|
# 添加一些噪声,使数据更真实
|
||||||
|
y += np.random.normal(scale=0.5, size=y.shape)
|
||||||
|
|
||||||
|
# 将x和y合并成一个二维数组
|
||||||
|
data = np.column_stack((x, y))
|
||||||
|
|
||||||
|
# 将数据保存到CSV文件
|
||||||
|
with open('data1.csv', 'w', newline='') as csvfile:
|
||||||
|
writer = csv.writer(csvfile)
|
||||||
|
# 写入表头
|
||||||
|
# writer.writerow(['x', 'y'])
|
||||||
|
# 写入数据
|
||||||
|
writer.writerows(data)
|
||||||
22
linear regression/plt_print.py
Normal file
22
linear regression/plt_print.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# 原始数据
|
||||||
|
points = np.genfromtxt("data1.csv", delimiter=',')
|
||||||
|
x = points[:, 0]
|
||||||
|
y = points[:, 1]
|
||||||
|
|
||||||
|
# 拟合直线
|
||||||
|
x_range = np.linspace(min(x), max(x), 100)
|
||||||
|
y_pred = 1.6455038785934448 * x_range + 1.827562689781189
|
||||||
|
|
||||||
|
# 绘图
|
||||||
|
plt.figure(figsize=(8, 6))
|
||||||
|
plt.scatter(x, y, color='blue', label='Original data')
|
||||||
|
plt.plot(x_range, y_pred, color='red', label='Fitted line')
|
||||||
|
plt.xlabel('X')
|
||||||
|
plt.ylabel('Y')
|
||||||
|
plt.title('Fitting a line to random data')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
plt.savefig('print1.png')
|
||||||
BIN
linear regression/print1.png
Normal file
BIN
linear regression/print1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 36 KiB |
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
pytorch::pytorch
|
||||||
|
torchvision
|
||||||
|
torchaudio
|
||||||
|
pandas
|
||||||
|
matplotlib
|
||||||
|
numpy
|
||||||
|
g4f
|
||||||
6
test/macTest.py
Normal file
6
test/macTest.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
print(torch.backends.mps.is_available())
|
||||||
|
print(torch.backends.mps.is_built())
|
||||||
|
|
||||||
|
print(torch.device("mps"))
|
||||||
29
test/performance.py
Normal file
29
test/performance.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import torch
|
||||||
|
import time
|
||||||
|
|
||||||
|
print(torch.__version__)
|
||||||
|
print(torch.backends.mps.is_available())
|
||||||
|
print(torch.cuda.is_available())
|
||||||
|
|
||||||
|
a = torch.randn(10000,1000)
|
||||||
|
b = torch.randn(1000,2000)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
c = torch.matmul(a, b)
|
||||||
|
t1 = time.time()
|
||||||
|
print(a.device,t1-t0,c.norm(2))
|
||||||
|
|
||||||
|
device = torch.device('mps')
|
||||||
|
|
||||||
|
a = a.to(device)
|
||||||
|
b = b.to(device)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
c = torch.matmul(a, b)
|
||||||
|
t1 = time.time()
|
||||||
|
print(a.device,t1-t0,c.norm(2))
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
c = torch.matmul(a, b)
|
||||||
|
t1 = time.time()
|
||||||
|
print(a.device,t1-t0,c.norm(2))
|
||||||
Reference in New Issue
Block a user