willsh1997 commited on
Commit
fa8a6c8
·
1 Parent(s): 1761d59

TEST cast to bfloat16

Browse files
Files changed (1) hide show
  1. context_window_gradio.py +1 -1
context_window_gradio.py CHANGED
@@ -13,7 +13,7 @@ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndB
13
  # quantization_config = BitsAndBytesConfig(load_in_4bit=True)
14
  torch_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu")
15
 
16
- torch_dtype = torch.float16 if torch_device in ["cuda", "mps"] else torch.float32
17
 
18
  llama_model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
19
  # quantization_config=quantization_config,
 
13
  # quantization_config = BitsAndBytesConfig(load_in_4bit=True)
14
  torch_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu")
15
 
16
+ torch_dtype = torch.bfloat16 if torch_device in ["cuda", "mps"] else torch.float32
17
 
18
  llama_model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
19
  # quantization_config=quantization_config,