File size: 4,150 Bytes
bf97bdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# app.py (FastAPI server to host the Jina Embedding model)
# Must be set before importing Hugging Face libs
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Optional
import torch
from transformers import AutoModel, AutoTokenizer

app = FastAPI()

# -----------------------------
# Load model once on startup
# -----------------------------
MODEL_NAME = "jinaai/jina-embeddings-v4"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModel.from_pretrained(
    MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float16
).to(device)
model.eval()


# -----------------------------
# Request / Response Models
# -----------------------------
class EmbedRequest(BaseModel):
    text: str
    task: str = "retrieval"          # "retrieval", "text-matching", "code", etc.
    prompt_name: Optional[str] = None
    return_token_embeddings: bool = True   # False → for queries (pooled embedding)


class EmbedResponse(BaseModel):
    embeddings: List[List[float]]  # (num_tokens, hidden_dim) if token-level
                                   # (1, hidden_dim) if pooled query


class TokenizeRequest(BaseModel):
    text: str


class TokenizeResponse(BaseModel):
    input_ids: List[int]


class DecodeRequest(BaseModel):
    input_ids: List[int]


class DecodeResponse(BaseModel):
    text: str


# -----------------------------
# Embedding Endpoint
# -----------------------------
@app.post("/embed", response_model=EmbedResponse)
def embed(req: EmbedRequest):
    text = req.text

    # -----------------------------
    # Case 1: Query → directly pooled embedding
    # -----------------------------
    if not req.return_token_embeddings:
        with torch.no_grad():
            emb = model.encode_text(
                texts=[text],
                task=req.task,
                prompt_name=req.prompt_name or "query",
                return_multivector=False
            )
        return {"embeddings": emb.tolist()}  # shape: (1, hidden_dim)

    # -----------------------------
    # Case 2: Long passages → sliding window token embeddings
    # -----------------------------
    enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
    input_ids = enc["input_ids"].squeeze(0).to(device)  # (total_tokens,)
    total_tokens = input_ids.size(0)

    max_len = model.config.max_position_embeddings  # e.g., 32k for v4
    stride = 50  # overlap for sliding window
    embeddings = []
    position = 0

    while position < total_tokens:
        end = min(position + max_len, total_tokens)
        window_ids = input_ids[position:end].unsqueeze(0).to(device)

        with torch.no_grad():
            outputs = model.encode_text(
                    texts=[tokenizer.decode(window_ids[0])],
                    task=req.task,
                    prompt_name=req.prompt_name or "passage",
                    return_multivector=True,
                )

        window_embeds = outputs.squeeze(0).cpu()  # (window_len, hidden_dim)

        # Drop overlapping tokens except in first window
        if position > 0:
            window_embeds = window_embeds[stride:]

        embeddings.append(window_embeds)

        # Advance window
        position += max_len - stride

    full_embeddings = torch.cat(embeddings, dim=0)  # (total_tokens, hidden_dim)
    return {"embeddings": full_embeddings.tolist()}


# -----------------------------
# Tokenize Endpoint
# -----------------------------
@app.post("/tokenize", response_model=TokenizeResponse)
def tokenize(req: TokenizeRequest):
    enc = tokenizer(req.text, add_special_tokens=False)
    return {"input_ids": enc["input_ids"]}


# -----------------------------
# Decode Endpoint
# -----------------------------
@app.post("/decode", response_model=DecodeResponse)
def decode(req: DecodeRequest):
    decoded = tokenizer.decode(req.input_ids)
    return {"text": decoded}