Last active 1 week ago

zyppe's Avatar zyppe revised this gist 1 week ago. Go to revision

1 file changed, 1 insertion

ssh.ps1(file created)

@@ -0,0 +1 @@
1 + ssh benjamin@cc.findnothing.click

zyppe's Avatar zyppe revised this gist 1 week ago. Go to revision

1 file changed

ns.ode renamed to ns.py

File renamed without changes

zyppe's Avatar zyppe revised this gist 1 week ago. Go to revision

2 files changed, 208 insertions

ns.ode(file created)

@@ -0,0 +1,103 @@
1 + import flax
2 + import deepxde as dde
3 + import jax.numpy as np
4 + import matplotlib.pyplot as plt
5 +
6 + rho = 1
7 + mu = 1
8 + u_in = 1
9 + D = 1
10 + L = 2
11 +
12 + geom = dde.geometry.Rectangle(xmin=[-L / 2, -D / 2], xmax=[L / 2, D / 2])
13 + def boundary_wall(X, on_boundary):
14 + on_wall = np.logical_and(
15 + np.logical_or(
16 + np.isclose(X[1], -D / 2, rtol=1e-05, atol=1e-08),
17 + np.isclose(X[1], D / 2, rtol=1e-05, atol=1e-08),
18 + ),
19 + on_boundary,
20 + )
21 + return on_wall
22 +
23 + def boundary_inlet(X, on_boundary):
24 + on_inlet = np.logical_and(
25 + np.isclose(X[0], -L / 2, rtol=1e-05, atol=1e-08), on_boundary
26 + )
27 + return on_inlet
28 +
29 + def boundary_outlet(X, on_boundary):
30 + on_outlet = np.logical_and(
31 + np.isclose(X[0], L / 2, rtol=1e-05, atol=1e-08), on_boundary
32 + )
33 + return on_outlet
34 +
35 + bc_wall_u = dde.DirichletBC(geom, lambda X: 0.0, boundary_wall, component=0)
36 + bc_wall_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_wall, component=1)
37 +
38 + bc_inlet_u = dde.DirichletBC(geom, lambda X: u_in, boundary_inlet, component=0)
39 + bc_inlet_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_inlet, component=1)
40 +
41 + bc_outlet_p = dde.DirichletBC(geom, lambda X: 0.0, boundary_outlet, component=2)
42 + bc_outlet_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_outlet, component=1)
43 +
44 + def pde(X, Y):
45 + """Args:
46 + Y: network output [u, v, p],
47 + X: input coordinates [x, y]"""
48 + du_x,_ = dde.grad.jacobian(Y, X, i=0, j=0)
49 + du_y,_ = dde.grad.jacobian(Y, X, i=0, j=1)
50 + dv_x,_ = dde.grad.jacobian(Y, X, i=1, j=0)
51 + dv_y,_ = dde.grad.jacobian(Y, X, i=1, j=1)
52 + dp_x,_ = dde.grad.jacobian(Y, X, i=2, j=0)
53 + dp_y,_ = dde.grad.jacobian(Y, X, i=2, j=1)
54 +
55 + du_xx,_ = dde.grad.hessian(Y, X, component=0, i=0, j=0)
56 + du_yy,_ = dde.grad.hessian(Y, X, component=0, i=1, j=1)
57 + dv_xx,_ = dde.grad.hessian(Y, X, component=1, i=0, j=0)
58 + dv_yy,_ = dde.grad.hessian(Y, X, component=1, i=1, j=1)
59 +
60 + y_val, _ = Y
61 + u_pred = y_val[:, 0:1]
62 + v_pred = y_val[:, 1:2]
63 + pde_u = u_pred * du_x + v_pred * du_y + 1 / rho * dp_x - (mu / rho) * (du_xx + du_yy)
64 + pde_v = u_pred * dv_x + v_pred * dv_y + 1 / rho * dp_y - (mu / rho) * (dv_xx + dv_yy)
65 + pde_cont = du_x + dv_y
66 +
67 + return [pde_u, pde_v, pde_cont]
68 +
69 + data = dde.data.PDE(
70 + geom,
71 + pde,
72 + [bc_wall_u, bc_wall_v, bc_inlet_u, bc_inlet_v, bc_outlet_p, bc_outlet_v],
73 + num_domain=2000,
74 + num_boundary=200,
75 + num_test=200,
76 + )
77 +
78 + net = dde.maps.FNN([2] + [64] * 5 + [3], "tanh", "Glorot uniform")
79 + model = dde.Model(data, net)
80 +
81 + model.compile("adam", lr=1e-3)
82 + losshistory, train_state = model.train(epochs=10000)
83 +
84 + plt.figure(figsize=(10, 8))
85 + plt.scatter(data.train_x_all[:, 0], data.train_x_all[:, 1], s=0.5)
86 + plt.xlabel("x")
87 + plt.ylabel("y")
88 + plt.show()
89 +
90 +
91 + samples = geom.random_points(500000)
92 + result = model.predict(samples)
93 +
94 + color_legend = [[0, 1.5], [-0.3, 0.3], [0, 35]]
95 + for idx in range(3):
96 + plt.figure(figsize=(20, 4))
97 + plt.scatter(samples[:, 0], samples[:, 1], c=result[:, idx], cmap="jet", s=2)
98 + plt.colorbar()
99 + plt.clim(color_legend[idx])
100 + plt.xlim((0 - L / 2, L - L / 2))
101 + plt.ylim((0 - D / 2, D - D / 2))
102 + plt.tight_layout()
103 + plt.show()

