zyppe revised this gist 1 week ago. Go to revision
1 file changed, 1 insertion
ssh.ps1(file created)
| @@ -0,0 +1 @@ | |||
| 1 | + | ssh benjamin@cc.findnothing.click | |
zyppe revised this gist 1 week ago. Go to revision
1 file changed
ns.ode renamed to ns.py
File renamed without changes
zyppe revised this gist 1 week ago. Go to revision
2 files changed, 208 insertions
ns.ode(file created)
| @@ -0,0 +1,103 @@ | |||
| 1 | + | import flax | |
| 2 | + | import deepxde as dde | |
| 3 | + | import jax.numpy as np | |
| 4 | + | import matplotlib.pyplot as plt | |
| 5 | + | ||
| 6 | + | rho = 1 | |
| 7 | + | mu = 1 | |
| 8 | + | u_in = 1 | |
| 9 | + | D = 1 | |
| 10 | + | L = 2 | |
| 11 | + | ||
| 12 | + | geom = dde.geometry.Rectangle(xmin=[-L / 2, -D / 2], xmax=[L / 2, D / 2]) | |
| 13 | + | def boundary_wall(X, on_boundary): | |
| 14 | + | on_wall = np.logical_and( | |
| 15 | + | np.logical_or( | |
| 16 | + | np.isclose(X[1], -D / 2, rtol=1e-05, atol=1e-08), | |
| 17 | + | np.isclose(X[1], D / 2, rtol=1e-05, atol=1e-08), | |
| 18 | + | ), | |
| 19 | + | on_boundary, | |
| 20 | + | ) | |
| 21 | + | return on_wall | |
| 22 | + | ||
| 23 | + | def boundary_inlet(X, on_boundary): | |
| 24 | + | on_inlet = np.logical_and( | |
| 25 | + | np.isclose(X[0], -L / 2, rtol=1e-05, atol=1e-08), on_boundary | |
| 26 | + | ) | |
| 27 | + | return on_inlet | |
| 28 | + | ||
| 29 | + | def boundary_outlet(X, on_boundary): | |
| 30 | + | on_outlet = np.logical_and( | |
| 31 | + | np.isclose(X[0], L / 2, rtol=1e-05, atol=1e-08), on_boundary | |
| 32 | + | ) | |
| 33 | + | return on_outlet | |
| 34 | + | ||
| 35 | + | bc_wall_u = dde.DirichletBC(geom, lambda X: 0.0, boundary_wall, component=0) | |
| 36 | + | bc_wall_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_wall, component=1) | |
| 37 | + | ||
| 38 | + | bc_inlet_u = dde.DirichletBC(geom, lambda X: u_in, boundary_inlet, component=0) | |
| 39 | + | bc_inlet_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_inlet, component=1) | |
| 40 | + | ||
| 41 | + | bc_outlet_p = dde.DirichletBC(geom, lambda X: 0.0, boundary_outlet, component=2) | |
| 42 | + | bc_outlet_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_outlet, component=1) | |
| 43 | + | ||
| 44 | + | def pde(X, Y): | |
| 45 | + | """Args: | |
| 46 | + | Y: network output [u, v, p], | |
| 47 | + | X: input coordinates [x, y]""" | |
| 48 | + | du_x,_ = dde.grad.jacobian(Y, X, i=0, j=0) | |
| 49 | + | du_y,_ = dde.grad.jacobian(Y, X, i=0, j=1) | |
| 50 | + | dv_x,_ = dde.grad.jacobian(Y, X, i=1, j=0) | |
| 51 | + | dv_y,_ = dde.grad.jacobian(Y, X, i=1, j=1) | |
| 52 | + | dp_x,_ = dde.grad.jacobian(Y, X, i=2, j=0) | |
| 53 | + | dp_y,_ = dde.grad.jacobian(Y, X, i=2, j=1) | |
| 54 | + | ||
| 55 | + | du_xx,_ = dde.grad.hessian(Y, X, component=0, i=0, j=0) | |
| 56 | + | du_yy,_ = dde.grad.hessian(Y, X, component=0, i=1, j=1) | |
| 57 | + | dv_xx,_ = dde.grad.hessian(Y, X, component=1, i=0, j=0) | |
| 58 | + | dv_yy,_ = dde.grad.hessian(Y, X, component=1, i=1, j=1) | |
| 59 | + | ||
| 60 | + | y_val, _ = Y | |
| 61 | + | u_pred = y_val[:, 0:1] | |
| 62 | + | v_pred = y_val[:, 1:2] | |
| 63 | + | pde_u = u_pred * du_x + v_pred * du_y + 1 / rho * dp_x - (mu / rho) * (du_xx + du_yy) | |
| 64 | + | pde_v = u_pred * dv_x + v_pred * dv_y + 1 / rho * dp_y - (mu / rho) * (dv_xx + dv_yy) | |
| 65 | + | pde_cont = du_x + dv_y | |
| 66 | + | ||
| 67 | + | return [pde_u, pde_v, pde_cont] | |
| 68 | + | ||
| 69 | + | data = dde.data.PDE( | |
| 70 | + | geom, | |
| 71 | + | pde, | |
| 72 | + | [bc_wall_u, bc_wall_v, bc_inlet_u, bc_inlet_v, bc_outlet_p, bc_outlet_v], | |
| 73 | + | num_domain=2000, | |
| 74 | + | num_boundary=200, | |
| 75 | + | num_test=200, | |
| 76 | + | ) | |
| 77 | + | ||
| 78 | + | net = dde.maps.FNN([2] + [64] * 5 + [3], "tanh", "Glorot uniform") | |
| 79 | + | model = dde.Model(data, net) | |
| 80 | + | ||
| 81 | + | model.compile("adam", lr=1e-3) | |
| 82 | + | losshistory, train_state = model.train(epochs=10000) | |
| 83 | + | ||
| 84 | + | plt.figure(figsize=(10, 8)) | |
| 85 | + | plt.scatter(data.train_x_all[:, 0], data.train_x_all[:, 1], s=0.5) | |
| 86 | + | plt.xlabel("x") | |
| 87 | + | plt.ylabel("y") | |
| 88 | + | plt.show() | |
| 89 | + | ||
| 90 | + | ||
| 91 | + | samples = geom.random_points(500000) | |
| 92 | + | result = model.predict(samples) | |
| 93 | + | ||
| 94 | + | color_legend = [[0, 1.5], [-0.3, 0.3], [0, 35]] | |
| 95 | + | for idx in range(3): | |
| 96 | + | plt.figure(figsize=(20, 4)) | |
| 97 | + | plt.scatter(samples[:, 0], samples[:, 1], c=result[:, idx], cmap="jet", s=2) | |
| 98 | + | plt.colorbar() | |
| 99 | + | plt.clim(color_legend[idx]) | |
| 100 | + | plt.xlim((0 - L / 2, L - L / 2)) | |
| 101 | + | plt.ylim((0 - D / 2, D - D / 2)) | |
| 102 | + | plt.tight_layout() | |
| 103 | + | plt.show() | |
ode.py(file created)
| @@ -0,0 +1,105 @@ | |||
| 1 | + | import jax | |
| 2 | + | import jax.numpy as jnp | |
| 3 | + | import optax | |
| 4 | + | ||
| 5 | + | # 1. 纯函数式:初始化网络参数 (不使用任何网络框架) | |
| 6 | + | def init_pinn_params(rng_key, layers): | |
| 7 | + | """手动初始化多层感知机的权重和偏置""" | |
| 8 | + | params = [] | |
| 9 | + | keys = jax.random.split(rng_key, len(layers) - 1) | |
| 10 | + | ||
| 11 | + | for i in range(len(layers) - 1): | |
| 12 | + | in_dim = layers[i] | |
| 13 | + | out_dim = layers[i+1] | |
| 14 | + | ||
| 15 | + | # 使用 Xavier/Glorot 初始化方法 | |
| 16 | + | glorot_std = jnp.sqrt(2.0 / (in_dim + out_dim)) | |
| 17 | + | ||
| 18 | + | w = jax.random.normal(keys[i], (in_dim, out_dim)) * glorot_std | |
| 19 | + | b = jnp.zeros((out_dim,)) | |
| 20 | + | ||
| 21 | + | params.append({'w': w, 'b': b}) | |
| 22 | + | return params | |
| 23 | + | ||
| 24 | + | # 2. 纯函数式:正向传播函数 | |
| 25 | + | def pinn_forward(params, x): | |
| 26 | + | """输入 x 形状为 (1,) 或 (batch, 1),通过参数字典进行前向计算""" | |
| 27 | + | activation = x | |
| 28 | + | # 遍历隐藏层 | |
| 29 | + | for layer in params[:-1]: | |
| 30 | + | activation = jnp.tanh(jnp.dot(activation, layer['w']) + layer['b']) | |
| 31 | + | # 输出层(不加激活函数) | |
| 32 | + | final_layer = params[-1] | |
| 33 | + | return jnp.dot(activation, final_layer['w']) + final_layer['b'] | |
| 34 | + | ||
| 35 | + | # 3. 目标物理方程 f(x) | |
| 36 | + | def f_target(x): | |
| 37 | + | return 1.0 / (1.0 + x**4) | |
| 38 | + | ||
| 39 | + | # 4. 物理残差计算 (单点) | |
| 40 | + | def pde_residual_single(params, x_single): | |
| 41 | + | # 用 lambda 包装,确保 jax.grad 明确知道是对输入 x_single 求导 | |
| 42 | + | # 因为 pinn_forward 返回的是一个数组 [y],我们需要用 [0] 拿到标量 | |
| 43 | + | y_grad = jax.grad(lambda x: pinn_forward(params, x)[0])(x_single) | |
| 44 | + | return y_grad - f_target(x_single[0]) | |
| 45 | + | ||
| 46 | + | # 使用 vmap 自动将单点残差扩展到批量点 | |
| 47 | + | pde_residual_batch = jax.vmap(pde_residual_single, in_axes=(None, 0)) | |
| 48 | + | ||
| 49 | + | # 5. 损失函数 | |
| 50 | + | def loss_fn(params, x_collocation, x_bc, y_bc): | |
| 51 | + | # A. PDE 残差损失 | |
| 52 | + | preds_grad = pde_residual_batch(params, x_collocation) | |
| 53 | + | loss_pde = jnp.mean(jnp.square(preds_grad)) | |
| 54 | + | ||
| 55 | + | # B. 边界条件损失 | |
| 56 | + | # 因为 x_bc 是批量的 (这里只有1个点),我们直接用 jax.vmap 处理 forward | |
| 57 | + | preds_bc = jax.vmap(pinn_forward, in_axes=(None, 0))(params, x_bc) | |
| 58 | + | loss_bc = jnp.mean(jnp.square(preds_bc - y_bc)) | |
| 59 | + | ||
| 60 | + | return loss_pde + 2.0 * loss_bc | |
| 61 | + | ||
| 62 | + | # 6. 训练主流程 | |
| 63 | + | def train_pinn(epochs=3000, lr=3e-3): | |
| 64 | + | # 架构:输入1维 -> 两个32维的隐藏层 -> 输出1维 | |
| 65 | + | layer_sizes = [1, 32, 32, 1] | |
| 66 | + | ||
| 67 | + | # 初始化参数 | |
| 68 | + | rng_key = jax.random.PRNGKey(42) | |
| 69 | + | params = init_pinn_params(rng_key, layer_sizes) | |
| 70 | + | ||
| 71 | + | # 初始化 optax 优化器 | |
| 72 | + | optimizer = optax.adam(lr) | |
| 73 | + | opt_state = optimizer.init(params) | |
| 74 | + | ||
| 75 | + | # 准备数据 | |
| 76 | + | x_bc = jnp.array([[-1.0]]) | |
| 77 | + | y_bc = jnp.array([[0.0]]) | |
| 78 | + | x_collocation = jnp.linspace(-1.0, 1.0, 200).reshape(-1, 1) | |
| 79 | + | ||
| 80 | + | # 纯函数式的单步训练,用 JIT 编译成静态图 | |
| 81 | + | @jax.jit | |
| 82 | + | def train_step(params, opt_state, x_coll): | |
| 83 | + | loss_val, grads = jax.value_and_grad(loss_fn)(params, x_coll, x_bc, y_bc) | |
| 84 | + | updates, opt_state = optimizer.update(grads, opt_state) | |
| 85 | + | params = optax.apply_updates(params, updates) | |
| 86 | + | return params, opt_state, loss_val | |
| 87 | + | ||
| 88 | + | print("开始纯函数式 PINN 训练...") | |
| 89 | + | for epoch in range(epochs + 1): | |
| 90 | + | params, opt_state, loss_val = train_step(params, opt_state, x_collocation) | |
| 91 | + | ||
| 92 | + | if epoch % 500 == 0: | |
| 93 | + | print(f"Epoch {epoch:4d} | Loss: {loss_val:.6f}") | |
| 94 | + | ||
| 95 | + | return params | |
| 96 | + | ||
| 97 | + | # 运行并验证 | |
| 98 | + | trained_params = train_pinn() | |
| 99 | + | ||
| 100 | + | # 预测 x = 1.0 处的值 | |
| 101 | + | x_test = jnp.array([[1.0]]) | |
| 102 | + | y_pred = pinn_forward(trained_params, x_test[0]) | |
| 103 | + | print(f"\n预测完成!") | |
| 104 | + | print(f"纯函数 PINN 预测的 y(1) = {y_pred[0]:.6f}") | |
| 105 | + | print(f"标准数值解: 1.733946098729092") | |