Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,338 Bytes
fe65e71 0a027ed fe65e71 0a027ed a46043f fe65e71 8e219a5 0a027ed fe65e71 a46043f 0a027ed a46043f 8e219a5 fe65e71 0a027ed fe65e71 a46043f fe65e71 0a027ed fe65e71 0a027ed fe65e71 8e219a5 fe65e71 0a027ed fe65e71 a46043f fe65e71 0a027ed fe65e71 0a027ed fe65e71 0a027ed a46043f 0a027ed fe65e71 0a027ed fe65e71 0a027ed fe65e71 a46043f 0a027ed fe65e71 d8f08b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import os
import time
from typing import List, Dict, Tuple
import gradio as gr
from transformers import pipeline
import spaces
# === Config (override via Space secrets/env vars) ===
MODEL_ID = os.environ.get("MODEL_ID", "openai/gpt-oss-20b")
STATIC_PROMPT = """"""
DEFAULT_MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 512))
DEFAULT_TEMPERATURE = float(os.environ.get("TEMPERATURE", 0.7))
DEFAULT_TOP_P = float(os.environ.get("TOP_P", 0.95))
DEFAULT_REPETITION_PENALTY = float(os.environ.get("REPETITION_PENALTY", 1.0))
ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", 120)) # seconds
_pipe = None # cached pipeline
_tok = None # tokenizer for parsing Harmony format
def _to_messages(policy: str, user_prompt: str) -> List[Dict[str, str]]:
messages = []
if policy.strip():
messages.append({"role": "system", "content": policy.strip()})
# if STATIC_PROMPT:
# messages.append({"role": "system", "content": STATIC_PROMPT})
messages.append({"role": "user", "content": user_prompt})
return messages
def _parse_harmony_output(last, tokenizer):
analysis, content = None, None
if isinstance(last, dict) and ("content" in last or "thinking" in last):
analysis = last.get("thinking")
content = last.get("content")
else:
parsed = tokenizer.parse_response(last)
analysis = parsed.get("thinking")
content = parsed.get("content")
return analysis, content
@spaces.GPU(duration=ZGPU_DURATION)
def generate_long_prompt(
policy: str,
prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
repetition_penalty: float,
) -> Tuple[str, str, str]:
global _pipe, _tok
start = time.time()
if _pipe is None:
_pipe = pipeline(
task="text-generation",
model=MODEL_ID,
torch_dtype="auto",
device_map="auto",
)
_tok = _pipe.tokenizer
messages = _to_messages(policy, prompt)
outputs = _pipe(
messages,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
res = outputs[0]
last = res.get("generated_text", [])
if isinstance(last, list) and last:
last = last[-1]
analysis, content = _parse_harmony_output(last, _tok)
elapsed = time.time() - start
meta = f"Model: {MODEL_ID} | Time: {elapsed:.1f}s | max_new_tokens={max_new_tokens}"
return analysis or "(No analysis)", content or "(No answer)", meta
CUSTOM_CSS = "/** Simple styling **/\n.gradio-container {font-family: ui-sans-serif, system-ui, Inter, Roboto;}\ntextarea {font-family: ui-monospace, monospace;}\nfooter {display:none;}"
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
gr.Markdown("""# GPT‑OSS Harmony Demo\nProvide a **Policy**, a **Prompt**, and see both **Analysis** and **Answer** separately.""")
with gr.Row():
with gr.Column(scale=1):
policy = gr.Textbox(label="Policy (system)", lines=20, placeholder="Enter the guiding rules and tone…")
prompt = gr.Textbox(label="Prompt (user)", lines=10, placeholder="Enter your main prompt…")
with gr.Accordion("Advanced settings", open=False):
max_new_tokens = gr.Slider(16, 4096, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="max_new_tokens")
temperature = gr.Slider(0.0, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="temperature")
top_p = gr.Slider(0.0, 1.0, value=DEFAULT_TOP_P, step=0.01, label="top_p")
repetition_penalty = gr.Slider(0.8, 2.0, value=DEFAULT_REPETITION_PENALTY, step=0.05, label="repetition_penalty")
generate = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
analysis = gr.Textbox(label="Analysis (Harmony thinking)", lines=10)
answer = gr.Textbox(label="Answer", lines=10)
meta = gr.Markdown()
generate.click(
fn=generate_long_prompt,
inputs=[policy, prompt, max_new_tokens, temperature, top_p, repetition_penalty],
outputs=[analysis, answer, meta],
concurrency_limit=1,
api_name="generate",
)
if __name__ == "__main__":
demo.queue(max_size=32).launch()
|