2D Acoustic FWI on Marmousi with JAX¶
Source file:
examples/FWI/2d/acoustic/jax/fwi_marmousi.py
What This Example Does¶
This example runs acoustic full-waveform inversion with the JAX propagator.
The script:
- loads a true velocity model and a smooth initial model
- builds a
PropJax(Acoustic(...))solver - generates observed data from the true model
- inverts the initial model by matching synthetic and observed gathers
Main Components¶
The solver is built from:
equation:Acoustic(..., backend="jax")propagator:PropJax(...)wave: a Ricker waveletsources: regularly sampled source coordinatesreceivers: regularly sampled receiver coordinatesmodels: the velocity modelvp
Prepare the Marmousi Model Files¶
This example reads the same prepared Marmousi acoustic files as the Torch FWI example:
examples/models/marmousi/true.npyexamples/models/marmousi/smooth.npy
Generate them before running the example:
python3 examples/models/marmousi/download_marmousi.py --extract
python3 examples/models/marmousi/extract_model_segy.py
python3 examples/models/marmousi/convert_segy_to_npy.py
python3 examples/models/marmousi/prepare_fwi_models.py \
--input examples/models/marmousi/npy/vp_1p25m.npy \
--source-dh 1.25 \
--target-dh 25.0 \
--radii 8,8 \
--passes 3
Entry Point¶
Run the example with:
python3 examples/FWI/2d/acoustic/jax/fwi_marmousi.py
The script keeps its runtime settings in one CONFIG dictionary.
Key Configuration¶
Important entries include:
nt,dt: temporal samplingdh: spatial samplingspatial_order: finite-difference orderabcn: absorbing boundary widthsrc_step,rec_step: acquisition sampling in the x directiontrue_model,init_model:.npyfiles loaded fromexamples/models/epochs,batchsize,lr: inversion hyperparametersuse_ckpt: whether JAX chunk rematerialization is enabled
Solver Setup¶
The equation is created with the JAX backend:
equation = Acoustic(
spatial_order=cfg["spatial_order"],
backend="jax",
)
Shared propagator arguments are collected first:
prop_kwargs = dict(
shape=shape,
dev=None,
dh=cfg["dh"],
dt=cfg["dt"],
source_type=["h1"],
receiver_type=["h1"],
abcn=cfg["abcn"],
free_surface=cfg["free_surface"],
use_ckpt=cfg["use_ckpt"],
pml_type="cpmlr",
)
Then the solver is created as:
solver = PropJax(equation, **prop_kwargs)
Geometry¶
The example uses a fixed-depth surface acquisition:
- sources are placed every
src_stepgrid points - receivers are placed every
rec_stepgrid points - all sources use the same source depth
srcz - all receivers use the same receiver depth
recz
The final array shapes are:
sources:(nshots, 2)receivers:(nshots, nreceivers, 2)
Inversion Workflow¶
Observed data is generated first from the true model, then the inversion updates
the smooth model with optax.adam.
At each iteration, the script:
- selects a random subset of shots
- computes synthetic data
- evaluates the L2 data-misfit loss
- gets gradients with
jax.value_and_grad - updates the model with
optax
Outputs¶
The script creates an output directory under examples/ and saves:
ricker.pngobserved_data.pngloss.pngepoch_XXXX.png: includes the true model, the current inverted model, and the current gradient
The output directory is:
acoustic_fwi_jax
Running the Example¶
Step 1. Prepare the Marmousi .npy files listed above if they do not already
exist.
Step 2. Run the JAX Marmousi script from the repository root.
python3 examples/FWI/2d/acoustic/jax/fwi_marmousi.py
Step 3. Check acoustic_fwi_jax for the saved wavelet, observed data, loss, and epoch figures.
Notes:
- the script disables JAX preallocation with
XLA_PYTHON_CLIENT_PREALLOCATE=false use_ckptcontrols JAX rematerialization for reduced memory usage