Update model.py
Browse files
model.py
CHANGED
|
@@ -279,7 +279,7 @@ class GPT(nn.Module):
|
|
| 279 |
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
| 280 |
# Create AdamW optimizer and use the fused version if it is available
|
| 281 |
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
| 282 |
-
use_fused = fused_available and device_type == '
|
| 283 |
extra_args = dict(fused=True) if use_fused else dict()
|
| 284 |
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
| 285 |
print(f"using fused AdamW: {use_fused}")
|
|
|
|
| 279 |
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
| 280 |
# Create AdamW optimizer and use the fused version if it is available
|
| 281 |
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
| 282 |
+
use_fused = fused_available and device_type == 'cpu'
|
| 283 |
extra_args = dict(fused=True) if use_fused else dict()
|
| 284 |
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
| 285 |
print(f"using fused AdamW: {use_fused}")
|