Spaces:
Running
on
T4
Running
on
T4
Commit
·
fa8a6c8
1
Parent(s):
1761d59
TEST cast to bfloat16
Browse files- context_window_gradio.py +1 -1
context_window_gradio.py
CHANGED
|
@@ -13,7 +13,7 @@ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndB
|
|
| 13 |
# quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
| 14 |
torch_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu")
|
| 15 |
|
| 16 |
-
torch_dtype = torch.
|
| 17 |
|
| 18 |
llama_model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
|
| 19 |
# quantization_config=quantization_config,
|
|
|
|
| 13 |
# quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
| 14 |
torch_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu")
|
| 15 |
|
| 16 |
+
torch_dtype = torch.bfloat16 if torch_device in ["cuda", "mps"] else torch.float32
|
| 17 |
|
| 18 |
llama_model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
|
| 19 |
# quantization_config=quantization_config,
|