Multi-GPU FWI · DDP vs single device¶
How much does throwing more GPUs at a small FWI step actually buy you? This notebook compares:
- 1 GPU baseline — the inner FWI loop everyone knows:
solver(...) → loss → loss.backward()overnshotsshots, sequentially oncuda:0. torch.distributed(DDP) on N GPUs — the same loop, butnshots / Nshots per rank. Gradients all-reduce to a global sum at the end.
Each DDP rank is its own Python interpreter; communication only happens at the gradient all_reduce boundary. For shot-parallel FWI — where each rank's work is independent — that boundary is the only synchronisation point and scaling is close to linear.
import os, time, subprocess, textwrap
import numpy as np
import torch
from sweep.equations import Acoustic
from sweep.propagator.torch import PropTorch
from sweep.signal import ricker
ngpu = torch.cuda.device_count()
print(f'visible CUDA devices: {ngpu}')
if ngpu == 0:
raise SystemExit('No CUDA device visible — this notebook needs at least one GPU.')
if ngpu == 1:
print('NOTE: only 1 GPU visible — the DDP run below will use --nproc_per_node=1 and produce no speedup.')
visible CUDA devices: 4
1. Build a small FWI problem¶
Same shape / time step / source layout for both the 1-GPU baseline and the DDP run, so the speedup numbers compare like-for-like.
shape = (384, 512) # 384 z × 512 x cells
dh, dt, nt = 6.0, 0.8e-3, 2000 # 0.8 ms time step, 1.6 s record
freq, delay = 14.0, 0.08 # 14 Hz Ricker
nshots = 8 # divisible by 4 → clean DDP split on a 4-GPU box
vp_true_np = np.full(shape, 1500.0, dtype=np.float32)
vp_true_np[shape[0] // 2:, :] = 2500.0
vp_init_np = np.full(shape, 1500.0, dtype=np.float32)
src_x = np.linspace(20, shape[1] - 21, nshots, dtype=np.int64)
sources = np.stack([src_x, np.full(nshots, 2, dtype=np.int64)], axis=1)
rec_x = np.linspace(0, shape[1] - 1, 128, dtype=np.int64)
one_shot_rec = np.stack([rec_x, np.full_like(rec_x, 4)], axis=1)
receivers = np.broadcast_to(one_shot_rec, (nshots, *one_shot_rec.shape)).copy()
t = np.arange(nt, dtype=np.float32) * dt
wavelet = ricker(t - delay, f=freq).astype(np.float32)
print(f'{nshots} shots, {len(rec_x)} receivers/shot')
8 shots, 128 receivers/shot
2. Baseline — single-GPU shot loop¶
Sequentially run forward + backward on cuda:0 for all 8 shots. This is the wall-clock number every other configuration will be compared against. We warm CUDA / cuDNN once first so the first measured step doesn't include autotune cost.
dev = torch.device('cuda:0')
solver = PropTorch(Acoustic(device=dev), shape=shape, dh=dh, dt=dt, dev=dev,
use_ckpt=False, impl='c')
with torch.no_grad():
obs_full = solver(wavelet, sources, receivers,
models=[torch.tensor(vp_true_np, device=dev)])
_ = solver(wavelet, sources[:1], receivers[:1],
models=[torch.tensor(vp_init_np, device=dev)])
torch.cuda.synchronize(dev)
vp_t = torch.tensor(vp_init_np, device=dev, requires_grad=True)
t0 = time.perf_counter()
single_loss = 0.0
for s in range(nshots):
syn = solver(wavelet, sources[s:s+1], receivers[s:s+1], models=[vp_t])
loss = 0.5 * (syn - obs_full[s:s+1]).pow(2).sum()
loss.backward()
single_loss += float(loss.detach())
torch.cuda.synchronize(dev)
elapsed_single = time.perf_counter() - t0
single_grad = vp_t.grad.detach().clone()
print(f'1 GPU: loss={single_loss:.4e} elapsed={elapsed_single:.3f} s')
1 GPU: loss=2.6533e+01 elapsed=0.665 s
3. DDP — same loop, N ranks, one GPU each¶
DDP requires a separate Python process per GPU. We do that the standard way:
- Write a small driver script to disk — it has the same shot loop as above plus a
dist.all_reduceon the gradient at the end. - Launch it with
torchrun --nproc_per_node=N. - Parse the elapsed time and accumulated loss from the rank-0 stdout.
The script is intentionally short so the only thing that differs from the baseline is the process-level parallelism.
DDP_SCRIPT = textwrap.dedent('''\
import os, time, json
import numpy as np
import torch, torch.distributed as dist
from sweep.equations import Acoustic
from sweep.propagator.torch import PropTorch
from sweep.signal import ricker
SHAPE = (384, 512); DH, DT, NT = 6.0, 0.8e-3, 2000
FREQ, DELAY = 14.0, 0.08; NSHOTS = 8
dist.init_process_group(backend="nccl")
rank, world = dist.get_rank(), dist.get_world_size()
local = int(os.environ.get("LOCAL_RANK", rank))
torch.cuda.set_device(local)
dev = torch.device(f"cuda:{local}")
vp_true = np.full(SHAPE, 1500.0, dtype=np.float32); vp_true[SHAPE[0] // 2:, :] = 2500.0
vp_init = np.full(SHAPE, 1500.0, dtype=np.float32)
sx = np.linspace(20, SHAPE[1] - 21, NSHOTS, dtype=np.int64)
sources = np.stack([sx, np.full(NSHOTS, 2, dtype=np.int64)], axis=1)
rx = np.linspace(0, SHAPE[1] - 1, 128, dtype=np.int64)
receivers = np.broadcast_to(np.stack([rx, np.full_like(rx, 4)], axis=1),
(NSHOTS, rx.size, 2)).copy()
t = np.arange(NT, dtype=np.float32) * DT
wavelet = ricker(t - DELAY, f=FREQ).astype(np.float32)
solver = PropTorch(Acoustic(device=dev), shape=SHAPE, dh=DH, dt=DT, dev=dev,
use_ckpt=False, impl="c")
with torch.no_grad():
obs = solver(wavelet, sources, receivers,
models=[torch.tensor(vp_true, device=dev)])
_ = solver(wavelet, sources[:1], receivers[:1],
models=[torch.tensor(vp_init, device=dev)])
torch.cuda.synchronize(dev); dist.barrier()
my_shots = np.array_split(np.arange(NSHOTS), world)[rank].tolist()
vp = torch.tensor(vp_init, device=dev, requires_grad=True)
t0 = time.perf_counter()
local_loss = 0.0
for s in my_shots:
syn = solver(wavelet, sources[s:s+1], receivers[s:s+1], models=[vp])
loss = 0.5 * (syn - obs[s:s+1]).pow(2).sum()
loss.backward()
local_loss += float(loss.detach())
torch.cuda.synchronize(dev)
dist.all_reduce(vp.grad, op=dist.ReduceOp.SUM)
loss_t = torch.tensor([local_loss], device=dev)
dist.all_reduce(loss_t, op=dist.ReduceOp.SUM)
dist.barrier()
elapsed = time.perf_counter() - t0
if rank == 0:
print(json.dumps({"world": world,
"loss": float(loss_t.item()),
"elapsed": elapsed,
"grad_abs_sum": float(vp.grad.abs().sum().item())}))
dist.destroy_process_group()
''')
with open('fwi_ddp.py', 'w') as f:
f.write(DDP_SCRIPT)
print(f'DDP driver written to ./fwi_ddp.py ({len(DDP_SCRIPT)} bytes)')
DDP driver written to ./fwi_ddp.py (2290 bytes)
# Launch torchrun. The notebook captures the JSON line rank-0 prints at
# the end; the IPv6 socket warnings and DDP-init noise go to stderr.
import json as _json
nproc = torch.cuda.device_count()
cmd = ['torchrun', '--standalone', f'--nproc_per_node={nproc}', 'fwi_ddp.py']
print('cmd:', ' '.join(cmd))
res = subprocess.run(cmd, capture_output=True, text=True)
if res.returncode != 0:
print('STDERR:', res.stderr[-800:])
raise RuntimeError(f'torchrun exited with {res.returncode}')
summary = None
for line in res.stdout.splitlines():
line = line.strip()
if line.startswith('{') and line.endswith('}'):
summary = _json.loads(line)
break
if summary is None:
raise RuntimeError(f'no JSON summary in torchrun stdout:\n{res.stdout}')
elapsed_ddp = summary['elapsed']
ddp_loss = summary['loss']
speedup = elapsed_single / elapsed_ddp
print(f'DDP world={summary["world"]}: loss={ddp_loss:.4e} elapsed={elapsed_ddp:.3f} s')
print(f'speedup vs 1 GPU: {speedup:.2f}×')
print(f'loss match (1 GPU vs DDP): rel diff {abs(single_loss - ddp_loss) / abs(single_loss):.2e}')
cmd: torchrun --standalone --nproc_per_node=4 fwi_ddp.py DDP world=4: loss=2.6531e+01 elapsed=0.176 s speedup vs 1 GPU: 3.79× loss match (1 GPU vs DDP): rel diff 5.46e-05
Download this notebook — 12_multi_gpu.ipynb · or view on GitHub