GK-Transformer (Galerkin Transformer)¶
The Galerkin Transformer is an attention-based operator learner that replaces softmax attention with a projection-based formulation.
Reference Paper: Cao, NeurIPS 2021.
@article{cao2021choose,
title={Choose a transformer: Fourier or galerkin},
author={Cao, Shuhao},
journal={Advances in neural information processing systems},
volume={34},
pages={24924--24940},
year={2021}
}
RealPDEBench interface¶
- Input:
xwith shape[B, T_in, H, W, C_in] - Output: predictions with shape
[B, T_out, H, W, C_out]
Implementation notes (RealPDEBench)¶
- If no coordinate grid is provided, the model constructs a normalized
(t, x, y)grid internally. - The regressor maps latent representations back to field space and then reshapes to the benchmark time dimension.
- Multi-step outputs assume \(T_\text{out}\) is a multiple of \(T_\text{in}\).
- Training uses an MSE-style objective on predicted fields.
GK-Transformer specific YAML config¶
In RealPDEBench, GK-Transformer is implemented as the Galerkin Transformer baseline and is enabled by:
model_name: "galerkin_transformer"
Config files¶
- Cylinder:
realpdebench/configs/cylinder/galerkin_transformer.yaml - Controlled Cylinder:
realpdebench/configs/controlled_cylinder/galerkin_transformer.yaml - FSI:
realpdebench/configs/fsi/galerkin_transformer.yaml - Foil:
realpdebench/configs/foil/galerkin_transformer.yaml - Combustion:
realpdebench/configs/combustion/galerkin_transformer.yaml
Model-specific keys¶
These keys are passed as **kwargs into realpdebench.model.galerkin_transformer.GalerkinTransformer3d.
The wrapper also auto-fills shape_in, shape_out, node_feats, and n_targets from the dataset.
pos_dim(int): Dimension of the positional encoding input to attention (whenposis provided / enabled).n_hidden(int): Hidden feature dimension of tokens inside the transformer.num_feat_layers(int): Number of graph feature-extractor layers whenfeat_extract_typeis enabled (e.g., GCN/GAT).num_encoder_layers(int): Number of transformer encoder layers.n_head(int): Number of attention heads.dim_feedforward(int): Feed-forward hidden dimension inside each encoder layer (defaults to \(2\times n_\text{hidden}\) if unset).feat_extract_type(str/null): Optional feature extractor type before attention (commonlynull,"gcn", or"gat").attention_type(str): Attention kernel type. Supported families include"fourier","integral","local","global","cosine","galerkin","linear","softmax", and"official".xavier_init(float): Xavier/Glorot init gain for attention projections (used in some attention types).diagonal_weight(float): Diagonal initialization weight used by some attention parameterizations.symmetric_init(bool): Whether to use symmetric initialization in attention (implementation-dependent).layer_norm(bool): Whether to apply layer normalization in encoder blocks.attn_norm(bool): Whether to normalize inside attention (often enabled whenlayer_normis disabled).norm_eps(float): Epsilon for normalization layers (numerical stability).batch_norm(bool): Whether to use batch normalization in feed-forward blocks (implementation-dependent).return_attn_weight(bool): Whether to return/store attention weights (mainly for debugging/analysis).return_latent(bool): Whether to store intermediate latent states from encoder layers.decoder_type(str): Decoder/regressor type. In the current wrapper, common values include"ifft2"(spectral regressor) and"pointwise".spacial_dim(int): Spatial coordinate dimension used whenspacial_fcis enabled (1,2, or3).spacial_fc(bool): Whether to concatenate spatial coordinates into the feature MLPs (coordinate-aware projection).upsample_mode(str): Upsampling mode used by the optionalUpScaler(e.g.,"interp").downsample_mode(str): Downsampling mode used by the optionalDownScaler(e.g.,"interp").freq_dim(int): Frequency-domain hidden dimension used by the spectral regressor (decoder_type: "ifft2").boundary_condition(str/null): Boundary condition handling mode (e.g.,None/"dirichlet"), used by some regressor variants.num_regressor_layers(int): Number of layers in the regressor/decoder head.fourier_modes_x(int): Number of Fourier modes in the x dimension for the spectral regressor.fourier_modes_y(int): Number of Fourier modes in the y dimension for the spectral regressor.fourier_modes_t(int): Number of Fourier modes in the time dimension for the spectral regressor.regressor_activation(str): Activation function used in the regressor head (e.g.,"silu").downscaler_activation(str): Activation function used in the downscaler (if enabled).upscaler_activation(str): Activation function used in the upscaler (if enabled).last_activation(bool): Whether to apply an activation at the last stage of the regressor (implementation-dependent).dropout(float): Dropout probability used in attention blocks (may be overridden per sub-module).downscaler_dropout(float): Dropout probability in the downscaler (if enabled).upscaler_dropout(float): Dropout probability in the upscaler (if enabled).ffn_dropout(float): Dropout probability in the feed-forward sub-layer.encoder_dropout(float): Dropout probability used in encoder layers.decoder_dropout(float): Dropout probability used in the decoder/regressor.debug(bool): Debug flag (enables extra checks / verbose behavior in some components).