3-mode Training: XPU-vs-CUDA Loss Regression Investigation
Context
An 8-XPU (single-node, 8 tiles) training run of OM_train_3modes.py reached:
loss_ang: -0.50 to -0.57loss_dist: 0.7 to 0.8loss_contrastive: 4.8e-5 to 5.1e-5loss_imgsim: -0.10 to -0.15
Continuing the same checkpoint on 2 CUDA GPUs (same lr=1e-5, batchsize=3, same script) produced:
loss_ang: -0.42 (worse)loss_dist: 2.1 (3Γ worse)loss_contrastive: 5.4e-5 (slightly worse)loss_imgsim: -0.21 (better)
The 2-mode script (OM_train_2modes.py, diffusion + registration only) converges fine on 2 CUDA with the same lr and batchsize, so the regression is specific to what the 3-mode script adds or changes relative to 2-mode.
Material edits in 3-mode vs 2-mode (non-cosmetic)
All line numbers refer to OM_train_3modes.py unless noted.
A. DDP wrapper uses find_unused_parameters=True (line 296)
- 2-mode:
DDP(Deformddpm, device_ids=[rank]) - 3-mode:
DDP(Deformddpm, device_ids=[rank], find_unused_parameters=True)
Added because the contrastive pass does not touch every parameter. Interacts badly with the contrastive side-channel backward in (F) below.
B. Diffusion and registration loss weights retuned (lines 93β95)
| constant | 2-mode | 3-mode | change |
|---|---|---|---|
LOSS_WEIGHTS_DIFF (ang, dist, reg) |
[2.0, 2.0, 4.0] |
[2.0, 1.0, 4.0] |
dist Γ· 2 |
LOSS_WEIGHTS_REGIST (imgsim, imgmse, ddf) |
[1.0, 0.05, 128] |
[1.0, 0.01, 100] |
imgmse Γ· 5, ddf Γ 0.78 |
Halving the diffusion dist weight makes loss_dist weakly constrained. Relaxing the registration DDF regulariser and imgmse lets the registration branch push harder on the shared UNet weights toward the imgsim minimum.
C. XPU-only per-epoch restart with rank-0 param broadcast (lines 831β839, ddp_load_dict line 922β926)
3-mode adds an unconditional sys.exit(EXIT_CODE_RESTART) at the end of every epoch on XPU. The bash wrapper re-launches, ddp_load_dict() loads the checkpoint and does dist.broadcast(param.data, src=0) for every parameter, so every epoch on XPU starts with all ranks bit-identical.
On CUDA this block is skipped entirely (gated on DEVICE_TYPE == 'xpu'), so whatever cross-rank drift accumulates in the contrastive or NaN-skip paths is never wiped. This is the reason XPU appears healthy and CUDA does not for any cross-rank drift bug.
D. Registration sampling is deeper in 3-mode (lines 662β664)
- 2-mode:
scale_regist = U(0, 0.7),select_timestep = randint(8, 17) - 3-mode:
scale_regist = U(0, 0.5),select_timestep = randint(12, 32)
Samples over a larger range of timesteps (12β32 vs 8β17) and from a shallower start. Both directions make each registration forward deeper and produce a larger gradient signal when the gate fires. Combined with (B), the registration branch is meaningfully more dominant in 3-mode.
E. Registration accumulator divisor changed (lines 697β700, 793)
- 2-mode:
epoch_loss_regist += loss_regist.item() / totalβ denominator is total steps - 3-mode:
epoch_loss_regist += loss_regist.item(), printed as/ total_reg_safeβ denominator is steps where registration actually fired
Reporting change only. Does not affect training, but distorts cross-run comparisons against 2-mode numbers.
F. Contrastive branch with side-channel backward (lines 605β624)
_ = Deformddpm(img_org=..., contrastive_only=True, text=None) # DDP forward returns ddf
raw_network = Deformddpm.module.network if use_distributed else Deformddpm.network
img_embd = raw_network.img_embd # side-channel attribute
loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean() - 0.25)
loss_contra.backward()
In Diffusion/diffuser.py:395β399, the contrastive_only=True branch returns self.network(...), which returns ddf. But the training code ignores ddf and reads raw_network.img_embd β a side-channel attribute set in the middle of the forward, before the rec_num=2 loop overwrites it and before the decoder runs.
Consequence:
- DDP's
Reducer.prepare_for_backward(ddf)walks the autograd graph fromddfand marks the expected-gradient set (essentially the full UNet over both rec iterations). loss_contra.backward()is called fromimg_embd, so autograd traverses a different subgraph than the one DDP prepared. The set of params whose gradient hooks actually fire does not match the expected set.- With
find_unused_parameters=Truethe Reducer's fallback synthesises zero grads for missing hooks and allreduces them, but in practice some contrastive grads escape allreduce and stay local per rank. Each rank takes a slightly differentoptimizer.step(). - DDP only syncs gradients, not
param.data. The encoder /img2txtweights of the two ranks drift apart, the next diffusion forward produces different outputs per rank, and the drift compounds.
This is the well-known DDP footgun: the DDP module output must be the tensor that forms the loss.
G. NaN/Inf skip uses broadcast(src=0) instead of an OR-reduction (lines 554β563)
is_nan = torch.isnan(loss_tot) or torch.isinf(loss_tot)
if use_distributed:
nan_flag = torch.tensor([1.0 if is_nan else 0.0], device=f"{DEVICE_TYPE}:{rank}")
dist.broadcast(nan_flag, src=0)
is_nan = nan_flag.item() > 0
This forces rank 0's decision onto every rank. If any non-rank-0 rank has NaN while rank 0 does not, all ranks proceed with the non-rank-0 backward and the NaN gradient allreduces into everyone. The correct reduction is dist.all_reduce(nan_flag, op=dist.ReduceOp.MAX) β "any rank hit NaN β everyone skips".
Not the primary cause of the current regression (CUDA rarely produces NaN), but it is a correctness bug and should be fixed regardless.
H. Gradient checkpointing on XPU only (lines 260β263)
Numerically equivalent to non-checkpointed forward. Does not explain the loss delta. Noted for completeness.
I. empty_cache() + gc.collect() + synchronize() between phases (lines 591β594, 635β638)
Added for XPU's autograd memory leak. On CUDA just costs time, no correctness impact.
J. TEXT_EMBED_PROB 0.7 β 0.5 (line 91)
Less text conditioning. Minor effect, same on both devices.
K. Multi-node LOCAL_RANK launch branch (lines 943β955)
Only affects how ranks are spawned, not training math.
Ranked suspect list for the CUDA regression
(B) + (D) together β the 3-mode registration branch is heavier, and the weaker dist weight in
LOSS_WEIGHTS_DIFFcan't holdloss_distdown on small-effective-BS CUDA. This matches the shape of the regression best:loss_imgsimimproves (more pull from registration),loss_distregresses most (lightest constraint),loss_angregresses less (still heavily weighted). Most likely primary cause.(F) + (A) β contrastive side-channel backward with
find_unused_parameters=True, masked on XPU by (C). Secondary contributor via encoder drift. Probably accounts for the minor contrastive-loss regression (5.0e-5 β 5.4e-5) more than theloss_distblowup.Effective global BS dropped from 24 (8 XPU Γ 3) to 6 (2 CUDA Γ 3) with unchanged
lr=1e-5. On its own this would be survivable (2-mode proves it), but combined with (1) the dist loss no longer has enough weight to resist the extra gradient noise, and combined with (2) rank drift is no longer being averaged across many ranks.(G) NaN poison-gradient bug. Unlikely to be the current problem but should be fixed.
Cheap isolation experiments
Run each independently on 2 CUDA, starting from the same 8-XPU checkpoint, for 2β3 epochs. Watch loss_dist specifically.
- Revert
LOSS_WEIGHTS_DIFFto[2.0, 2.0, 4.0](the 2-mode value). Leave everything else alone. Ifloss_distreturns near 0.7β0.9, (B) is confirmed as the main cause. - Set
LOSS_WEIGHT_CONTRASTIVE = 0(or skip the whole contrastive block). If diffusion losses stop degrading, (F) is a significant contributor; if not, contrastive is not the main cause. - Add an explicit per-epoch rank-0 broadcast on CUDA: at the top of every epoch,
for p in Deformddpm.parameters(): dist.broadcast(p.data, src=0). If this alone fixes the regression, rank drift is the dominant cause. - Revert
LOSS_WEIGHTS_REGISTto[1.0, 0.05, 128]andscale_regist/select_timestepto 2-mode values. Ifloss_imgsimstops improving andloss_diststops regressing, (D) is significant.
(1)+(2) can be combined. If that does not fully restore behaviour, cause #3 (effective BS) is material and lr scaling or gradient accumulation is needed.
Code changes to make regardless of experiment outcomes
Fix the contrastive DDP output. In
Diffusion/diffuser.py:395β399, returnself.network.img_embdfrom thecontrastive_only=Truebranch; inOM_train_3modes.py:612β616, use the returned value as the loss input. This makes DDP'sprepare_for_backwardtrace the correct subgraph and allreduce the right set of params.# Diffusion/diffuser.py def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, contrastive_only=False, **kwargs): if contrastive_only: self.network(x=img_org, y=cond_imgs, t=T, text=kwargs.get('text')) return self.network.img_embd ... # OM_train_3modes.py img_embd = Deformddpm(img_org=(x0 * blind_mask).detach(), cond_imgs=cond_img.detach(), T=t_contra, contrastive_only=True, text=None) loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu( 1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean() - 0.25)After this fix,
find_unused_parameters=Trueon line 296 can and should be removed. If DDP complains about unused params on steps where the registration gate fails, usestatic_graph=Trueinstead.Fix the NaN-skip reduction. Replace
dist.broadcast(nan_flag, src=0)withdist.all_reduce(nan_flag, op=dist.ReduceOp.MAX)so that "any rank hit NaN β everyone skips".Decide whether the XPU-only per-epoch restart should also exist on CUDA β or, equivalently, whether CUDA runs should periodically do a rank-0 param broadcast. Right now XPU is getting a hidden correctness safety net that CUDA is not, and any cross-rank divergence bug will look fine on XPU and broken on CUDA.
References
OM_train_3modes.pyβ 3-mode training scriptOM_train_2modes.pyβ 2-mode training script (baseline for diff)Diffusion/diffuser.py:395β399βDeformDDPM.forward(contrastive_only=True)Diffusion/networks.py:893β978βRecMulModMutAttnNet.forward, whereself.img_embdis set mid-forward and overwritten by therec_num=2loop