diff --git a/README.md b/README.md index 6172939..e87be84 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,23 @@ # pytorch study -## ENV +## BASE ENV ```shell conda create -n pt python=3.10 -y conda activate pt +``` + +## MAC +```shell +# 安装 pytorch v1.12版本已经正式支持了用于mac m1芯片gpu加速的mps后端 +conda install pytorch::pytorch torchvision torchaudio -c pytorch -y + +pip install numpy +pip install pandas +pip install matplotlib +``` + +## gpt4free +``` +pip install -U g4f[all] ``` \ No newline at end of file diff --git a/autograd.py b/autograd.py new file mode 100644 index 0000000..21556e4 --- /dev/null +++ b/autograd.py @@ -0,0 +1,15 @@ +import torch +from torch import autograd + +device = torch.device('mps') + +x = torch.tensor(1.) +a = torch.tensor(2., requires_grad=True) +b = torch.tensor(2., requires_grad=True) +c = torch.tensor(3., requires_grad=True) + +y = a ** 2 * x + b * x + c ** 3 + +print('before:', a.grad, b.grad, c.grad) +grads = autograd.grad(y, [a, b, c]) +print('after:', grads[0], grads[1], grads[2]) diff --git a/gpt.py b/gpt.py new file mode 100644 index 0000000..4a7a0a2 --- /dev/null +++ b/gpt.py @@ -0,0 +1,11 @@ +from g4f.client import Client + +content = "张量在机器学习中的主要用途" + + +client = Client() +response = client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": content}], +) +print(response.choices[0].message.content) \ No newline at end of file diff --git a/linear regression/1.py b/linear regression/1.py new file mode 100644 index 0000000..dc455ef --- /dev/null +++ b/linear regression/1.py @@ -0,0 +1,56 @@ +import numpy as np +import torch + + +def compute_error_for_line_given_points(b, w, points): + totalError = 0 + N = float(len(points)) + for i in range(len(points)): + x = points[i][0] + y = points[i][1] + totalError += (y - (w * x + b)) ** 2 + return totalError / N + + +def step_gradient(b_current, w_current, points, learningRate): + b_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32) + w_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32) + N = float(len(points)) + for i in range(len(points)): + x = points[i][0] + y = points[i][1] + b_gradient += -(2 / N) * (y - (w_current * x + b_current) + b_current) + w_gradient += -(2 / N) * x * (y - (w_current * x + b_current + b_current)) + new_b = b_current - (learningRate * b_gradient) + new_w = w_current - (learningRate * w_gradient) + return [new_b, new_w] + + +def gradient_descent_runner(points, starting_b, starting_w, learningRate, num_iterations): + b = torch.tensor(starting_b, device=points.device, dtype=torch.float32) + w = torch.tensor(starting_w, device=points.device, dtype=torch.float32) + for i in range(num_iterations): + b, w = step_gradient(b, w, points, learningRate) + return [b, w] + + +def run(): + # 修改为生成数据的文件路径 + points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32) + points = torch.tensor(points_np, device='mps') + learning_rate = 0.0001 # 使用较小的学习率 + initial_b = 0.0 + initial_w = 0.0 + num_iterations = 1000 + print("Starting gradient descent at b={0},w={1},error={2}".format(initial_b, initial_w, + compute_error_for_line_given_points(initial_b, + initial_w, + points))) + print("running...") + [b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations) + print("After gradient descent at b={0},w={1},error={2}".format(b.item(), w.item(), + compute_error_for_line_given_points(b, w, points))) + + +if __name__ == '__main__': + run() diff --git a/linear regression/data1.csv b/linear regression/data1.csv new file mode 100644 index 0000000..338b71d --- /dev/null +++ b/linear regression/data1.csv @@ -0,0 +1,100 @@ +0.0,2.4360562173289115 +0.10101010101010101,2.4288710820592065 +0.20202020202020202,3.677943201375977 +0.30303030303030304,3.5029234515863217 +0.40404040404040403,4.007715980839878 +0.5050505050505051,3.95999321461469 +0.6060606060606061,3.220853916066527 +0.7070707070707071,3.2211460206798623 +0.8080808080808081,4.2516957270374505 +0.9090909090909091,4.311826715084292 +1.0101010101010102,4.153966583608258 +1.1111111111111112,4.224290328721461 +1.2121212121212122,4.551324602105953 +1.3131313131313131,5.157200101408408 +1.4141414141414141,5.199011258508288 +1.5151515151515151,5.248911218901843 +1.6161616161616161,5.789628423512512 +1.7171717171717171,5.126592322934872 +1.8181818181818181,4.546631494636344 +1.9191919191919191,5.7260434379514065 +2.0202020202020203,5.607446671816119 +2.121212121212121,5.401744626671172 +2.2222222222222223,5.568078510495838 +2.323232323232323,6.136817713051054 +2.4242424242424243,5.399802896696589 +2.525252525252525,6.7465591899811415 +2.6262626262626263,6.510002771256968 +2.727272727272727,6.194107987238278 +2.8282828282828283,6.280445605022811 +2.929292929292929,6.413289184504817 +3.0303030303030303,8.178951965980268 +3.131313131313131,7.438933818741419 +3.2323232323232323,8.161193108124682 +3.3333333333333335,6.466487953447159 +3.4343434343434343,7.6815673373443385 +3.5353535353535355,7.412509123916619 +3.6363636363636362,7.712231039046388 +3.7373737373737375,7.512155302443977 +3.8383838383838382,8.169468174953455 +3.9393939393939394,8.201406255891817 +4.040404040404041,9.413915839209679 +4.141414141414141,7.2131607261403 +4.242424242424242,8.244196707034996 +4.343434343434343,8.059400613529792 +4.444444444444445,9.127093042087843 +4.545454545454545,8.232456814994503 +4.646464646464646,9.026988954051767 +4.747474747474747,8.936405824368308 +4.848484848484849,8.838334259675397 +4.94949494949495,9.717080564295035 +5.05050505050505,9.635892324495916 +5.151515151515151,10.802758752616178 +5.252525252525253,9.889268431487773 +5.353535353535354,9.262021983987134 +5.454545454545454,9.905732041295009 +5.555555555555555,9.697006564677089 +5.656565656565657,10.435437946557755 +5.757575757575758,10.257651825530608 +5.858585858585858,11.394734709569004 +5.959595959595959,10.872621683473387 +6.0606060606060606,10.750944058491058 +6.161616161616162,11.375400587831757 +6.262626262626262,11.834436555701465 +6.363636363636363,11.536088544119654 +6.4646464646464645,11.261555999325722 +6.565656565656566,12.529961808490153 +6.666666666666667,12.19345219105891 +6.767676767676767,11.950653180245155 +6.8686868686868685,12.176773142948385 +6.96969696969697,12.055083206520518 +7.070707070707071,13.498633194384489 +7.171717171717171,12.542518727882712 +7.2727272727272725,13.318372269865769 +7.373737373737374,12.542630883166513 +7.474747474747475,12.93490122675753 +7.575757575757575,14.4040220344926 +7.6767676767676765,13.314367294113964 +7.777777777777778,14.061236496574551 +7.878787878787879,12.686346979737731 +7.979797979797979,14.024375221983842 +8.080808080808081,13.7042096336008 +8.181818181818182,13.342730021126272 +8.282828282828282,14.136548357864573 +8.383838383838384,14.619569834949138 +8.484848484848484,14.01453898823226 +8.585858585858587,15.154877807203663 +8.686868686868687,14.081910297898048 +8.787878787878787,14.474564310016353 +8.88888888888889,14.966525346412723 +8.98989898989899,15.526107019435932 +9.09090909090909,14.352357736719853 +9.191919191919192,15.843742065895144 +9.292929292929292,15.787083172159111 +9.393939393939394,15.211828607109144 +9.494949494949495,15.845176532492374 +9.595959595959595,15.622518083688107 +9.696969696969697,15.589081237426006 +9.797979797979798,15.511248085690712 +9.8989898989899,16.27050774059862 +10.0,16.1105549896166 diff --git a/linear regression/np_genPoints.py b/linear regression/np_genPoints.py new file mode 100644 index 0000000..70fd9e2 --- /dev/null +++ b/linear regression/np_genPoints.py @@ -0,0 +1,30 @@ +import numpy as np +import csv + +# 定义回归方程参数 +w = 1.35 +b = 2.89 + +# 生成x值范围 +x_min = 0 +x_max = 10 + +# 生成100个在x轴附近的点 +x = np.linspace(x_min, x_max, 100) + +# 根据回归方程计算y值 +y = w * x + b + +# 添加一些噪声,使数据更真实 +y += np.random.normal(scale=0.5, size=y.shape) + +# 将x和y合并成一个二维数组 +data = np.column_stack((x, y)) + +# 将数据保存到CSV文件 +with open('data1.csv', 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + # 写入表头 + # writer.writerow(['x', 'y']) + # 写入数据 + writer.writerows(data) diff --git a/linear regression/plt_print.py b/linear regression/plt_print.py new file mode 100644 index 0000000..2e1d90c --- /dev/null +++ b/linear regression/plt_print.py @@ -0,0 +1,22 @@ +import numpy as np +import matplotlib.pyplot as plt + +# 原始数据 +points = np.genfromtxt("data1.csv", delimiter=',') +x = points[:, 0] +y = points[:, 1] + +# 拟合直线 +x_range = np.linspace(min(x), max(x), 100) +y_pred = 1.6455038785934448 * x_range + 1.827562689781189 + +# 绘图 +plt.figure(figsize=(8, 6)) +plt.scatter(x, y, color='blue', label='Original data') +plt.plot(x_range, y_pred, color='red', label='Fitted line') +plt.xlabel('X') +plt.ylabel('Y') +plt.title('Fitting a line to random data') +plt.legend() +plt.grid(True) +plt.savefig('print1.png') diff --git a/linear regression/print1.png b/linear regression/print1.png new file mode 100644 index 0000000..a008b51 Binary files /dev/null and b/linear regression/print1.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..11027f8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +pytorch::pytorch +torchvision +torchaudio +pandas +matplotlib +numpy +g4f \ No newline at end of file diff --git a/test/macTest.py b/test/macTest.py new file mode 100644 index 0000000..c4fc7d3 --- /dev/null +++ b/test/macTest.py @@ -0,0 +1,6 @@ +import torch + +print(torch.backends.mps.is_available()) +print(torch.backends.mps.is_built()) + +print(torch.device("mps")) \ No newline at end of file diff --git a/test/performance.py b/test/performance.py new file mode 100644 index 0000000..356a313 --- /dev/null +++ b/test/performance.py @@ -0,0 +1,29 @@ +import torch +import time + +print(torch.__version__) +print(torch.backends.mps.is_available()) +print(torch.cuda.is_available()) + +a = torch.randn(10000,1000) +b = torch.randn(1000,2000) + +t0 = time.time() +c = torch.matmul(a, b) +t1 = time.time() +print(a.device,t1-t0,c.norm(2)) + +device = torch.device('mps') + +a = a.to(device) +b = b.to(device) + +t0 = time.time() +c = torch.matmul(a, b) +t1 = time.time() +print(a.device,t1-t0,c.norm(2)) + +t0 = time.time() +c = torch.matmul(a, b) +t1 = time.time() +print(a.device,t1-t0,c.norm(2)) \ No newline at end of file