Omini3D / tests /debug.md
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified

3-mode Training: XPU-vs-CUDA Loss Regression Investigation

Context

An 8-XPU (single-node, 8 tiles) training run of OM_train_3modes.py reached:

  • loss_ang: -0.50 to -0.57
  • loss_dist: 0.7 to 0.8
  • loss_contrastive: 4.8e-5 to 5.1e-5
  • loss_imgsim: -0.10 to -0.15

Continuing the same checkpoint on 2 CUDA GPUs (same lr=1e-5, batchsize=3, same script) produced:

  • loss_ang: -0.42 (worse)
  • loss_dist: 2.1 (3Γ— worse)
  • loss_contrastive: 5.4e-5 (slightly worse)
  • loss_imgsim: -0.21 (better)

The 2-mode script (OM_train_2modes.py, diffusion + registration only) converges fine on 2 CUDA with the same lr and batchsize, so the regression is specific to what the 3-mode script adds or changes relative to 2-mode.

Material edits in 3-mode vs 2-mode (non-cosmetic)

All line numbers refer to OM_train_3modes.py unless noted.

A. DDP wrapper uses find_unused_parameters=True (line 296)

  • 2-mode: DDP(Deformddpm, device_ids=[rank])
  • 3-mode: DDP(Deformddpm, device_ids=[rank], find_unused_parameters=True)

Added because the contrastive pass does not touch every parameter. Interacts badly with the contrastive side-channel backward in (F) below.

B. Diffusion and registration loss weights retuned (lines 93–95)

constant 2-mode 3-mode change
LOSS_WEIGHTS_DIFF (ang, dist, reg) [2.0, 2.0, 4.0] [2.0, 1.0, 4.0] dist Γ· 2
LOSS_WEIGHTS_REGIST (imgsim, imgmse, ddf) [1.0, 0.05, 128] [1.0, 0.01, 100] imgmse Γ· 5, ddf Γ— 0.78

Halving the diffusion dist weight makes loss_dist weakly constrained. Relaxing the registration DDF regulariser and imgmse lets the registration branch push harder on the shared UNet weights toward the imgsim minimum.

C. XPU-only per-epoch restart with rank-0 param broadcast (lines 831–839, ddp_load_dict line 922–926)

3-mode adds an unconditional sys.exit(EXIT_CODE_RESTART) at the end of every epoch on XPU. The bash wrapper re-launches, ddp_load_dict() loads the checkpoint and does dist.broadcast(param.data, src=0) for every parameter, so every epoch on XPU starts with all ranks bit-identical.

On CUDA this block is skipped entirely (gated on DEVICE_TYPE == 'xpu'), so whatever cross-rank drift accumulates in the contrastive or NaN-skip paths is never wiped. This is the reason XPU appears healthy and CUDA does not for any cross-rank drift bug.

D. Registration sampling is deeper in 3-mode (lines 662–664)

  • 2-mode: scale_regist = U(0, 0.7), select_timestep = randint(8, 17)
  • 3-mode: scale_regist = U(0, 0.5), select_timestep = randint(12, 32)

Samples over a larger range of timesteps (12–32 vs 8–17) and from a shallower start. Both directions make each registration forward deeper and produce a larger gradient signal when the gate fires. Combined with (B), the registration branch is meaningfully more dominant in 3-mode.

E. Registration accumulator divisor changed (lines 697–700, 793)

  • 2-mode: epoch_loss_regist += loss_regist.item() / total β€” denominator is total steps
  • 3-mode: epoch_loss_regist += loss_regist.item(), printed as / total_reg_safe β€” denominator is steps where registration actually fired

Reporting change only. Does not affect training, but distorts cross-run comparisons against 2-mode numbers.

F. Contrastive branch with side-channel backward (lines 605–624)

_ = Deformddpm(img_org=..., contrastive_only=True, text=None)      # DDP forward returns ddf
raw_network = Deformddpm.module.network if use_distributed else Deformddpm.network
img_embd = raw_network.img_embd                                    # side-channel attribute
loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean() - 0.25)
loss_contra.backward()

In Diffusion/diffuser.py:395–399, the contrastive_only=True branch returns self.network(...), which returns ddf. But the training code ignores ddf and reads raw_network.img_embd β€” a side-channel attribute set in the middle of the forward, before the rec_num=2 loop overwrites it and before the decoder runs.

Consequence:

  1. DDP's Reducer.prepare_for_backward(ddf) walks the autograd graph from ddf and marks the expected-gradient set (essentially the full UNet over both rec iterations).
  2. loss_contra.backward() is called from img_embd, so autograd traverses a different subgraph than the one DDP prepared. The set of params whose gradient hooks actually fire does not match the expected set.
  3. With find_unused_parameters=True the Reducer's fallback synthesises zero grads for missing hooks and allreduces them, but in practice some contrastive grads escape allreduce and stay local per rank. Each rank takes a slightly different optimizer.step().
  4. DDP only syncs gradients, not param.data. The encoder / img2txt weights of the two ranks drift apart, the next diffusion forward produces different outputs per rank, and the drift compounds.

This is the well-known DDP footgun: the DDP module output must be the tensor that forms the loss.

G. NaN/Inf skip uses broadcast(src=0) instead of an OR-reduction (lines 554–563)

is_nan = torch.isnan(loss_tot) or torch.isinf(loss_tot)
if use_distributed:
    nan_flag = torch.tensor([1.0 if is_nan else 0.0], device=f"{DEVICE_TYPE}:{rank}")
    dist.broadcast(nan_flag, src=0)
    is_nan = nan_flag.item() > 0

