gpt-4chan-public/src/server/serve_api.py

200 lines
6.6 KiB
Python

#!/usr/bin/env python3
from typing import Optional
import threading
import queue
import time
from loguru import logger
from pathlib import Path
import contextlib
import pydantic
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Settings(pydantic.BaseSettings):
queue_size: int = 1024
log_file: str = "logs/serve_api.log"
api_keys_file: str = 'valid_api_keys.txt'
hf_model: str = ''
hf_cuda: bool = False
pre_prompt_length: int = 512
settings = Settings()
def _check_api_key(key):
key = key.strip()
for line in Path(settings.api_keys_file).open():
if not line:
continue
valid_key = line.split()[0]
if key == valid_key:
break
else:
return False
return True
request_queue = queue.Queue(maxsize=settings.queue_size)
@contextlib.contextmanager
def jax_generation():
from model import inference
import jax
model = inference.Inference(path="../model_slim/step_88001/")
def _generate(request):
response = model.generate(
prompt=request.prompt,
length=request.length,
top_p=request.top_p,
temperature=request.temperature,
)
return response
with jax.experimental.maps.mesh(inference._devices, ("dp", "mp")):
yield _generate
@contextlib.contextmanager
def hf_generation():
from transformers import GPTJForCausalLM, AutoTokenizer
import torch
if settings.hf_cuda:
model = GPTJForCausalLM.from_pretrained(
settings.hf_model, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True
)
model.cuda()
else:
model = GPTJForCausalLM.from_pretrained( settings.hf_model, torch_dtype=torch.float32)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
def _generate(request: CompleteRequest):
input_ids = tokenizer(request.prompt, return_tensors="pt").input_ids
max_prompt_length = 2048 - request.length
input_ids = input_ids[:, -max_prompt_length:]
if request.pre_prompt:
pp_input_ids = tokenizer(request.pre_prompt, return_tensors="pt").input_ids
pp_input_ids = pp_input_ids[:, :settings.pre_prompt_length]
input_ids = input_ids[:, -(max_prompt_length-len(pp_input_ids)):]
full_prompt = tokenizer.batch_decode(pp_input_ids)[0] + tokenizer.batch_decode(input_ids)[0]
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
input_ids = input_ids[:, -max_prompt_length:]
if settings.hf_cuda:
input_ids = input_ids.cuda()
with torch.no_grad():
gen_tokens = model.generate(
input_ids,
do_sample=True,
temperature=request.temperature,
top_p=request.top_p,
typical_p=request.typical_p,
max_new_tokens=request.length,
).detach().cpu()
gen_text = tokenizer.batch_decode(gen_tokens)[0]
prompt_decoded = tokenizer.batch_decode(input_ids.detach().cpu())[0]
if not gen_text.startswith(prompt_decoded):
raise Exception(f"Generated text does not start with prompt: {gen_text}\n(prompt was {prompt_decoded})")
gen_text = gen_text[len(prompt_decoded):]
return gen_text
yield _generate
def worker():
if settings.hf_model:
generation = hf_generation
else:
generation = jax_generation
with generation() as generate_fn:
with open(settings.log_file, "a") as logf:
while True:
response_queue = None
try:
start_time = time.time()
(request, response_queue) = request_queue.get()
logger.info(f"getting request took {time.time() - start_time}")
start_time = time.time()
response = generate_fn(request)
logger.info(f"generate took {time.time() - start_time}, response length: {len(response)}")
start_time = time.time()
logf.write(f"##### {request.api_key} ##### {time.time()} #####\n")
logf.write(f"{request.pre_prompt}\n")
logf.write("###\n")
logf.write(f"{request.prompt}\n")
logf.write("#####\n")
logf.write(f"{response}\n\n")
logf.flush()
logger.info(f"writing log took {time.time() - start_time}")
start_time = time.time()
response_queue.put(response)
logger.info(f"putting response took {time.time() - start_time}")
except KeyboardInterrupt:
logger.info(f"Got KeyboardInterrupt... quitting!")
raise
except Exception:
logger.exception(f"Got exception, will continue")
if response_queue is not None:
response_queue.put("")
@app.get("/")
async def main():
return {"response": "Hello, world!"}
class CompleteRequest(pydantic.BaseModel):
prompt: pydantic.constr(min_length=0, max_length=2**14)
pre_prompt: pydantic.constr(min_length=0, max_length=2**14) = ''
api_key: pydantic.constr(min_length=1, max_length=128) = "x"*9
length: pydantic.conint(ge=1, le=1024) = 128
top_p: pydantic.confloat(ge=0.0, le=1.0) = 1.0
temperature: pydantic.confloat(ge=0.0) = 1.0
typical_p: pydantic.confloat(ge=0.0, le=1.0) = 1.0
def _enqueue(request: CompleteRequest):
response_queue = queue.Queue()
request_queue.put((request, response_queue))
response = response_queue.get()
return response
@app.on_event("startup")
def startup():
threading.Thread(
target=worker,
daemon=True,
).start()
_enqueue(CompleteRequest(prompt="hello"))
@app.post("/complete")
def complete(request: CompleteRequest):
logger.info(f"Received request from key {request.api_key}. Queue size is {request_queue.qsize()}")
if request_queue.full():
logger.warning("Request queue full.")
raise ValueError("Request queue full.")
if not _check_api_key(request.api_key):
logger.warning(f"api key not valid: {request.api_key}, discarding...")
raise ValueError("Invalid API key")
response = _enqueue(request)
return {"response": response}