Последняя активность 1 week ago

ns.py Исходник
1import flax
2import deepxde as dde
3import jax.numpy as np
4import matplotlib.pyplot as plt
5
6rho = 1
7mu = 1
8u_in = 1
9D = 1
10L = 2
11
12geom = dde.geometry.Rectangle(xmin=[-L / 2, -D / 2], xmax=[L / 2, D / 2])
13def boundary_wall(X, on_boundary):
14 on_wall = np.logical_and(
15 np.logical_or(
16 np.isclose(X[1], -D / 2, rtol=1e-05, atol=1e-08),
17 np.isclose(X[1], D / 2, rtol=1e-05, atol=1e-08),
18 ),
19 on_boundary,
20 )
21 return on_wall
22
23def boundary_inlet(X, on_boundary):
24 on_inlet = np.logical_and(
25 np.isclose(X[0], -L / 2, rtol=1e-05, atol=1e-08), on_boundary
26 )
27 return on_inlet
28
29def boundary_outlet(X, on_boundary):
30 on_outlet = np.logical_and(
31 np.isclose(X[0], L / 2, rtol=1e-05, atol=1e-08), on_boundary
32 )
33 return on_outlet
34
35bc_wall_u = dde.DirichletBC(geom, lambda X: 0.0, boundary_wall, component=0)
36bc_wall_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_wall, component=1)
37
38bc_inlet_u = dde.DirichletBC(geom, lambda X: u_in, boundary_inlet, component=0)
39bc_inlet_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_inlet, component=1)
40
41bc_outlet_p = dde.DirichletBC(geom, lambda X: 0.0, boundary_outlet, component=2)
42bc_outlet_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_outlet, component=1)
43
44def pde(X, Y):
45 """Args:
46 Y: network output [u, v, p],
47 X: input coordinates [x, y]"""
48 du_x,_ = dde.grad.jacobian(Y, X, i=0, j=0)
49 du_y,_ = dde.grad.jacobian(Y, X, i=0, j=1)
50 dv_x,_ = dde.grad.jacobian(Y, X, i=1, j=0)
51 dv_y,_ = dde.grad.jacobian(Y, X, i=1, j=1)
52 dp_x,_ = dde.grad.jacobian(Y, X, i=2, j=0)
53 dp_y,_ = dde.grad.jacobian(Y, X, i=2, j=1)
54
55 du_xx,_ = dde.grad.hessian(Y, X, component=0, i=0, j=0)
56 du_yy,_ = dde.grad.hessian(Y, X, component=0, i=1, j=1)
57 dv_xx,_ = dde.grad.hessian(Y, X, component=1, i=0, j=0)
58 dv_yy,_ = dde.grad.hessian(Y, X, component=1, i=1, j=1)
59
60 y_val, _ = Y
61 u_pred = y_val[:, 0:1]
62 v_pred = y_val[:, 1:2]
63 pde_u = u_pred * du_x + v_pred * du_y + 1 / rho * dp_x - (mu / rho) * (du_xx + du_yy)
64 pde_v = u_pred * dv_x + v_pred * dv_y + 1 / rho * dp_y - (mu / rho) * (dv_xx + dv_yy)
65 pde_cont = du_x + dv_y
66
67 return [pde_u, pde_v, pde_cont]
68
69data = dde.data.PDE(
70 geom,
71 pde,
72 [bc_wall_u, bc_wall_v, bc_inlet_u, bc_inlet_v, bc_outlet_p, bc_outlet_v],
73 num_domain=2000,
74 num_boundary=200,
75 num_test=200,
76)
77
78net = dde.maps.FNN([2] + [64] * 5 + [3], "tanh", "Glorot uniform")
79model = dde.Model(data, net)
80
81model.compile("adam", lr=1e-3)
82losshistory, train_state = model.train(epochs=10000)
83
84plt.figure(figsize=(10, 8))
85plt.scatter(data.train_x_all[:, 0], data.train_x_all[:, 1], s=0.5)
86plt.xlabel("x")
87plt.ylabel("y")
88plt.show()
89
90
91samples = geom.random_points(500000)
92result = model.predict(samples)
93
94color_legend = [[0, 1.5], [-0.3, 0.3], [0, 35]]
95for idx in range(3):
96 plt.figure(figsize=(20, 4))
97 plt.scatter(samples[:, 0], samples[:, 1], c=result[:, idx], cmap="jet", s=2)
98 plt.colorbar()
99 plt.clim(color_legend[idx])
100 plt.xlim((0 - L / 2, L - L / 2))
101 plt.ylim((0 - D / 2, D - D / 2))
102 plt.tight_layout()
103 plt.show()
ode.py Исходник
1import jax
2import jax.numpy as jnp
3import optax
4
5# 1. 纯函数式:初始化网络参数 (不使用任何网络框架)
6def init_pinn_params(rng_key, layers):
7 """手动初始化多层感知机的权重和偏置"""
8 params = []
9 keys = jax.random.split(rng_key, len(layers) - 1)
10
11 for i in range(len(layers) - 1):
12 in_dim = layers[i]
13 out_dim = layers[i+1]
14
15 # 使用 Xavier/Glorot 初始化方法
16 glorot_std = jnp.sqrt(2.0 / (in_dim + out_dim))
17
18 w = jax.random.normal(keys[i], (in_dim, out_dim)) * glorot_std
19 b = jnp.zeros((out_dim,))
20
21 params.append({'w': w, 'b': b})
22 return params
23
24# 2. 纯函数式:正向传播函数
25def pinn_forward(params, x):
26 """输入 x 形状为 (1,) 或 (batch, 1),通过参数字典进行前向计算"""
27 activation = x
28 # 遍历隐藏层
29 for layer in params[:-1]:
30 activation = jnp.tanh(jnp.dot(activation, layer['w']) + layer['b'])
31 # 输出层(不加激活函数)
32 final_layer = params[-1]
33 return jnp.dot(activation, final_layer['w']) + final_layer['b']
34
35# 3. 目标物理方程 f(x)
36def f_target(x):
37 return 1.0 / (1.0 + x**4)
38
39# 4. 物理残差计算 (单点)
40def pde_residual_single(params, x_single):
41 # 用 lambda 包装,确保 jax.grad 明确知道是对输入 x_single 求导
42 # 因为 pinn_forward 返回的是一个数组 [y],我们需要用 [0] 拿到标量
43 y_grad = jax.grad(lambda x: pinn_forward(params, x)[0])(x_single)
44 return y_grad - f_target(x_single[0])
45
46# 使用 vmap 自动将单点残差扩展到批量点
47pde_residual_batch = jax.vmap(pde_residual_single, in_axes=(None, 0))
48
49# 5. 损失函数
50def loss_fn(params, x_collocation, x_bc, y_bc):
51 # A. PDE 残差损失
52 preds_grad = pde_residual_batch(params, x_collocation)
53 loss_pde = jnp.mean(jnp.square(preds_grad))
54
55 # B. 边界条件损失
56 # 因为 x_bc 是批量的 (这里只有1个点),我们直接用 jax.vmap 处理 forward
57 preds_bc = jax.vmap(pinn_forward, in_axes=(None, 0))(params, x_bc)
58 loss_bc = jnp.mean(jnp.square(preds_bc - y_bc))
59
60 return loss_pde + 2.0 * loss_bc
61
62# 6. 训练主流程
63def train_pinn(epochs=3000, lr=3e-3):
64 # 架构:输入1维 -> 两个32维的隐藏层 -> 输出1维
65 layer_sizes = [1, 32, 32, 1]
66
67 # 初始化参数
68 rng_key = jax.random.PRNGKey(42)
69 params = init_pinn_params(rng_key, layer_sizes)
70
71 # 初始化 optax 优化器
72 optimizer = optax.adam(lr)
73 opt_state = optimizer.init(params)
74
75 # 准备数据
76 x_bc = jnp.array([[-1.0]])
77 y_bc = jnp.array([[0.0]])
78 x_collocation = jnp.linspace(-1.0, 1.0, 200).reshape(-1, 1)
79
80 # 纯函数式的单步训练,用 JIT 编译成静态图
81 @jax.jit
82 def train_step(params, opt_state, x_coll):
83 loss_val, grads = jax.value_and_grad(loss_fn)(params, x_coll, x_bc, y_bc)
84 updates, opt_state = optimizer.update(grads, opt_state)
85 params = optax.apply_updates(params, updates)
86 return params, opt_state, loss_val
87
88 print("开始纯函数式 PINN 训练...")
89 for epoch in range(epochs + 1):
90 params, opt_state, loss_val = train_step(params, opt_state, x_collocation)
91
92 if epoch % 500 == 0:
93 print(f"Epoch {epoch:4d} | Loss: {loss_val:.6f}")
94
95 return params
96
97# 运行并验证
98trained_params = train_pinn()
99
100# 预测 x = 1.0 处的值
101x_test = jnp.array([[1.0]])
102y_pred = pinn_forward(trained_params, x_test[0])
103print(f"\n预测完成!")
104print(f"纯函数 PINN 预测的 y(1) = {y_pred[0]:.6f}")
105print(f"标准数值解: 1.733946098729092")
ssh.ps1 Исходник
1ssh benjamin@cc.findnothing.click