1
0
Files
AI-learning/lab/1_liner-regression-single.ipynb
2025-01-25 22:33:14 +08:00

182 lines
4.3 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# 引入库\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 检查os位置\n",
"print(os.getcwd())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# 生成数据\n",
"def generate_data():\n",
" w = 1.35\n",
" b = 2.89\n",
" x_min = 0\n",
" x_max = 10\n",
" x = np.linspace(x_min, x_max, 100)\n",
" y = w * x + b\n",
" y += np.random.normal(scale=0.5, size=y.shape)\n",
" data = np.column_stack((x, y))\n",
" return data\n",
"\n",
"# 保存数据\n",
"def save_data(filename, data):\n",
" np.savetxt(filename, data, delimiter=',')\n",
" print(f\"{filename} 已成功创建并写入数据。\")\n",
"\n",
"# 生成并保存数据\n",
"data = generate_data()\n",
"#save_data('./1_data.txt', data)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# 读取数据\n",
"#points = np.genfromtxt(\"./1_data.txt\", delimiter=',')\n",
"\n",
"points = data\n",
" \n",
"x = points[:, 0]\n",
"y = points[:, 1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"损失函数: \n",
"$$J(w,b) = \\frac{1}{2m} \\sum_{i=1}^{m} (y_{w,b}(x^{(i)}) - y^{(i)})^2$$\n",
"\n",
"梯度下降:\n",
"\n",
"分别对w和b求偏导数然后更新w和b\n",
"$$\n",
"w = w - \\alpha\\cdot\\frac{\\partial J(w,b)}{\\partial w}\n",
"$$\n",
"\n",
"$$\n",
"b = b - \\alpha\\cdot\\frac{\\partial J(w,b)}{\\partial b}\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# 定义损失函数\n",
"def compute_loss(w,b):\n",
" return np.sum((y-w*x-b)**2)/2*len(x)\n",
"\n",
"# 等效\n",
"def compute_loss_equivalent(w,b):\n",
" sum = 0\n",
" for i in range(len(x)):\n",
" sum += (y[i] - (w*x[i]+b))**2\n",
" return sum/(2*len(x))\n",
"\n",
"# 定义梯度下降\n",
"def gradient_descent(w,b,alpha,num_iter):\n",
" m = len(x)\n",
" for _ in range(num_iter):\n",
" # 计算梯度\n",
" dw = -np.sum(x*(y-w*x-b))/m\n",
" db = -np.sum(y-w*x-b)/m\n",
" # 更新w和b\n",
" w = w - alpha*dw\n",
" b = b - alpha*db\n",
" return w,b"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 主函数\n",
"if __name__ == \"__main__\":\n",
" # 初始化w和b\n",
" w,b = 0,0\n",
" # 设置学习率\n",
" alpha = 0.01\n",
" # 设置迭代次数\n",
" num_iter = 1000\n",
" # 进行梯度下降\n",
" w,b = gradient_descent(w,b,alpha,num_iter)\n",
" print(\"w:\", w)\n",
" print(\"b:\", b)\n",
" # 计算损失\n",
" loss = compute_loss(w,b)\n",
" print(\"loss:\", loss)\n",
"\n",
" plt.figure(dpi=600)\n",
" #plt.switch_backend('Agg') # 使用 Agg 渲染器\n",
" # 绘制数据点\n",
" plt.scatter(x, y, color='blue', label='original data')\n",
"\n",
" # 绘制回归直线\n",
" plt.plot(x, w*x + b, color='red', label='regression line')\n",
"\n",
" # 添加标题和标签\n",
" plt.title('linear regression')\n",
" plt.xlabel('x')\n",
" plt.ylabel('y')\n",
"\n",
" # 显示图例\n",
" plt.legend()\n",
"\n",
" # 显示图像\n",
" plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pt",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}