182 lines
4.3 KiB
Plaintext
182 lines
4.3 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 引入库\n",
|
||
"import numpy as np\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import os"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 检查os位置\n",
|
||
"print(os.getcwd())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 生成数据\n",
|
||
"def generate_data():\n",
|
||
" w = 1.35\n",
|
||
" b = 2.89\n",
|
||
" x_min = 0\n",
|
||
" x_max = 10\n",
|
||
" x = np.linspace(x_min, x_max, 100)\n",
|
||
" y = w * x + b\n",
|
||
" y += np.random.normal(scale=0.5, size=y.shape)\n",
|
||
" data = np.column_stack((x, y))\n",
|
||
" return data\n",
|
||
"\n",
|
||
"# 保存数据\n",
|
||
"def save_data(filename, data):\n",
|
||
" np.savetxt(filename, data, delimiter=',')\n",
|
||
" print(f\"{filename} 已成功创建并写入数据。\")\n",
|
||
"\n",
|
||
"# 生成并保存数据\n",
|
||
"data = generate_data()\n",
|
||
"#save_data('./1_data.txt', data)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 读取数据\n",
|
||
"#points = np.genfromtxt(\"./1_data.txt\", delimiter=',')\n",
|
||
"\n",
|
||
"points = data\n",
|
||
" \n",
|
||
"x = points[:, 0]\n",
|
||
"y = points[:, 1]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"损失函数: \n",
|
||
"$$J(w,b) = \\frac{1}{2m} \\sum_{i=1}^{m} (y_{w,b}(x^{(i)}) - y^{(i)})^2$$\n",
|
||
"\n",
|
||
"梯度下降:\n",
|
||
"\n",
|
||
"分别对w和b求偏导数,然后更新w和b\n",
|
||
"$$\n",
|
||
"w = w - \\alpha\\cdot\\frac{\\partial J(w,b)}{\\partial w}\n",
|
||
"$$\n",
|
||
"\n",
|
||
"$$\n",
|
||
"b = b - \\alpha\\cdot\\frac{\\partial J(w,b)}{\\partial b}\n",
|
||
"$$"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 定义损失函数\n",
|
||
"def compute_loss(w,b):\n",
|
||
" return np.sum((y-w*x-b)**2)/2*len(x)\n",
|
||
"\n",
|
||
"# 等效\n",
|
||
"def compute_loss_equivalent(w,b):\n",
|
||
" sum = 0\n",
|
||
" for i in range(len(x)):\n",
|
||
" sum += (y[i] - (w*x[i]+b))**2\n",
|
||
" return sum/(2*len(x))\n",
|
||
"\n",
|
||
"# 定义梯度下降\n",
|
||
"def gradient_descent(w,b,alpha,num_iter):\n",
|
||
" m = len(x)\n",
|
||
" for _ in range(num_iter):\n",
|
||
" # 计算梯度\n",
|
||
" dw = -np.sum(x*(y-w*x-b))/m\n",
|
||
" db = -np.sum(y-w*x-b)/m\n",
|
||
" # 更新w和b\n",
|
||
" w = w - alpha*dw\n",
|
||
" b = b - alpha*db\n",
|
||
" return w,b"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 主函数\n",
|
||
"if __name__ == \"__main__\":\n",
|
||
" # 初始化w和b\n",
|
||
" w,b = 0,0\n",
|
||
" # 设置学习率\n",
|
||
" alpha = 0.01\n",
|
||
" # 设置迭代次数\n",
|
||
" num_iter = 1000\n",
|
||
" # 进行梯度下降\n",
|
||
" w,b = gradient_descent(w,b,alpha,num_iter)\n",
|
||
" print(\"w:\", w)\n",
|
||
" print(\"b:\", b)\n",
|
||
" # 计算损失\n",
|
||
" loss = compute_loss(w,b)\n",
|
||
" print(\"loss:\", loss)\n",
|
||
"\n",
|
||
" plt.figure(dpi=600)\n",
|
||
" #plt.switch_backend('Agg') # 使用 Agg 渲染器\n",
|
||
" # 绘制数据点\n",
|
||
" plt.scatter(x, y, color='blue', label='original data')\n",
|
||
"\n",
|
||
" # 绘制回归直线\n",
|
||
" plt.plot(x, w*x + b, color='red', label='regression line')\n",
|
||
"\n",
|
||
" # 添加标题和标签\n",
|
||
" plt.title('linear regression')\n",
|
||
" plt.xlabel('x')\n",
|
||
" plt.ylabel('y')\n",
|
||
"\n",
|
||
" # 显示图例\n",
|
||
" plt.legend()\n",
|
||
"\n",
|
||
" # 显示图像\n",
|
||
" plt.show()"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "pt",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.14"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|