This forces rank 0's decision onto every rank. If any non-rank-0 rank has NaN while rank 0 does not, all ranks proceed with the non-rank-0 backward and the NaN gradient allreduces into everyone. The correct reduction is dist.all_reduce(nan_flag, op=dist.ReduceOp.MAX) β€” "any rank hit NaN β†’ everyone skips".

Not the primary cause of the current regression (CUDA rarely produces NaN), but it is a correctness bug and should be fixed regardless.

H. Gradient checkpointing on XPU only (lines 260–263)

Numerically equivalent to non-checkpointed forward. Does not explain the loss delta. Noted for completeness.

I. empty_cache() + gc.collect() + synchronize() between phases (lines 591–594, 635–638)

Added for XPU's autograd memory leak. On CUDA just costs time, no correctness impact.

J. TEXT_EMBED_PROB 0.7 β†’ 0.5 (line 91)

Less text conditioning. Minor effect, same on both devices.

K. Multi-node LOCAL_RANK launch branch (lines 943–955)

Only affects how ranks are spawned, not training math.

Ranked suspect list for the CUDA regression

  1. (B) + (D) together β€” the 3-mode registration branch is heavier, and the weaker dist weight in LOSS_WEIGHTS_DIFF can't hold loss_dist down on small-effective-BS CUDA. This matches the shape of the regression best: loss_imgsim improves (more pull from registration), loss_dist regresses most (lightest constraint), loss_ang regresses less (still heavily weighted). Most likely primary cause.

  2. (F) + (A) β€” contrastive side-channel backward with find_unused_parameters=True, masked on XPU by (C). Secondary contributor via encoder drift. Probably accounts for the minor contrastive-loss regression (5.0e-5 β†’ 5.4e-5) more than the loss_dist blowup.

  3. Effective global BS dropped from 24 (8 XPU Γ— 3) to 6 (2 CUDA Γ— 3) with unchanged lr=1e-5. On its own this would be survivable (2-mode proves it), but combined with (1) the dist loss no longer has enough weight to resist the extra gradient noise, and combined with (2) rank drift is no longer being averaged across many ranks.

  4. (G) NaN poison-gradient bug. Unlikely to be the current problem but should be fixed.

Cheap isolation experiments

Run each independently on 2 CUDA, starting from the same 8-XPU checkpoint, for 2–3 epochs. Watch loss_dist specifically.

  1. Revert LOSS_WEIGHTS_DIFF to [2.0, 2.0, 4.0] (the 2-mode value). Leave everything else alone. If loss_dist returns near 0.7–0.9, (B) is confirmed as the main cause.
  2. Set LOSS_WEIGHT_CONTRASTIVE = 0 (or skip the whole contrastive block). If diffusion losses stop degrading, (F) is a significant contributor; if not, contrastive is not the main cause.
  3. Add an explicit per-epoch rank-0 broadcast on CUDA: at the top of every epoch, for p in Deformddpm.parameters(): dist.broadcast(p.data, src=0). If this alone fixes the regression, rank drift is the dominant cause.
  4. Revert LOSS_WEIGHTS_REGIST to [1.0, 0.05, 128] and scale_regist / select_timestep to 2-mode values. If loss_imgsim stops improving and loss_dist stops regressing, (D) is significant.

(1)+(2) can be combined. If that does not fully restore behaviour, cause #3 (effective BS) is material and lr scaling or gradient accumulation is needed.

Code changes to make regardless of experiment outcomes

  • Fix the contrastive DDP output. In Diffusion/diffuser.py:395–399, return self.network.img_embd from the contrastive_only=True branch; in OM_train_3modes.py:612–616, use the returned value as the loss input. This makes DDP's prepare_for_backward trace the correct subgraph and allreduce the right set of params.

    # Diffusion/diffuser.py
    def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, contrastive_only=False, **kwargs):
        if contrastive_only:
            self.network(x=img_org, y=cond_imgs, t=T, text=kwargs.get('text'))
            return self.network.img_embd
        ...
    
    # OM_train_3modes.py
    img_embd = Deformddpm(img_org=(x0 * blind_mask).detach(),
                          cond_imgs=cond_img.detach(),
                          T=t_contra, contrastive_only=True, text=None)
    loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu(
        1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean() - 0.25)
    

    After this fix, find_unused_parameters=True on line 296 can and should be removed. If DDP complains about unused params on steps where the registration gate fails, use static_graph=True instead.

  • Fix the NaN-skip reduction. Replace dist.broadcast(nan_flag, src=0) with dist.all_reduce(nan_flag, op=dist.ReduceOp.MAX) so that "any rank hit NaN β†’ everyone skips".

  • Decide whether the XPU-only per-epoch restart should also exist on CUDA β€” or, equivalently, whether CUDA runs should periodically do a rank-0 param broadcast. Right now XPU is getting a hidden correctness safety net that CUDA is not, and any cross-rank divergence bug will look fine on XPU and broken on CUDA.

References

  • OM_train_3modes.py β€” 3-mode training script
  • OM_train_2modes.py β€” 2-mode training script (baseline for diff)
  • Diffusion/diffuser.py:395–399 β€” DeformDDPM.forward(contrastive_only=True)
  • Diffusion/networks.py:893–978 β€” RecMulModMutAttnNet.forward, where self.img_embd is set mid-forward and overwritten by the rec_num=2 loop