PropJax¶
class PropJax(
equation,
shape,
source_type=[],
receiver_type=[],
abcn=50,
free_surface=False,
dh=10.0,
dt=0.002,
dev=None,
use_ckpt=True,
ckpt_chunks=100,
pml_type="spml",
)
Implementation:
src/sweep/propagator/jax.py
JAX propagator built around jax.lax.scan and chunk-style rematerialization.
Note
PropJax shares the same solver concepts as the PyTorch backend, but its
runtime behavior is shaped by JAX transforms rather than Python-side loops.
Parameters¶
equation(equation instance): The equation instance to be stepped in JAX.shape(tuple[int, ...]): Physical model shape before absorbing boundaries are added. Use(nz, nx)in 2D and(nz, ny, nx)in 3D.source_type(list[str], optional): Wavefield names used for source injection. These names must exist inequation.wavefields.PropJaxdoes not auto-fill defaults.receiver_type(list[str], optional): Wavefield names sampled at receiver locations. These must also matchequation.wavefields.abcn(int, optional): Absorbing boundary width.free_surface(bool, optional): Whether the top boundary is treated as a free surface. This affects internal coordinate offsets before source injection and receiver sampling.dh(float, optional): Scalar grid spacing.dt(float, optional): Time step in seconds.dev(device or context, optional): Stored device/context argument. Actual JAX execution placement is still driven by JAX arrays and transforms.use_ckpt(bool, optional): Enables chunk-based rematerialization in the scanned time loop.ckpt_chunks(int, optional): Chunk size used whenuse_ckpt=True.pml_type(str, optional): PML implementation passed into the equation setup.
Forward Parameters¶
forward(
wavelet,
sources,
receivers,
models=None,
source_encoding=False,
return_wavefield=False,
adj=False,
**kwargs,
)
wavelet(array-like): Source time function. Common layouts are(nt,),(B, nt),(B, nsrc, nt), and the source-encoding super-shot layout(1, nsrc, nt).sources(array-like): Source coordinates. Common layouts are(B, dim),(B, nsrc, dim), and(1, nsrc, dim)for a source-encoding super-shot.receivers(array-like): Receiver coordinates. Typical shape:(B, nreceivers, dim), including(1, nreceivers, dim)for a source-encoding super-shot.models(list of arrays, optional): List of model arrays in the exact order required byequation.models.source_encoding(bool, optional): IfTrue, collapses shots into a single encoded batch.PropJaxalso auto-detects source encoding when the runtime inputs use(1, nsrc, nt),(1, nsrc, dim), and(1, nreceivers, dim).return_wavefield(bool, optional): IfTrue, returns an auxiliary wavefield output in addition to the recorded data.adj(bool, optional): Adjoint-style forward switch.
In the shape descriptions above:
Bis the runtime batch sizensrcis the number of sources inside one batch elementdimis2in 2D and3in 3D
Return Value¶
- default:
record - if
return_wavefield=True:(record, snapshots)