记得刚开始学习机器学习时,最让我头疼的就是那些抽象的数学概念。特别是"凸函数"和"凸集"这两个词,在教材里反复出现,但光看定义总觉得隔靴搔痒。直到有一天,我尝试用Python把这些概念画出来,一切突然变得清晰可见。这篇文章就是要把这种"可视化学习法"分享给你,让你也能通过代码亲手触摸这些数学概念的形状。
在开始之前,我们需要准备好Python环境和必要的库。推荐使用Anaconda创建虚拟环境,这样可以避免库版本冲突的问题:
bash复制conda create -n convex_env python=3.8
conda activate convex_env
pip install numpy matplotlib ipykernel
核心工具包的功能简介:
提示:如果你使用Jupyter Notebook,可以添加
%matplotlib inline魔法命令,让图表直接显示在笔记本中。
让我们先定义一个简单的绘图函数,后续可以重复使用:
python复制import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def plot_function(f, x_range=(-5,5), y_range=(-5,5), step=0.1):
"""绘制二元函数的三维曲面"""
x = np.arange(x_range[0], x_range[1], step)
y = np.arange(y_range[0], y_range[1], step)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)
fig = plt.figure(figsize=(12,6))
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8)
fig.colorbar(surf)
plt.title(f"Function: {f.__name__}")
plt.show()
数学定义说凸集是"集合内任意两点连线上的点都在集合内",这听起来很抽象。让我们用代码来具象化这个概念。
先创建几个常见的凸集:
python复制def plot_convex_sets():
# 圆形(凸集)
circle = plt.Circle((0,0), 2, fill=False, edgecolor='blue', linewidth=2)
# 矩形(凸集)
rectangle = plt.Rectangle((-1,-1), 2, 2, fill=False, edgecolor='green', linewidth=2)
# 非凸集示例(月牙形)
theta = np.linspace(0, 2*np.pi, 100)
moon_x = 2*np.cos(theta)
moon_y = 2*np.sin(theta) + np.where(theta < np.pi, 0.5, -0.5)
fig, ax = plt.subplots(figsize=(10,5))
ax.add_patch(circle)
ax.add_patch(rectangle)
ax.plot(moon_x, moon_y, 'r-', linewidth=2)
ax.set_xlim(-3,3)
ax.set_ylim(-3,3)
ax.set_aspect('equal')
ax.grid(True)
plt.title("凸集(蓝圆/绿方) vs 非凸集(红月牙)")
plt.show()
plot_convex_sets()
编写一个函数来验证给定点集是否为凸集:
python复制def is_convex_set(points, tolerance=1e-6):
"""
验证点集是否为凸集
:param points: numpy数组,形状为(n,2)
:param tolerance: 浮点误差容忍度
:return: bool
"""
n = len(points)
if n < 3: # 少于3个点自动视为凸集
return True
for i in range(n):
for j in range(i+1, n):
# 对每对点,检查线段上的点是否都在集合内
lambda_vals = np.linspace(0, 1, 20)
for lam in lambda_vals:
point_on_segment = lam*points[i] + (1-lam)*points[j]
# 找到集合中最近的点
distances = np.linalg.norm(points - point_on_segment, axis=1)
if np.min(distances) > tolerance:
return False
return True
测试我们的函数:
python复制# 凸集测试(单位圆上的点)
theta = np.linspace(0, 2*np.pi, 20)
circle_points = np.column_stack([np.cos(theta), np.sin(theta)])
print(f"圆形点集是否为凸集: {is_convex_set(circle_points)}") # 应该返回True
# 非凸集测试(月牙形点集)
moon_points = np.column_stack([2*np.cos(theta), 2*np.sin(theta) + np.where(theta < np.pi, 0.5, -0.5)])
print(f"月牙形点集是否为凸集: {is_convex_set(moon_points)}") # 应该返回False
凸函数的定义是"函数图像上任意两点间的线段都在图像上方"。让我们用Python来实现这个概念的验证。
定义几个典型的凸函数:
python复制# 二次函数
def quadratic(x, y):
return x**2 + y**2
# 指数函数
def exponential(x, y):
return np.exp(x) + np.exp(y)
# 负对数函数
def neg_log(x, y):
return -np.log(x+3) - np.log(y+3) # +3保证定义域内有效
绘制这些函数:
python复制plot_function(quadratic)
plot_function(exponential)
plot_function(neg_log, x_range=(0.1,5), y_range=(0.1,5))
实现一个验证函数凸性的工具:
python复制def is_convex_function(f, x_range=(-1,1), y_range=(-1,1), num_points=10):
"""
通过采样验证二元函数在给定范围内的凸性
"""
x_samples = np.linspace(x_range[0], x_range[1], num_points)
y_samples = np.linspace(y_range[0], y_range[1], num_points)
X, Y = np.meshgrid(x_samples, y_samples)
for i in range(num_points-1):
for j in range(num_points-1):
# 取四个相邻点形成两个三角形
p1 = np.array([X[i,j], Y[i,j]])
p2 = np.array([X[i+1,j], Y[i+1,j]])
p3 = np.array([X[i,j+1], Y[i,j+1]])
# 检查凸性条件
for lam in np.linspace(0, 1, 5):
# 第一个三角形
p = lam*p1 + (1-lam)*p2
z_actual = f(p[0], p[1])
z_interp = lam*f(p1[0],p1[1]) + (1-lam)*f(p2[0],p2[1])
if z_actual > z_interp + 1e-6:
return False
# 第二个三角形
p = lam*p1 + (1-lam)*p3
z_actual = f(p[0], p[1])
z_interp = lam*f(p1[0],p1[1]) + (1-lam)*f(p3[0],p3[1])
if z_actual > z_interp + 1e-6:
return False
return True
测试我们的验证函数:
python复制print(f"二次函数是否为凸函数: {is_convex_function(quadratic)}")
print(f"指数函数是否为凸函数: {is_convex_function(exponential)}")
print(f"负对数函数是否为凸函数: {is_convex_function(neg_log, x_range=(0.1,5), y_range=(0.1,5))}")
# 测试一个非凸函数
def non_convex(x, y):
return np.sin(x) + np.cos(y)
print(f"正弦余弦函数是否为凸函数: {is_convex_function(non_convex)}")
虽然可视化方法直观,但在高维空间或复杂函数中,我们需要更可靠的数学工具。Hessian矩阵就是这样的工具。
实现一个数值计算Hessian矩阵的函数:
python复制def numerical_hessian(f, x, y, epsilon=1e-5):
"""
数值计算二元函数在某点的Hessian矩阵
"""
# 一阶导数
fx = (f(x+epsilon, y) - f(x-epsilon, y))/(2*epsilon)
fy = (f(x, y+epsilon) - f(x, y-epsilon))/(2*epsilon)
# 二阶导数
fxx = (f(x+epsilon, y) - 2*f(x,y) + f(x-epsilon, y))/(epsilon**2)
fyy = (f(x, y+epsilon) - 2*f(x,y) + f(x, y-epsilon))/(epsilon**2)
fxy = (f(x+epsilon, y+epsilon) - f(x+epsilon, y-epsilon) -
f(x-epsilon, y+epsilon) + f(x-epsilon, y-epsilon))/(4*epsilon**2)
return np.array([[fxx, fxy], [fxy, fyy]])
实现判断矩阵是否半正定的函数:
python复制def is_positive_semidefinite(matrix):
"""
判断矩阵是否半正定
"""
eigenvalues = np.linalg.eigvals(matrix)
return np.all(eigenvalues >= -1e-8) # 考虑数值误差
结合Hessian矩阵和正定性判断的函数凸性:
python复制def is_convex_by_hessian(f, x_range=(-1,1), y_range=(-1,1), num_points=5):
"""
通过采样点Hessian矩阵验证函数凸性
"""
x_samples = np.linspace(x_range[0], x_range[1], num_points)
y_samples = np.linspace(y_range[0], y_range[1], num_points)
for x in x_samples:
for y in y_samples:
hessian = numerical_hessian(f, x, y)
if not is_positive_semidefinite(hessian):
return False
return True
比较两种验证方法的结果:
python复制test_functions = [
("二次函数", quadratic),
("指数函数", exponential),
("负对数函数", lambda x,y: neg_log(x,y)),
("非凸函数", non_convex)
]
for name, func in test_functions:
visual_result = is_convex_function(func)
hessian_result = is_convex_by_hessian(func)
print(f"{name}: 可视化验证={visual_result}, Hessian验证={hessian_result}")
理解了凸函数和凸集的概念后,让我们看看它们在优化问题中的应用。
考虑以下凸优化问题:
最小化 f(x,y) = x² + y²
约束条件:x + y ≥ 1
用SciPy的优化工具求解:
python复制from scipy.optimize import minimize
# 定义目标函数和约束
def objective(x):
return x[0]**2 + x[1]**2
def constraint(x):
return x[0] + x[1] - 1 # x + y ≥ 1 → x + y -1 ≥ 0
# 设置优化问题
cons = {'type': 'ineq', 'fun': constraint}
x0 = [0, 0] # 初始猜测
# 求解
solution = minimize(objective, x0, constraints=cons)
print(f"最优解: x={solution.x[0]:.4f}, y={solution.x[1]:.4f}")
print(f"最优值: {solution.fun:.4f}")
让我们绘制优化问题的图形表示:
python复制def plot_optimization_problem():
x = np.linspace(-1, 2, 100)
y = np.linspace(-1, 2, 100)
X, Y = np.meshgrid(x, y)
Z = X**2 + Y**2
plt.figure(figsize=(10,6))
# 绘制等高线
contours = plt.contour(X, Y, Z, levels=20, cmap='viridis')
plt.colorbar(contours)
# 绘制约束条件
plt.plot(x, 1 - x, 'r-', label='约束: x + y ≥ 1')
# 标记最优解
plt.plot(0.5, 0.5, 'ro', markersize=10, label='最优解')
plt.xlabel('x')
plt.ylabel('y')
plt.title('凸优化问题可视化')
plt.legend()
plt.grid(True)
plt.show()
plot_optimization_problem()
考虑一个非凸优化问题作为对比:
最小化 f(x,y) = sin(x) + cos(y) + (x² + y²)/10
python复制def non_convex_objective(x):
return np.sin(x[0]) + np.cos(x[1]) + (x[0]**2 + x[1]**2)/10
# 从不同初始点出发,观察结果
initial_points = [[0,0], [1,1], [-1,-1], [2,2], [-2,-2]]
for x0 in initial_points:
solution = minimize(non_convex_objective, x0)
print(f"初始点{x0} → 解:{solution.x}, 值:{solution.fun:.4f}")
这个例子展示了非凸函数可能存在的多个局部最优解,具体结果取决于初始点选择。