import jax import jax.numpy as jnp import optax # 1. 纯函数式:初始化网络参数 (不使用任何网络框架) def init_pinn_params(rng_key, layers): """手动初始化多层感知机的权重和偏置""" params = [] keys = jax.random.split(rng_key, len(layers) - 1) for i in range(len(layers) - 1): in_dim = layers[i] out_dim = layers[i+1] # 使用 Xavier/Glorot 初始化方法 glorot_std = jnp.sqrt(2.0 / (in_dim + out_dim)) w = jax.random.normal(keys[i], (in_dim, out_dim)) * glorot_std b = jnp.zeros((out_dim,)) params.append({'w': w, 'b': b}) return params # 2. 纯函数式:正向传播函数 def pinn_forward(params, x): """输入 x 形状为 (1,) 或 (batch, 1),通过参数字典进行前向计算""" activation = x # 遍历隐藏层 for layer in params[:-1]: activation = jnp.tanh(jnp.dot(activation, layer['w']) + layer['b']) # 输出层(不加激活函数) final_layer = params[-1] return jnp.dot(activation, final_layer['w']) + final_layer['b'] # 3. 目标物理方程 f(x) def f_target(x): return 1.0 / (1.0 + x**4) # 4. 物理残差计算 (单点) def pde_residual_single(params, x_single): # 用 lambda 包装,确保 jax.grad 明确知道是对输入 x_single 求导 # 因为 pinn_forward 返回的是一个数组 [y],我们需要用 [0] 拿到标量 y_grad = jax.grad(lambda x: pinn_forward(params, x)[0])(x_single) return y_grad - f_target(x_single[0]) # 使用 vmap 自动将单点残差扩展到批量点 pde_residual_batch = jax.vmap(pde_residual_single, in_axes=(None, 0)) # 5. 损失函数 def loss_fn(params, x_collocation, x_bc, y_bc): # A. PDE 残差损失 preds_grad = pde_residual_batch(params, x_collocation) loss_pde = jnp.mean(jnp.square(preds_grad)) # B. 边界条件损失 # 因为 x_bc 是批量的 (这里只有1个点),我们直接用 jax.vmap 处理 forward preds_bc = jax.vmap(pinn_forward, in_axes=(None, 0))(params, x_bc) loss_bc = jnp.mean(jnp.square(preds_bc - y_bc)) return loss_pde + 2.0 * loss_bc # 6. 训练主流程 def train_pinn(epochs=3000, lr=3e-3): # 架构:输入1维 -> 两个32维的隐藏层 -> 输出1维 layer_sizes = [1, 32, 32, 1] # 初始化参数 rng_key = jax.random.PRNGKey(42) params = init_pinn_params(rng_key, layer_sizes) # 初始化 optax 优化器 optimizer = optax.adam(lr) opt_state = optimizer.init(params) # 准备数据 x_bc = jnp.array([[-1.0]]) y_bc = jnp.array([[0.0]]) x_collocation = jnp.linspace(-1.0, 1.0, 200).reshape(-1, 1) # 纯函数式的单步训练,用 JIT 编译成静态图 @jax.jit def train_step(params, opt_state, x_coll): loss_val, grads = jax.value_and_grad(loss_fn)(params, x_coll, x_bc, y_bc) updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) return params, opt_state, loss_val print("开始纯函数式 PINN 训练...") for epoch in range(epochs + 1): params, opt_state, loss_val = train_step(params, opt_state, x_collocation) if epoch % 500 == 0: print(f"Epoch {epoch:4d} | Loss: {loss_val:.6f}") return params # 运行并验证 trained_params = train_pinn() # 预测 x = 1.0 处的值 x_test = jnp.array([[1.0]]) y_pred = pinn_forward(trained_params, x_test[0]) print(f"\n预测完成!") print(f"纯函数 PINN 预测的 y(1) = {y_pred[0]:.6f}") print(f"标准数值解: 1.733946098729092")