diff --git a/linear regression/1.md b/linear regression/1.md new file mode 100644 index 0000000..938f45d --- /dev/null +++ b/linear regression/1.md @@ -0,0 +1,18 @@ +# 1.py + +$$ +y = wx + b +$$ + +## 梯度下降算法 + +$$ +b_gradient += -\frac{2}{N} \left(y - (w_current \cdot x + b_current)\right) +$$ + + + +$$ +w_gradient += -\frac{2}{N} \cdot x \cdot \left(y - (w_current \cdot x + b_current)\right) +$$ + diff --git a/linear regression/1.py b/linear regression/1.py index dc455ef..a7402e6 100644 --- a/linear regression/1.py +++ b/linear regression/1.py @@ -19,8 +19,8 @@ def step_gradient(b_current, w_current, points, learningRate): for i in range(len(points)): x = points[i][0] y = points[i][1] - b_gradient += -(2 / N) * (y - (w_current * x + b_current) + b_current) - w_gradient += -(2 / N) * x * (y - (w_current * x + b_current + b_current)) + b_gradient += -(2 / N) * (y - (w_current * x + b_current)) + w_gradient += -(2 / N) * x * (y - (w_current * x + b_current)) new_b = b_current - (learningRate * b_gradient) new_w = w_current - (learningRate * w_gradient) return [new_b, new_w] diff --git a/linear regression/1_cuda.py b/linear regression/1_cuda.py new file mode 100644 index 0000000..be818aa --- /dev/null +++ b/linear regression/1_cuda.py @@ -0,0 +1,67 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch + +# 线性回归训练代码 +def compute_error_for_line_given_points(b, w, points): + totalError = 0 + N = float(len(points)) + for i in range(len(points)): + x = points[i][0] + y = points[i][1] + totalError += (y - (w * x + b)) ** 2 + return totalError / N + +def step_gradient(b_current, w_current, points, learningRate): + b_gradient = torch.tensor(0.0, device=points.device) + w_gradient = torch.tensor(0.0, device=points.device) + N = float(len(points)) + for i in range(len(points)): + x = points[i][0] + y = points[i][1] + b_gradient += -(2 / N) * (y - (w_current * x + b_current)) + w_gradient += -(2 / N) * x * (y - (w_current * x + b_current)) + new_b = b_current - (learningRate * b_gradient) + new_w = w_current - (learningRate * w_gradient) + return [new_b, new_w] + +def gradient_descent_runner(points, starting_b, starting_w, learningRate, num_iterations): + b = torch.tensor(starting_b, device=points.device) + w = torch.tensor(starting_w, device=points.device) + for i in range(num_iterations): + b, w = step_gradient(b, w, points, learningRate) + return [b, w] + +def run(): + points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32) + points = torch.tensor(points_np, device='cuda') + learning_rate = 0.0001 + initial_b = 0.0 + initial_w = 0.0 + num_iterations = 100000 + [b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations) + print("After gradient descent at b={0}, w={1}, error={2}".format(b.item(), w.item(), + compute_error_for_line_given_points(b, w, points))) + return b.item(), w.item() + +# 运行线性回归 +final_b, final_w = run() + +# 绘制图像 +points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32) +x = points_np[:, 0] +y = points_np[:, 1] + +x_range = np.linspace(min(x), max(x), 100) +y_pred = final_w * x_range + final_b + +plt.figure(figsize=(8, 6)) +plt.scatter(x, y, color='blue', label='Original data') +plt.plot(x_range, y_pred, color='red', label='Fitted line') +plt.xlabel('X') +plt.ylabel('Y') +plt.title('Fitting a line to random data') +plt.legend() +plt.grid(True) +plt.savefig('print1.png') +plt.show() diff --git a/linear regression/1x.py b/linear regression/1x.py new file mode 100644 index 0000000..0b1b93a --- /dev/null +++ b/linear regression/1x.py @@ -0,0 +1,72 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch + + +# 线性回归训练代码 +def compute_error_for_line_given_points(b, w, points): + totalError = 0 + N = float(len(points)) + for i in range(len(points)): + x = points[i][0] + y = points[i][1] + totalError += (y - (w * x + b)) ** 2 + return totalError / N + + +def step_gradient(b_current, w_current, points, learningRate): + b_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32) + w_gradient = torch.tensor(0.0, device=points.device, dtype=torch.float32) + N = float(len(points)) + for i in range(len(points)): + x = points[i][0] + y = points[i][1] + b_gradient += -(2 / N) * (y - (w_current * x + b_current)) + w_gradient += -(2 / N) * x * (y - (w_current * x + b_current)) + new_b = b_current - (learningRate * b_gradient) + new_w = w_current - (learningRate * w_gradient) + return [new_b, new_w] + + +def gradient_descent_runner(points, starting_b, starting_w, learningRate, num_iterations): + b = torch.tensor(starting_b, device=points.device, dtype=torch.float32) + w = torch.tensor(starting_w, device=points.device, dtype=torch.float32) + for i in range(num_iterations): + b, w = step_gradient(b, w, points, learningRate) + return [b, w] + + +def run(): + points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32) + points = torch.tensor(points_np, device='mps') + learning_rate = 0.0001 + initial_b = 0.0 + initial_w = 0.0 + num_iterations = 100000 + [b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations) + print("After gradient descent at b={0},w={1},error={2}".format(b.item(), w.item(), + compute_error_for_line_given_points(b, w, points))) + return b.item(), w.item() + + +# 运行线性回归 +final_b, final_w = run() + +# 绘制图像 +points_np = np.genfromtxt("data1.csv", delimiter=',').astype(np.float32) +x = points_np[:, 0] +y = points_np[:, 1] + +x_range = np.linspace(min(x), max(x), 100) +y_pred = final_w * x_range + final_b + +plt.figure(figsize=(8, 6)) +plt.scatter(x, y, color='blue', label='Original data') +plt.plot(x_range, y_pred, color='red', label='Fitted line') +plt.xlabel('X') +plt.ylabel('Y') +plt.title('Fitting a line to random data') +plt.legend() +plt.grid(True) +plt.savefig('print1.png') +plt.show() diff --git a/linear regression/data1.csv b/linear regression/data1.csv index 338b71d..508fd99 100644 --- a/linear regression/data1.csv +++ b/linear regression/data1.csv @@ -1,100 +1,100 @@ -0.0,2.4360562173289115 -0.10101010101010101,2.4288710820592065 -0.20202020202020202,3.677943201375977 -0.30303030303030304,3.5029234515863217 -0.40404040404040403,4.007715980839878 -0.5050505050505051,3.95999321461469 -0.6060606060606061,3.220853916066527 -0.7070707070707071,3.2211460206798623 -0.8080808080808081,4.2516957270374505 -0.9090909090909091,4.311826715084292 -1.0101010101010102,4.153966583608258 -1.1111111111111112,4.224290328721461 -1.2121212121212122,4.551324602105953 -1.3131313131313131,5.157200101408408 -1.4141414141414141,5.199011258508288 -1.5151515151515151,5.248911218901843 -1.6161616161616161,5.789628423512512 -1.7171717171717171,5.126592322934872 -1.8181818181818181,4.546631494636344 -1.9191919191919191,5.7260434379514065 -2.0202020202020203,5.607446671816119 -2.121212121212121,5.401744626671172 -2.2222222222222223,5.568078510495838 -2.323232323232323,6.136817713051054 -2.4242424242424243,5.399802896696589 -2.525252525252525,6.7465591899811415 -2.6262626262626263,6.510002771256968 -2.727272727272727,6.194107987238278 -2.8282828282828283,6.280445605022811 -2.929292929292929,6.413289184504817 -3.0303030303030303,8.178951965980268 -3.131313131313131,7.438933818741419 -3.2323232323232323,8.161193108124682 -3.3333333333333335,6.466487953447159 -3.4343434343434343,7.6815673373443385 -3.5353535353535355,7.412509123916619 -3.6363636363636362,7.712231039046388 -3.7373737373737375,7.512155302443977 -3.8383838383838382,8.169468174953455 -3.9393939393939394,8.201406255891817 -4.040404040404041,9.413915839209679 -4.141414141414141,7.2131607261403 -4.242424242424242,8.244196707034996 -4.343434343434343,8.059400613529792 -4.444444444444445,9.127093042087843 -4.545454545454545,8.232456814994503 -4.646464646464646,9.026988954051767 -4.747474747474747,8.936405824368308 -4.848484848484849,8.838334259675397 -4.94949494949495,9.717080564295035 -5.05050505050505,9.635892324495916 -5.151515151515151,10.802758752616178 -5.252525252525253,9.889268431487773 -5.353535353535354,9.262021983987134 -5.454545454545454,9.905732041295009 -5.555555555555555,9.697006564677089 -5.656565656565657,10.435437946557755 -5.757575757575758,10.257651825530608 -5.858585858585858,11.394734709569004 -5.959595959595959,10.872621683473387 -6.0606060606060606,10.750944058491058 -6.161616161616162,11.375400587831757 -6.262626262626262,11.834436555701465 -6.363636363636363,11.536088544119654 -6.4646464646464645,11.261555999325722 -6.565656565656566,12.529961808490153 -6.666666666666667,12.19345219105891 -6.767676767676767,11.950653180245155 -6.8686868686868685,12.176773142948385 -6.96969696969697,12.055083206520518 -7.070707070707071,13.498633194384489 -7.171717171717171,12.542518727882712 -7.2727272727272725,13.318372269865769 -7.373737373737374,12.542630883166513 -7.474747474747475,12.93490122675753 -7.575757575757575,14.4040220344926 -7.6767676767676765,13.314367294113964 -7.777777777777778,14.061236496574551 -7.878787878787879,12.686346979737731 -7.979797979797979,14.024375221983842 -8.080808080808081,13.7042096336008 -8.181818181818182,13.342730021126272 -8.282828282828282,14.136548357864573 -8.383838383838384,14.619569834949138 -8.484848484848484,14.01453898823226 -8.585858585858587,15.154877807203663 -8.686868686868687,14.081910297898048 -8.787878787878787,14.474564310016353 -8.88888888888889,14.966525346412723 -8.98989898989899,15.526107019435932 -9.09090909090909,14.352357736719853 -9.191919191919192,15.843742065895144 -9.292929292929292,15.787083172159111 -9.393939393939394,15.211828607109144 -9.494949494949495,15.845176532492374 -9.595959595959595,15.622518083688107 -9.696969696969697,15.589081237426006 -9.797979797979798,15.511248085690712 -9.8989898989899,16.27050774059862 -10.0,16.1105549896166 +0.0,3.267264598063918 +0.10101010101010101,3.3715980311448757 +0.20202020202020202,3.624529195538264 +0.30303030303030304,2.2426026435865785 +0.40404040404040403,2.881354128303834 +0.5050505050505051,4.108915299601898 +0.6060606060606061,3.2833072841616344 +0.7070707070707071,3.401314666490608 +0.8080808080808081,3.4471224977820083 +0.9090909090909091,4.597483332850038 +1.0101010101010102,4.1948230194917615 +1.1111111111111112,4.770110614428998 +1.2121212121212122,4.3466984672473545 +1.3131313131313131,4.085374736788284 +1.4141414141414141,4.860667770156386 +1.5151515151515151,5.367460741298345 +1.6161616161616161,5.1076464505556585 +1.7171717171717171,4.517380143483942 +1.8181818181818181,6.028333717306668 +1.9191919191919191,5.268642728341781 +2.0202020202020203,5.2032646463511885 +2.121212121212121,5.776924577040542 +2.2222222222222223,5.914239664440679 +2.323232323232323,6.195294604152318 +2.4242424242424243,6.67461745554651 +2.525252525252525,6.62895898059055 +2.6262626262626263,6.423496434474387 +2.727272727272727,6.520626001853953 +2.8282828282828283,6.252851138402289 +2.929292929292929,7.045662416151556 +3.0303030303030303,7.062687254815803 +3.131313131313131,6.950015155958233 +3.2323232323232323,7.71420449451215 +3.3333333333333335,7.536987534120887 +3.4343434343434343,8.408446293914908 +3.5353535353535355,8.281116903817127 +3.6363636363636362,6.862064335470844 +3.7373737373737375,8.455114086555362 +3.8383838383838382,8.610569256326439 +3.9393939393939394,8.695172603505283 +4.040404040404041,7.987174933011048 +4.141414141414141,8.484042583282307 +4.242424242424242,8.152218590549857 +4.343434343434343,8.810112829362456 +4.444444444444445,9.098520210970904 +4.545454545454545,9.315991463976044 +4.646464646464646,9.266562387635002 +4.747474747474747,8.457655714255173 +4.848484848484849,8.577190426286784 +4.94949494949495,9.992687218959654 +5.05050505050505,9.949888900251127 +5.151515151515151,10.112370557219064 +5.252525252525253,10.250084050804231 +5.353535353535354,9.675169646286898 +5.454545454545454,9.790565255890696 +5.555555555555555,9.91666488079517 +5.656565656565657,10.325538746448835 +5.757575757575758,9.77548528051785 +5.858585858585858,10.55371401462777 +5.959595959595959,10.757696722894282 +6.0606060606060606,10.893354131765726 +6.161616161616162,12.049342074375708 +6.262626262626262,10.936118426966079 +6.363636363636363,11.031578580287063 +6.4646464646464645,11.713471927909302 +6.565656565656566,12.343664117608101 +6.666666666666667,12.067856620638729 +6.767676767676767,11.814430199711552 +6.8686868686868685,11.123516182999314 +6.96969696969697,12.496962644202316 +7.070707070707071,12.767487737755147 +7.171717171717171,12.632934104476211 +7.2727272727272725,12.728225932468364 +7.373737373737374,12.97630338885533 +7.474747474747475,12.896220727223701 +7.575757575757575,13.047808359849581 +7.6767676767676765,13.110443597152527 +7.777777777777778,12.921358752181128 +7.878787878787879,13.30038615173782 +7.979797979797979,13.836945395153705 +8.080808080808081,13.054484897082014 +8.181818181818182,14.01038452336861 +8.282828282828282,13.643336312636018 +8.383838383838384,14.564671365817466 +8.484848484848484,14.040540515755758 +8.585858585858587,14.57992522742261 +8.686868686868687,14.88631275019171 +8.787878787878787,14.021963220008606 +8.88888888888889,15.068155050128949 +8.98989898989899,15.083538874549268 +9.09090909090909,15.417748308319911 +9.191919191919192,14.89714205401168 +9.292929292929292,14.534676091762206 +9.393939393939394,15.556467883324295 +9.494949494949495,15.525938847099864 +9.595959595959595,15.560767751324764 +9.696969696969697,15.982790914773943 +9.797979797979798,16.062079721169738 +9.8989898989899,16.232818049890696 +10.0,17.053472736980353 diff --git a/linear regression/plt_print.py b/linear regression/plt_print.py index 2e1d90c..245fe99 100644 --- a/linear regression/plt_print.py +++ b/linear regression/plt_print.py @@ -3,12 +3,13 @@ import matplotlib.pyplot as plt # 原始数据 points = np.genfromtxt("data1.csv", delimiter=',') + x = points[:, 0] y = points[:, 1] # 拟合直线 x_range = np.linspace(min(x), max(x), 100) -y_pred = 1.6455038785934448 * x_range + 1.827562689781189 +y_pred = 0.3880246877670288 * x_range + 1.7214288711547852 # 绘图 plt.figure(figsize=(8, 6)) diff --git a/linear regression/print1.png b/linear regression/print1.png index a008b51..c2933e1 100644 Binary files a/linear regression/print1.png and b/linear regression/print1.png differ