ns.py
· 3.1 KiB · Python
Brut
import flax
import deepxde as dde
import jax.numpy as np
import matplotlib.pyplot as plt
rho = 1
mu = 1
u_in = 1
D = 1
L = 2
geom = dde.geometry.Rectangle(xmin=[-L / 2, -D / 2], xmax=[L / 2, D / 2])
def boundary_wall(X, on_boundary):
on_wall = np.logical_and(
np.logical_or(
np.isclose(X[1], -D / 2, rtol=1e-05, atol=1e-08),
np.isclose(X[1], D / 2, rtol=1e-05, atol=1e-08),
),
on_boundary,
)
return on_wall
def boundary_inlet(X, on_boundary):
on_inlet = np.logical_and(
np.isclose(X[0], -L / 2, rtol=1e-05, atol=1e-08), on_boundary
)
return on_inlet
def boundary_outlet(X, on_boundary):
on_outlet = np.logical_and(
np.isclose(X[0], L / 2, rtol=1e-05, atol=1e-08), on_boundary
)
return on_outlet
bc_wall_u = dde.DirichletBC(geom, lambda X: 0.0, boundary_wall, component=0)
bc_wall_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_wall, component=1)
bc_inlet_u = dde.DirichletBC(geom, lambda X: u_in, boundary_inlet, component=0)
bc_inlet_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_inlet, component=1)
bc_outlet_p = dde.DirichletBC(geom, lambda X: 0.0, boundary_outlet, component=2)
bc_outlet_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_outlet, component=1)
def pde(X, Y):
"""Args:
Y: network output [u, v, p],
X: input coordinates [x, y]"""
du_x,_ = dde.grad.jacobian(Y, X, i=0, j=0)
du_y,_ = dde.grad.jacobian(Y, X, i=0, j=1)
dv_x,_ = dde.grad.jacobian(Y, X, i=1, j=0)
dv_y,_ = dde.grad.jacobian(Y, X, i=1, j=1)
dp_x,_ = dde.grad.jacobian(Y, X, i=2, j=0)
dp_y,_ = dde.grad.jacobian(Y, X, i=2, j=1)
du_xx,_ = dde.grad.hessian(Y, X, component=0, i=0, j=0)
du_yy,_ = dde.grad.hessian(Y, X, component=0, i=1, j=1)
dv_xx,_ = dde.grad.hessian(Y, X, component=1, i=0, j=0)
dv_yy,_ = dde.grad.hessian(Y, X, component=1, i=1, j=1)
y_val, _ = Y
u_pred = y_val[:, 0:1]
v_pred = y_val[:, 1:2]
pde_u = u_pred * du_x + v_pred * du_y + 1 / rho * dp_x - (mu / rho) * (du_xx + du_yy)
pde_v = u_pred * dv_x + v_pred * dv_y + 1 / rho * dp_y - (mu / rho) * (dv_xx + dv_yy)
pde_cont = du_x + dv_y
return [pde_u, pde_v, pde_cont]
data = dde.data.PDE(
geom,
pde,
[bc_wall_u, bc_wall_v, bc_inlet_u, bc_inlet_v, bc_outlet_p, bc_outlet_v],
num_domain=2000,
num_boundary=200,
num_test=200,
)
net = dde.maps.FNN([2] + [64] * 5 + [3], "tanh", "Glorot uniform")
model = dde.Model(data, net)
model.compile("adam", lr=1e-3)
losshistory, train_state = model.train(epochs=10000)
plt.figure(figsize=(10, 8))
plt.scatter(data.train_x_all[:, 0], data.train_x_all[:, 1], s=0.5)
plt.xlabel("x")
plt.ylabel("y")
plt.show()
samples = geom.random_points(500000)
result = model.predict(samples)
color_legend = [[0, 1.5], [-0.3, 0.3], [0, 35]]
for idx in range(3):
plt.figure(figsize=(20, 4))
plt.scatter(samples[:, 0], samples[:, 1], c=result[:, idx], cmap="jet", s=2)
plt.colorbar()
plt.clim(color_legend[idx])
plt.xlim((0 - L / 2, L - L / 2))
plt.ylim((0 - D / 2, D - D / 2))
plt.tight_layout()
plt.show()
| 1 | import flax |
| 2 | import deepxde as dde |
| 3 | import jax.numpy as np |
| 4 | import matplotlib.pyplot as plt |
| 5 | |
| 6 | rho = 1 |
| 7 | mu = 1 |
| 8 | u_in = 1 |
| 9 | D = 1 |
| 10 | L = 2 |
| 11 | |
| 12 | geom = dde.geometry.Rectangle(xmin=[-L / 2, -D / 2], xmax=[L / 2, D / 2]) |
| 13 | def boundary_wall(X, on_boundary): |
| 14 | on_wall = np.logical_and( |
| 15 | np.logical_or( |
| 16 | np.isclose(X[1], -D / 2, rtol=1e-05, atol=1e-08), |
| 17 | np.isclose(X[1], D / 2, rtol=1e-05, atol=1e-08), |
| 18 | ), |
| 19 | on_boundary, |
| 20 | ) |
| 21 | return on_wall |
| 22 | |
| 23 | def boundary_inlet(X, on_boundary): |
| 24 | on_inlet = np.logical_and( |
| 25 | np.isclose(X[0], -L / 2, rtol=1e-05, atol=1e-08), on_boundary |
| 26 | ) |
| 27 | return on_inlet |
| 28 | |
| 29 | def boundary_outlet(X, on_boundary): |
| 30 | on_outlet = np.logical_and( |
| 31 | np.isclose(X[0], L / 2, rtol=1e-05, atol=1e-08), on_boundary |
| 32 | ) |
| 33 | return on_outlet |
| 34 | |
| 35 | bc_wall_u = dde.DirichletBC(geom, lambda X: 0.0, boundary_wall, component=0) |
| 36 | bc_wall_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_wall, component=1) |
| 37 | |
| 38 | bc_inlet_u = dde.DirichletBC(geom, lambda X: u_in, boundary_inlet, component=0) |
| 39 | bc_inlet_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_inlet, component=1) |
| 40 | |
| 41 | bc_outlet_p = dde.DirichletBC(geom, lambda X: 0.0, boundary_outlet, component=2) |
| 42 | bc_outlet_v = dde.DirichletBC(geom, lambda X: 0.0, boundary_outlet, component=1) |
| 43 | |
| 44 | def pde(X, Y): |
| 45 | """Args: |
| 46 | Y: network output [u, v, p], |
| 47 | X: input coordinates [x, y]""" |
| 48 | du_x,_ = dde.grad.jacobian(Y, X, i=0, j=0) |
| 49 | du_y,_ = dde.grad.jacobian(Y, X, i=0, j=1) |
| 50 | dv_x,_ = dde.grad.jacobian(Y, X, i=1, j=0) |
| 51 | dv_y,_ = dde.grad.jacobian(Y, X, i=1, j=1) |
| 52 | dp_x,_ = dde.grad.jacobian(Y, X, i=2, j=0) |
| 53 | dp_y,_ = dde.grad.jacobian(Y, X, i=2, j=1) |
| 54 | |
| 55 | du_xx,_ = dde.grad.hessian(Y, X, component=0, i=0, j=0) |
| 56 | du_yy,_ = dde.grad.hessian(Y, X, component=0, i=1, j=1) |
| 57 | dv_xx,_ = dde.grad.hessian(Y, X, component=1, i=0, j=0) |
| 58 | dv_yy,_ = dde.grad.hessian(Y, X, component=1, i=1, j=1) |
| 59 | |
| 60 | y_val, _ = Y |
| 61 | u_pred = y_val[:, 0:1] |
| 62 | v_pred = y_val[:, 1:2] |
| 63 | pde_u = u_pred * du_x + v_pred * du_y + 1 / rho * dp_x - (mu / rho) * (du_xx + du_yy) |
| 64 | pde_v = u_pred * dv_x + v_pred * dv_y + 1 / rho * dp_y - (mu / rho) * (dv_xx + dv_yy) |
| 65 | pde_cont = du_x + dv_y |
| 66 | |
| 67 | return [pde_u, pde_v, pde_cont] |
| 68 | |
| 69 | data = dde.data.PDE( |
| 70 | geom, |
| 71 | pde, |
| 72 | [bc_wall_u, bc_wall_v, bc_inlet_u, bc_inlet_v, bc_outlet_p, bc_outlet_v], |
| 73 | num_domain=2000, |
| 74 | num_boundary=200, |
| 75 | num_test=200, |
| 76 | ) |
| 77 | |
| 78 | net = dde.maps.FNN([2] + [64] * 5 + [3], "tanh", "Glorot uniform") |
| 79 | model = dde.Model(data, net) |
| 80 | |
| 81 | model.compile("adam", lr=1e-3) |
| 82 | losshistory, train_state = model.train(epochs=10000) |
| 83 | |
| 84 | plt.figure(figsize=(10, 8)) |
| 85 | plt.scatter(data.train_x_all[:, 0], data.train_x_all[:, 1], s=0.5) |
| 86 | plt.xlabel("x") |
| 87 | plt.ylabel("y") |
| 88 | plt.show() |
| 89 | |
| 90 | |
| 91 | samples = geom.random_points(500000) |
| 92 | result = model.predict(samples) |
| 93 | |
| 94 | color_legend = [[0, 1.5], [-0.3, 0.3], [0, 35]] |
| 95 | for idx in range(3): |
| 96 | plt.figure(figsize=(20, 4)) |
| 97 | plt.scatter(samples[:, 0], samples[:, 1], c=result[:, idx], cmap="jet", s=2) |
| 98 | plt.colorbar() |
| 99 | plt.clim(color_legend[idx]) |
| 100 | plt.xlim((0 - L / 2, L - L / 2)) |
| 101 | plt.ylim((0 - D / 2, D - D / 2)) |
| 102 | plt.tight_layout() |
| 103 | plt.show() |
ode.py
· 3.5 KiB · Python
Brut
import jax
import jax.numpy as jnp
import optax
# 1. 纯函数式:初始化网络参数 (不使用任何网络框架)
def init_pinn_params(rng_key, layers):
"""手动初始化多层感知机的权重和偏置"""
params = []
keys = jax.random.split(rng_key, len(layers) - 1)
for i in range(len(layers) - 1):
in_dim = layers[i]
out_dim = layers[i+1]
# 使用 Xavier/Glorot 初始化方法
glorot_std = jnp.sqrt(2.0 / (in_dim + out_dim))
w = jax.random.normal(keys[i], (in_dim, out_dim)) * glorot_std
b = jnp.zeros((out_dim,))
params.append({'w': w, 'b': b})
return params
# 2. 纯函数式:正向传播函数
def pinn_forward(params, x):
"""输入 x 形状为 (1,) 或 (batch, 1),通过参数字典进行前向计算"""
activation = x
# 遍历隐藏层
for layer in params[:-1]:
activation = jnp.tanh(jnp.dot(activation, layer['w']) + layer['b'])
# 输出层(不加激活函数)
final_layer = params[-1]
return jnp.dot(activation, final_layer['w']) + final_layer['b']
# 3. 目标物理方程 f(x)
def f_target(x):
return 1.0 / (1.0 + x**4)
# 4. 物理残差计算 (单点)
def pde_residual_single(params, x_single):
# 用 lambda 包装,确保 jax.grad 明确知道是对输入 x_single 求导
# 因为 pinn_forward 返回的是一个数组 [y],我们需要用 [0] 拿到标量
y_grad = jax.grad(lambda x: pinn_forward(params, x)[0])(x_single)
return y_grad - f_target(x_single[0])
# 使用 vmap 自动将单点残差扩展到批量点
pde_residual_batch = jax.vmap(pde_residual_single, in_axes=(None, 0))
# 5. 损失函数
def loss_fn(params, x_collocation, x_bc, y_bc):
# A. PDE 残差损失
preds_grad = pde_residual_batch(params, x_collocation)
loss_pde = jnp.mean(jnp.square(preds_grad))
# B. 边界条件损失
# 因为 x_bc 是批量的 (这里只有1个点),我们直接用 jax.vmap 处理 forward
preds_bc = jax.vmap(pinn_forward, in_axes=(None, 0))(params, x_bc)
loss_bc = jnp.mean(jnp.square(preds_bc - y_bc))
return loss_pde + 2.0 * loss_bc
# 6. 训练主流程
def train_pinn(epochs=3000, lr=3e-3):
# 架构:输入1维 -> 两个32维的隐藏层 -> 输出1维
layer_sizes = [1, 32, 32, 1]
# 初始化参数
rng_key = jax.random.PRNGKey(42)
params = init_pinn_params(rng_key, layer_sizes)
# 初始化 optax 优化器
optimizer = optax.adam(lr)
opt_state = optimizer.init(params)
# 准备数据
x_bc = jnp.array([[-1.0]])
y_bc = jnp.array([[0.0]])
x_collocation = jnp.linspace(-1.0, 1.0, 200).reshape(-1, 1)
# 纯函数式的单步训练,用 JIT 编译成静态图
@jax.jit
def train_step(params, opt_state, x_coll):
loss_val, grads = jax.value_and_grad(loss_fn)(params, x_coll, x_bc, y_bc)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_val
print("开始纯函数式 PINN 训练...")
for epoch in range(epochs + 1):
params, opt_state, loss_val = train_step(params, opt_state, x_collocation)
if epoch % 500 == 0:
print(f"Epoch {epoch:4d} | Loss: {loss_val:.6f}")
return params
# 运行并验证
trained_params = train_pinn()
# 预测 x = 1.0 处的值
x_test = jnp.array([[1.0]])
y_pred = pinn_forward(trained_params, x_test[0])
print(f"\n预测完成!")
print(f"纯函数 PINN 预测的 y(1) = {y_pred[0]:.6f}")
print(f"标准数值解: 1.733946098729092")
| 1 | import jax |
| 2 | import jax.numpy as jnp |
| 3 | import optax |
| 4 | |
| 5 | # 1. 纯函数式:初始化网络参数 (不使用任何网络框架) |
| 6 | def init_pinn_params(rng_key, layers): |
| 7 | """手动初始化多层感知机的权重和偏置""" |
| 8 | params = [] |
| 9 | keys = jax.random.split(rng_key, len(layers) - 1) |
| 10 | |
| 11 | for i in range(len(layers) - 1): |
| 12 | in_dim = layers[i] |
| 13 | out_dim = layers[i+1] |
| 14 | |
| 15 | # 使用 Xavier/Glorot 初始化方法 |
| 16 | glorot_std = jnp.sqrt(2.0 / (in_dim + out_dim)) |
| 17 | |
| 18 | w = jax.random.normal(keys[i], (in_dim, out_dim)) * glorot_std |
| 19 | b = jnp.zeros((out_dim,)) |
| 20 | |
| 21 | params.append({'w': w, 'b': b}) |
| 22 | return params |
| 23 | |
| 24 | # 2. 纯函数式:正向传播函数 |
| 25 | def pinn_forward(params, x): |
| 26 | """输入 x 形状为 (1,) 或 (batch, 1),通过参数字典进行前向计算""" |
| 27 | activation = x |
| 28 | # 遍历隐藏层 |
| 29 | for layer in params[:-1]: |
| 30 | activation = jnp.tanh(jnp.dot(activation, layer['w']) + layer['b']) |
| 31 | # 输出层(不加激活函数) |
| 32 | final_layer = params[-1] |
| 33 | return jnp.dot(activation, final_layer['w']) + final_layer['b'] |
| 34 | |
| 35 | # 3. 目标物理方程 f(x) |
| 36 | def f_target(x): |
| 37 | return 1.0 / (1.0 + x**4) |
| 38 | |
| 39 | # 4. 物理残差计算 (单点) |
| 40 | def pde_residual_single(params, x_single): |
| 41 | # 用 lambda 包装,确保 jax.grad 明确知道是对输入 x_single 求导 |
| 42 | # 因为 pinn_forward 返回的是一个数组 [y],我们需要用 [0] 拿到标量 |
| 43 | y_grad = jax.grad(lambda x: pinn_forward(params, x)[0])(x_single) |
| 44 | return y_grad - f_target(x_single[0]) |
| 45 | |
| 46 | # 使用 vmap 自动将单点残差扩展到批量点 |
| 47 | pde_residual_batch = jax.vmap(pde_residual_single, in_axes=(None, 0)) |
| 48 | |
| 49 | # 5. 损失函数 |
| 50 | def loss_fn(params, x_collocation, x_bc, y_bc): |
| 51 | # A. PDE 残差损失 |
| 52 | preds_grad = pde_residual_batch(params, x_collocation) |
| 53 | loss_pde = jnp.mean(jnp.square(preds_grad)) |
| 54 | |
| 55 | # B. 边界条件损失 |
| 56 | # 因为 x_bc 是批量的 (这里只有1个点),我们直接用 jax.vmap 处理 forward |
| 57 | preds_bc = jax.vmap(pinn_forward, in_axes=(None, 0))(params, x_bc) |
| 58 | loss_bc = jnp.mean(jnp.square(preds_bc - y_bc)) |
| 59 | |
| 60 | return loss_pde + 2.0 * loss_bc |
| 61 | |
| 62 | # 6. 训练主流程 |
| 63 | def train_pinn(epochs=3000, lr=3e-3): |
| 64 | # 架构:输入1维 -> 两个32维的隐藏层 -> 输出1维 |
| 65 | layer_sizes = [1, 32, 32, 1] |
| 66 | |
| 67 | # 初始化参数 |
| 68 | rng_key = jax.random.PRNGKey(42) |
| 69 | params = init_pinn_params(rng_key, layer_sizes) |
| 70 | |
| 71 | # 初始化 optax 优化器 |
| 72 | optimizer = optax.adam(lr) |
| 73 | opt_state = optimizer.init(params) |
| 74 | |
| 75 | # 准备数据 |
| 76 | x_bc = jnp.array([[-1.0]]) |
| 77 | y_bc = jnp.array([[0.0]]) |
| 78 | x_collocation = jnp.linspace(-1.0, 1.0, 200).reshape(-1, 1) |
| 79 | |
| 80 | # 纯函数式的单步训练,用 JIT 编译成静态图 |
| 81 | @jax.jit |
| 82 | def train_step(params, opt_state, x_coll): |
| 83 | loss_val, grads = jax.value_and_grad(loss_fn)(params, x_coll, x_bc, y_bc) |
| 84 | updates, opt_state = optimizer.update(grads, opt_state) |
| 85 | params = optax.apply_updates(params, updates) |
| 86 | return params, opt_state, loss_val |
| 87 | |
| 88 | print("开始纯函数式 PINN 训练...") |
| 89 | for epoch in range(epochs + 1): |
| 90 | params, opt_state, loss_val = train_step(params, opt_state, x_collocation) |
| 91 | |
| 92 | if epoch % 500 == 0: |
| 93 | print(f"Epoch {epoch:4d} | Loss: {loss_val:.6f}") |
| 94 | |
| 95 | return params |
| 96 | |
| 97 | # 运行并验证 |
| 98 | trained_params = train_pinn() |
| 99 | |
| 100 | # 预测 x = 1.0 处的值 |
| 101 | x_test = jnp.array([[1.0]]) |
| 102 | y_pred = pinn_forward(trained_params, x_test[0]) |
| 103 | print(f"\n预测完成!") |
| 104 | print(f"纯函数 PINN 预测的 y(1) = {y_pred[0]:.6f}") |
| 105 | print(f"标准数值解: 1.733946098729092") |