# 3-mode Training: XPU-vs-CUDA Loss Regression Investigation ## Context An 8-XPU (single-node, 8 tiles) training run of `OM_train_3modes.py` reached: - `loss_ang`: -0.50 to -0.57 - `loss_dist`: 0.7 to 0.8 - `loss_contrastive`: 4.8e-5 to 5.1e-5 - `loss_imgsim`: -0.10 to -0.15 Continuing the same checkpoint on 2 CUDA GPUs (same `lr=1e-5`, `batchsize=3`, same script) produced: - `loss_ang`: -0.42 (worse) - `loss_dist`: 2.1 (**3× worse**) - `loss_contrastive`: 5.4e-5 (slightly worse) - `loss_imgsim`: -0.21 (**better**) The 2-mode script (`OM_train_2modes.py`, diffusion + registration only) converges fine on 2 CUDA with the same `lr` and `batchsize`, so the regression is specific to what the 3-mode script adds or changes relative to 2-mode. ## Material edits in 3-mode vs 2-mode (non-cosmetic) All line numbers refer to `OM_train_3modes.py` unless noted. ### A. DDP wrapper uses `find_unused_parameters=True` (line 296) - 2-mode: `DDP(Deformddpm, device_ids=[rank])` - 3-mode: `DDP(Deformddpm, device_ids=[rank], find_unused_parameters=True)` Added because the contrastive pass does not touch every parameter. Interacts badly with the contrastive side-channel backward in (F) below. ### B. Diffusion and registration loss weights retuned (lines 93–95) | constant | 2-mode | 3-mode | change | |---|---|---|---| | `LOSS_WEIGHTS_DIFF` (ang, dist, reg) | `[2.0, 2.0, 4.0]` | `[2.0, 1.0, 4.0]` | **dist ÷ 2** | | `LOSS_WEIGHTS_REGIST` (imgsim, imgmse, ddf) | `[1.0, 0.05, 128]` | `[1.0, 0.01, 100]` | **imgmse ÷ 5, ddf × 0.78** | Halving the diffusion `dist` weight makes `loss_dist` weakly constrained. Relaxing the registration DDF regulariser and `imgmse` lets the registration branch push harder on the shared UNet weights toward the `imgsim` minimum. ### C. XPU-only per-epoch restart with rank-0 param broadcast (lines 831–839, ddp_load_dict line 922–926) 3-mode adds an unconditional `sys.exit(EXIT_CODE_RESTART)` at the end of every epoch **on XPU**. The bash wrapper re-launches, `ddp_load_dict()` loads the checkpoint and does `dist.broadcast(param.data, src=0)` for every parameter, so every epoch on XPU starts with all ranks bit-identical. **On CUDA this block is skipped entirely** (gated on `DEVICE_TYPE == 'xpu'`), so whatever cross-rank drift accumulates in the contrastive or NaN-skip paths is never wiped. This is the reason XPU appears healthy and CUDA does not for any cross-rank drift bug. ### D. Registration sampling is deeper in 3-mode (lines 662–664) - 2-mode: `scale_regist = U(0, 0.7)`, `select_timestep = randint(8, 17)` - 3-mode: `scale_regist = U(0, 0.5)`, `select_timestep = randint(12, 32)` Samples over a larger range of timesteps (12–32 vs 8–17) and from a shallower start. Both directions make each registration forward deeper and produce a larger gradient signal when the gate fires. Combined with (B), the registration branch is meaningfully more dominant in 3-mode. ### E. Registration accumulator divisor changed (lines 697–700, 793) - 2-mode: `epoch_loss_regist += loss_regist.item() / total` — denominator is **total steps** - 3-mode: `epoch_loss_regist += loss_regist.item()`, printed as `/ total_reg_safe` — denominator is **steps where registration actually fired** Reporting change only. Does not affect training, but distorts cross-run comparisons against 2-mode numbers. ### F. Contrastive branch with side-channel backward (lines 605–624) ```python _ = Deformddpm(img_org=..., contrastive_only=True, text=None) # DDP forward returns ddf raw_network = Deformddpm.module.network if use_distributed else Deformddpm.network img_embd = raw_network.img_embd # side-channel attribute loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean() - 0.25) loss_contra.backward() ``` In `Diffusion/diffuser.py:395–399`, the `contrastive_only=True` branch returns `self.network(...)`, which returns `ddf`. But the training code ignores `ddf` and reads `raw_network.img_embd` — a side-channel attribute set in the middle of the forward, before the `rec_num=2` loop overwrites it and before the decoder runs. Consequence: 1. DDP's `Reducer.prepare_for_backward(ddf)` walks the autograd graph **from `ddf`** and marks the expected-gradient set (essentially the full UNet over both rec iterations). 2. `loss_contra.backward()` is called from `img_embd`, so autograd traverses a *different* subgraph than the one DDP prepared. The set of params whose gradient hooks actually fire does not match the expected set. 3. With `find_unused_parameters=True` the Reducer's fallback synthesises zero grads for missing hooks and allreduces them, but in practice some contrastive grads escape allreduce and stay local per rank. Each rank takes a slightly different `optimizer.step()`. 4. DDP only syncs gradients, not `param.data`. The encoder / `img2txt` weights of the two ranks drift apart, the next diffusion forward produces different outputs per rank, and the drift compounds. This is the well-known DDP footgun: **the DDP module output must be the tensor that forms the loss**. ### G. NaN/Inf skip uses `broadcast(src=0)` instead of an OR-reduction (lines 554–563) ```python is_nan = torch.isnan(loss_tot) or torch.isinf(loss_tot) if use_distributed: nan_flag = torch.tensor([1.0 if is_nan else 0.0], device=f"{DEVICE_TYPE}:{rank}") dist.broadcast(nan_flag, src=0) is_nan = nan_flag.item() > 0 ``` This forces **rank 0's** decision onto every rank. If any non-rank-0 rank has NaN while rank 0 does not, all ranks proceed with the non-rank-0 backward and the NaN gradient allreduces into everyone. The correct reduction is `dist.all_reduce(nan_flag, op=dist.ReduceOp.MAX)` — "any rank hit NaN → everyone skips". Not the primary cause of the current regression (CUDA rarely produces NaN), but it is a correctness bug and should be fixed regardless. ### H. Gradient checkpointing on XPU only (lines 260–263) Numerically equivalent to non-checkpointed forward. Does not explain the loss delta. Noted for completeness. ### I. `empty_cache()` + `gc.collect()` + `synchronize()` between phases (lines 591–594, 635–638) Added for XPU's autograd memory leak. On CUDA just costs time, no correctness impact. ### J. `TEXT_EMBED_PROB` 0.7 → 0.5 (line 91) Less text conditioning. Minor effect, same on both devices. ### K. Multi-node `LOCAL_RANK` launch branch (lines 943–955) Only affects how ranks are spawned, not training math. ## Ranked suspect list for the CUDA regression 1. **(B) + (D) together — the 3-mode registration branch is heavier**, and the weaker dist weight in `LOSS_WEIGHTS_DIFF` can't hold `loss_dist` down on small-effective-BS CUDA. This matches the shape of the regression best: `loss_imgsim` improves (more pull from registration), `loss_dist` regresses most (lightest constraint), `loss_ang` regresses less (still heavily weighted). Most likely primary cause. 2. **(F) + (A) — contrastive side-channel backward with `find_unused_parameters=True`, masked on XPU by (C)**. Secondary contributor via encoder drift. Probably accounts for the minor contrastive-loss regression (5.0e-5 → 5.4e-5) more than the `loss_dist` blowup. 3. **Effective global BS dropped from 24 (8 XPU × 3) to 6 (2 CUDA × 3) with unchanged `lr=1e-5`**. On its own this would be survivable (2-mode proves it), but combined with (1) the dist loss no longer has enough weight to resist the extra gradient noise, and combined with (2) rank drift is no longer being averaged across many ranks. 4. **(G)** NaN poison-gradient bug. Unlikely to be the current problem but should be fixed. ## Cheap isolation experiments Run each independently on 2 CUDA, starting from the same 8-XPU checkpoint, for 2–3 epochs. Watch `loss_dist` specifically. 1. **Revert `LOSS_WEIGHTS_DIFF` to `[2.0, 2.0, 4.0]`** (the 2-mode value). Leave everything else alone. If `loss_dist` returns near 0.7–0.9, (B) is confirmed as the main cause. 2. **Set `LOSS_WEIGHT_CONTRASTIVE = 0`** (or skip the whole contrastive block). If diffusion losses stop degrading, (F) is a significant contributor; if not, contrastive is not the main cause. 3. **Add an explicit per-epoch rank-0 broadcast on CUDA**: at the top of every epoch, `for p in Deformddpm.parameters(): dist.broadcast(p.data, src=0)`. If this alone fixes the regression, rank drift is the dominant cause. 4. **Revert `LOSS_WEIGHTS_REGIST` to `[1.0, 0.05, 128]`** and `scale_regist` / `select_timestep` to 2-mode values. If `loss_imgsim` stops improving and `loss_dist` stops regressing, (D) is significant. (1)+(2) can be combined. If that does not fully restore behaviour, cause #3 (effective BS) is material and `lr` scaling or gradient accumulation is needed. ## Code changes to make regardless of experiment outcomes - **Fix the contrastive DDP output.** In `Diffusion/diffuser.py:395–399`, return `self.network.img_embd` from the `contrastive_only=True` branch; in `OM_train_3modes.py:612–616`, use the returned value as the loss input. This makes DDP's `prepare_for_backward` trace the correct subgraph and allreduce the right set of params. ```python # Diffusion/diffuser.py def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, contrastive_only=False, **kwargs): if contrastive_only: self.network(x=img_org, y=cond_imgs, t=T, text=kwargs.get('text')) return self.network.img_embd ... # OM_train_3modes.py img_embd = Deformddpm(img_org=(x0 * blind_mask).detach(), cond_imgs=cond_img.detach(), T=t_contra, contrastive_only=True, text=None) loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu( 1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean() - 0.25) ``` After this fix, `find_unused_parameters=True` on line 296 can and should be removed. If DDP complains about unused params on steps where the registration gate fails, use `static_graph=True` instead. - **Fix the NaN-skip reduction.** Replace `dist.broadcast(nan_flag, src=0)` with `dist.all_reduce(nan_flag, op=dist.ReduceOp.MAX)` so that "any rank hit NaN → everyone skips". - **Decide whether the XPU-only per-epoch restart should also exist on CUDA** — or, equivalently, whether CUDA runs should periodically do a rank-0 param broadcast. Right now XPU is getting a hidden correctness safety net that CUDA is not, and any cross-rank divergence bug will look fine on XPU and broken on CUDA. ## References - `OM_train_3modes.py` — 3-mode training script - `OM_train_2modes.py` — 2-mode training script (baseline for diff) - `Diffusion/diffuser.py:395–399` — `DeformDDPM.forward(contrastive_only=True)` - `Diffusion/networks.py:893–978` — `RecMulModMutAttnNet.forward`, where `self.img_embd` is set mid-forward and overwritten by the `rec_num=2` loop