File size: 3,239 Bytes
cf17729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from __future__ import annotations

import argparse
import os
from pathlib import Path

import torch

from src.quantization_utils import (
    load_bnb_quantized_model,
    load_fp32_model,
    quantize_dynamic_int8,
    quantize_dynamic_int8_decoder_only,
    save_quant_artifact,
)


def main() -> None:
    p = argparse.ArgumentParser(description="Export quantized Seq2Seq model artifacts for CPU inference.")
    p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
    p.add_argument("--adapter", default="", help="Optional LoRA adapter directory.")
    p.add_argument("--out_dir", required=True, help="Output directory for artifact.")
    p.add_argument(
        "--mode",
        required=True,
        choices=["fp32", "int8_dynamic", "int8_decoder_dynamic", "int8_bnb", "int4_bnb"],
    )
    p.add_argument("--device", default="cpu", help="cpu|cuda (bnb requires cuda)")
    p.add_argument("--local_only", action="store_true", help="Do not hit network; use HF cache only.")
    args = p.parse_args()

    adapter = args.adapter.strip() or None
    out_dir = Path(args.out_dir)

    if args.mode == "fp32":
        tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device=args.device, local_only=args.local_only)
        save_quant_artifact(out_dir, mode="fp32", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
        return

    if args.mode == "int8_dynamic":
        tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device="cpu", local_only=args.local_only)
        model = quantize_dynamic_int8(model)
        save_quant_artifact(out_dir, mode="int8_dynamic", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
        return

    if args.mode == "int8_decoder_dynamic":
        tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device="cpu", local_only=args.local_only)
        model = quantize_dynamic_int8_decoder_only(model)
        save_quant_artifact(
            out_dir,
            mode="int8_decoder_dynamic",
            base_model=args.base_model,
            adapter_path=adapter,
            tokenizer=tok,
            model=model,
        )
        return

    if args.mode == "int8_bnb":
        tok, model = load_bnb_quantized_model(
            args.base_model,
            adapter_path=adapter,
            device=args.device,
            local_only=args.local_only,
            load_in_8bit=True,
        )
        # Note: saving bnb quantized weights in a portable way is non-trivial; we still save state_dict for reference.
        save_quant_artifact(out_dir, mode="int8_bnb", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
        return

    if args.mode == "int4_bnb":
        tok, model = load_bnb_quantized_model(
            args.base_model,
            adapter_path=adapter,
            device=args.device,
            local_only=args.local_only,
            load_in_4bit=True,
        )
        save_quant_artifact(out_dir, mode="int4_bnb", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
        return


if __name__ == "__main__":
    torch.set_grad_enabled(False)
    main()