1
0
Files
AI-learning/lab/2_liner-regression-multiply.ipynb
2024-12-30 18:49:59 +08:00

207 lines
5.0 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# 引入库\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import os\n",
"from sklearn.preprocessing import StandardScaler"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/Users/wolves/Downloads/project/python/pt/lab\n"
]
}
],
"source": [
"# 检查os位置\n",
"print(os.getcwd())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"生成数据方式\n",
"$$y = 1.35x + 0.75x^2 + 2.1\\sqrt{x} + 2.89$$"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# 生成数据\n",
"def generate_data():\n",
" w = np.array([1.35, 0.75, 2.1]) # 权重\n",
" b = 2.89 # 偏置\n",
" x_min = 1\n",
" x_max = 8\n",
" x = np.linspace(x_min, x_max, 10) # 均匀分布\n",
" X = np.array([x, x**2, np.sqrt(x)]) # 特征矩阵3x10\n",
" y = np.dot(w, X) + b # 1x10 一维向量不区分行向量和列向量\n",
" y += np.random.normal(scale=0.5, size=y.shape)\n",
" data = np.column_stack((X.T, y)) # 10x4\n",
" scaler = StandardScaler()\n",
" data = scaler.fit_transform(data)\n",
" return data\n",
"\n",
"# 保存数据\n",
"def save_data(filename, data):\n",
" np.savetxt(filename, data, delimiter=',')\n",
" print(f\"{filename} 已成功创建并写入数据。\")\n",
"\n",
"# 生成并保存数据\n",
"data = generate_data()\n",
"#save_data('./1_data.txt', data)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# 读取数据\n",
"#points = np.genfromtxt(\"./1_data.txt\", delimiter=',')\n",
"\n",
"points = data\n",
"\n",
"m = len(points[:,0])\n",
"x = points[:, :3] # 10x3\n",
"y = points[:,3] # 1x10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"原函数:\n",
"$$\n",
"\\vec{w} = {\\begin{bmatrix} w_1 & w_2 & w_3 & \\cdots & w_n \\end{bmatrix}}^T\n",
"$$\n",
"\n",
"$$\n",
"\\vec{x} = \\begin{bmatrix} x_1 & x_2 & x_3 & \\cdots & x_n \\end{bmatrix}\n",
"$$\n",
"\n",
"$$\n",
"f_{\\vec{w} \\cdot,b}(\\vec{x}) = \\vec{w} \\cdot \\vec{x} + b\n",
"$$\n",
"\n",
"损失函数: \n",
"\n",
"$$\n",
"\\text{MSE} = \\frac{1}{2m} \\sum_{i=1}^{m} \\left( y^{(i)} - \\hat{y}^{(i)} \\right)^2\n",
"$$\n",
"\n",
"梯度下降:\n",
"\n",
"分别对每个w和b求偏导数然后更新w和b\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# 定义损失函数\n",
"def compute_loss(w, b):\n",
" return np.sum((y - (np.dot(w, x.T) + b)) ** 2) / (2 * m) # w 1x3 x.T 3x10 y 1x10 y-np.dot(w, x.T) 1x10 sum=number\n",
"\n",
"# 定义梯度下降\n",
"def gradient_descent(w, b, alpha, num_iter):\n",
" for _ in range(num_iter):\n",
" error = y - np.dot(w, x.T) - b # 1x10\n",
" # 计算梯度\n",
" dw = -np.dot(x.T , error) / m # dw 1x3 \n",
" db = -np.sum(error) / m # db 1x1\n",
" # 更新w和b\n",
" w -= alpha * dw\n",
" b -= alpha * db\n",
" return w, b"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"w: [0.30933908 0.51500581 0.17832512]\n",
"b: -3.2582270215186813e-16\n",
"loss: 0.0027531700168465624\n"
]
}
],
"source": [
"# 主函数\n",
"if __name__ == \"__main__\":\n",
" # 初始化w和b\n",
" w = np.zeros(x.shape[1])\n",
" b = 0.0\n",
" # 设置学习率\n",
" alpha = 0.01\n",
" # 设置迭代次数\n",
" num_iter = 1000\n",
" # 进行梯度下降\n",
" w,b = gradient_descent(w,b,alpha,num_iter)\n",
" print(\"w:\", w)\n",
" print(\"b:\", b)\n",
" # 计算损失\n",
" loss = compute_loss(w,b)\n",
" print(\"loss:\", loss)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 编码中遇到的错误\n",
"\n",
"梯度下降算法中把x.T和error相乘了正确应使用矩阵乘法。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pt",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}