Spaces:
Running
Running
| from flask import Flask, render_template, request, jsonify | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import torch.nn.functional as F | |
| from scipy.stats import percentileofscore | |
| app = Flask(__name__) | |
| DEFAULT_MODEL = "gpt2" | |
| model_cache = {} | |
| tokenizer_cache = {} | |
| def get_model_and_tokenizer(model_name): | |
| if model_name not in model_cache: | |
| trust_code = model_name == "microsoft/phi-1_5" | |
| model_cache[model_name] = AutoModelForCausalLM.from_pretrained( | |
| model_name, trust_remote_code=trust_code | |
| ) | |
| tokenizer_cache[model_name] = AutoTokenizer.from_pretrained( | |
| model_name, trust_remote_code=trust_code | |
| ) | |
| return model_cache[model_name], tokenizer_cache[model_name] | |
| def index(): | |
| return render_template( | |
| "index.html", | |
| models=[ | |
| DEFAULT_MODEL, | |
| # "gpt2-medium", | |
| # "gpt2-large", | |
| # "gpt2-xl", | |
| # "EleutherAI/pythia-1.4b", | |
| # "facebook/opt-1.3b", | |
| # "bigscience/bloom-1b7", | |
| # "microsoft/phi-1_5", | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| ], | |
| ) | |
| def analyze(): | |
| data = request.get_json() | |
| text = data["text"] | |
| model_name = data["model"] | |
| model, tokenizer = get_model_and_tokenizer(model_name) | |
| model.eval() | |
| with torch.no_grad(): | |
| inputs = tokenizer(text, return_tensors="pt") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| input_ids = inputs["input_ids"][0] | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids) | |
| log_probs = [] | |
| all_log_probs_list = [] | |
| top_k_predictions = [] | |
| for i in range(len(input_ids) - 1): | |
| probs_at_position = F.log_softmax(logits[0, i, :], dim=-1) | |
| all_log_probs_list.extend(probs_at_position.tolist()) | |
| top_k_values, top_k_indices = torch.topk(probs_at_position, 5) | |
| top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices) | |
| top_k_predictions.append( | |
| [ | |
| {"token": t, "log_prob": v.item()} | |
| for t, v in zip(top_k_tokens, top_k_values) | |
| ] | |
| ) | |
| log_prob = probs_at_position[input_ids[i + 1]].item() | |
| log_probs.append(log_prob) | |
| percentiles = [percentileofscore(all_log_probs_list, lp) for lp in log_probs] | |
| joint_log_likelihood = sum(log_probs) | |
| average_log_likelihood = ( | |
| joint_log_likelihood / len(log_probs) if log_probs else 0 | |
| ) | |
| return jsonify({ | |
| "tokens": tokens, | |
| "percentiles": percentiles, | |
| "log_probs": log_probs, | |
| "top_k_predictions": top_k_predictions, | |
| "joint_log_likelihood": joint_log_likelihood, | |
| "average_log_likelihood": average_log_likelihood, | |
| }) | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860) |