1
0
This commit is contained in:
2024-12-30 18:49:59 +08:00
parent 6361618c8e
commit 0d00e1deb4
2 changed files with 225 additions and 19 deletions

View File

@@ -36,26 +36,26 @@
"metadata": {},
"outputs": [],
"source": [
" # 生成数据\n",
" def generate_data():\n",
" w = 1.35\n",
" b = 2.89\n",
" x_min = 0\n",
" x_max = 10\n",
" x = np.linspace(x_min, x_max, 100)\n",
" y = w * x + b\n",
" y += np.random.normal(scale=0.5, size=y.shape)\n",
" data = np.column_stack((x, y))\n",
" return data\n",
"# 生成数据\n",
"def generate_data():\n",
" w = 1.35\n",
" b = 2.89\n",
" x_min = 0\n",
" x_max = 10\n",
" x = np.linspace(x_min, x_max, 100)\n",
" y = w * x + b\n",
" y += np.random.normal(scale=0.5, size=y.shape)\n",
" data = np.column_stack((x, y))\n",
" return data\n",
"\n",
" # 保存数据\n",
" def save_data(filename, data):\n",
" np.savetxt(filename, data, delimiter=',')\n",
" print(f\"{filename} 已成功创建并写入数据。\")\n",
"# 保存数据\n",
"def save_data(filename, data):\n",
" np.savetxt(filename, data, delimiter=',')\n",
" print(f\"{filename} 已成功创建并写入数据。\")\n",
"\n",
" # 生成并保存数据\n",
" data = generate_data()\n",
" #save_data('./1_data.txt', data)"
"# 生成并保存数据\n",
"data = generate_data()\n",
"#save_data('./1_data.txt', data)"
]
},
{
@@ -100,7 +100,7 @@
"source": [
"# 定义损失函数\n",
"def compute_loss(w,b):\n",
" return np.sum((y-w*x-b)**2)/(2*len(x))\n",
" return np.sum((y-w*x-b)**2)/2*len(x)\n",
"\n",
"# 等效\n",
"def compute_loss_equivalent(w,b):\n",