File size: 1,636 Bytes
faa44eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# src/generation/model.py
import sys

def load_model(model_path="mlx-community/Llama-3.2-3B-Instruct-4bit"):
    """
    Loads model conditionally based on environment.
    Local (Mac): Uses MLX for GPU acceleration.
    Cloud (Linux): Uses HuggingFace Transformers (CPU/CUDA).
    """
    try:
        from mlx_lm import load, generate
        print(f"Loading {model_path} with MLX on Apple Silicon...")
        model, tokenizer = load(model_path)
        return model, tokenizer, "mlx"
    except ImportError:
        # Fallback for Docker/Cloud if MLX isn't available
        print("MLX not found. Falling back to Transformers...")
        from transformers import AutoModelForCausalLM, AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(model_path)
        return model, tokenizer, "transformers"

if __name__ == "__main__":
    import mlx.core as mx
    
    # 1. Check Default Device
    device = mx.default_device()
    print(f"✅ Current MLX Device: {device}")  # Should say "gpu"
    
    # 2. Run Inference to trigger GPU
    model, tokenizer, backend = load_model()
    
    if backend == "mlx":
        from mlx_lm import generate
        prompt = "Explain quantum physics in one sentence."
        messages = [{"role": "user", "content": prompt}]
        prompt_formatted = tokenizer.apply_chat_template(messages, tokenize=False)
        
        print(f"\n🧪 Testing Inference (Watch your GPU stats now)...")
        response = generate(model, tokenizer, prompt=prompt_formatted, verbose=True)
        print(f"\n🤖 Response: {response}")