ode.py(file created)

@@ -0,0 +1,105 @@
1 + import jax
2 + import jax.numpy as jnp
3 + import optax
4 +
5 + # 1. 纯函数式:初始化网络参数 (不使用任何网络框架)
6 + def init_pinn_params(rng_key, layers):
7 + """手动初始化多层感知机的权重和偏置"""
8 + params = []
9 + keys = jax.random.split(rng_key, len(layers) - 1)
10 +
11 + for i in range(len(layers) - 1):
12 + in_dim = layers[i]
13 + out_dim = layers[i+1]
14 +
15 + # 使用 Xavier/Glorot 初始化方法
16 + glorot_std = jnp.sqrt(2.0 / (in_dim + out_dim))
17 +
18 + w = jax.random.normal(keys[i], (in_dim, out_dim)) * glorot_std
19 + b = jnp.zeros((out_dim,))
20 +
21 + params.append({'w': w, 'b': b})
22 + return params
23 +
24 + # 2. 纯函数式:正向传播函数
25 + def pinn_forward(params, x):
26 + """输入 x 形状为 (1,) 或 (batch, 1),通过参数字典进行前向计算"""
27 + activation = x
28 + # 遍历隐藏层
29 + for layer in params[:-1]:
30 + activation = jnp.tanh(jnp.dot(activation, layer['w']) + layer['b'])
31 + # 输出层(不加激活函数)
32 + final_layer = params[-1]
33 + return jnp.dot(activation, final_layer['w']) + final_layer['b']
34 +
35 + # 3. 目标物理方程 f(x)
36 + def f_target(x):
37 + return 1.0 / (1.0 + x**4)
38 +
39 + # 4. 物理残差计算 (单点)
40 + def pde_residual_single(params, x_single):
41 + # 用 lambda 包装,确保 jax.grad 明确知道是对输入 x_single 求导
42 + # 因为 pinn_forward 返回的是一个数组 [y],我们需要用 [0] 拿到标量
43 + y_grad = jax.grad(lambda x: pinn_forward(params, x)[0])(x_single)
44 + return y_grad - f_target(x_single[0])
45 +
46 + # 使用 vmap 自动将单点残差扩展到批量点
47 + pde_residual_batch = jax.vmap(pde_residual_single, in_axes=(None, 0))
48 +
49 + # 5. 损失函数
50 + def loss_fn(params, x_collocation, x_bc, y_bc):
51 + # A. PDE 残差损失
52 + preds_grad = pde_residual_batch(params, x_collocation)
53 + loss_pde = jnp.mean(jnp.square(preds_grad))
54 +
55 + # B. 边界条件损失
56 + # 因为 x_bc 是批量的 (这里只有1个点),我们直接用 jax.vmap 处理 forward
57 + preds_bc = jax.vmap(pinn_forward, in_axes=(None, 0))(params, x_bc)
58 + loss_bc = jnp.mean(jnp.square(preds_bc - y_bc))
59 +
60 + return loss_pde + 2.0 * loss_bc
61 +
62 + # 6. 训练主流程
63 + def train_pinn(epochs=3000, lr=3e-3):
64 + # 架构:输入1维 -> 两个32维的隐藏层 -> 输出1维
65 + layer_sizes = [1, 32, 32, 1]
66 +
67 + # 初始化参数
68 + rng_key = jax.random.PRNGKey(42)
69 + params = init_pinn_params(rng_key, layer_sizes)
70 +
71 + # 初始化 optax 优化器
72 + optimizer = optax.adam(lr)
73 + opt_state = optimizer.init(params)
74 +
75 + # 准备数据
76 + x_bc = jnp.array([[-1.0]])
77 + y_bc = jnp.array([[0.0]])
78 + x_collocation = jnp.linspace(-1.0, 1.0, 200).reshape(-1, 1)
79 +
80 + # 纯函数式的单步训练,用 JIT 编译成静态图
81 + @jax.jit
82 + def train_step(params, opt_state, x_coll):
83 + loss_val, grads = jax.value_and_grad(loss_fn)(params, x_coll, x_bc, y_bc)
84 + updates, opt_state = optimizer.update(grads, opt_state)
85 + params = optax.apply_updates(params, updates)
86 + return params, opt_state, loss_val
87 +
88 + print("开始纯函数式 PINN 训练...")
89 + for epoch in range(epochs + 1):
90 + params, opt_state, loss_val = train_step(params, opt_state, x_collocation)
91 +
92 + if epoch % 500 == 0:
93 + print(f"Epoch {epoch:4d} | Loss: {loss_val:.6f}")
94 +
95 + return params
96 +
97 + # 运行并验证
98 + trained_params = train_pinn()
99 +
100 + # 预测 x = 1.0 处的值
101 + x_test = jnp.array([[1.0]])
102 + y_pred = pinn_forward(trained_params, x_test[0])
103 + print(f"\n预测完成!")
104 + print(f"纯函数 PINN 预测的 y(1) = {y_pred[0]:.6f}")
105 + print(f"标准数值解: 1.733946098729092")
Newer Older