mirror of
https://github.com/yk/gpt-4chan-public.git
synced 2024-12-22 10:50:04 +00:00
initial commit
This commit is contained in:
parent
0538b0e31e
commit
b1089a2d89
|
@ -1,2 +1,9 @@
|
||||||
# gpt-4chan-public
|
# gpt-4chan-public
|
||||||
Code for GPT-4chan
|
Code for GPT-4chan
|
||||||
|
|
||||||
|
Note: This repository only contains helper code and small changes I made to other libraries.
|
||||||
|
The source code to the actual model is here at [https://github.com/kingoflolz/mesh-transformer-jax/](https://github.com/kingoflolz/mesh-transformer-jax/)
|
||||||
|
|
||||||
|
Data here: [https://zenodo.org/record/3606810](https://zenodo.org/record/3606810)
|
||||||
|
Model here: [https://huggingface.co/ykilcher/gpt-4chan](https://huggingface.co/ykilcher/gpt-4chan)
|
||||||
|
Website here: [https://gpt-4chan.com](https://gpt-4chan.com)
|
||||||
|
|
77
src/compute_metrics.py
Executable file
77
src/compute_metrics.py
Executable file
|
@ -0,0 +1,77 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from loguru import logger
|
||||||
|
from tabulate import tabulate
|
||||||
|
import lm_eval.tasks
|
||||||
|
|
||||||
|
m1 = 'GPT-J-6B'
|
||||||
|
m2 = 'GPT-4chan'
|
||||||
|
|
||||||
|
log_dir = Path('./eval_logs')
|
||||||
|
all_tasks = set()
|
||||||
|
model_data = {}
|
||||||
|
for fn in log_dir.rglob('log_*.stdout.txt'):
|
||||||
|
try:
|
||||||
|
file_text = fn.read_text()
|
||||||
|
data = json.loads('{' + file_text.split('{', 1)[1].rsplit('}', 1)[0] + '}')
|
||||||
|
model = data['config']['model_args'].split('=')[1]
|
||||||
|
model = m2 if 'fp16' in model else m1
|
||||||
|
if model not in model_data:
|
||||||
|
model_data[model] = {}
|
||||||
|
results = data['results']
|
||||||
|
tasks = list(results.keys())
|
||||||
|
assert len(tasks) == 1, 'Only one task supported'
|
||||||
|
task = tasks[0]
|
||||||
|
if task in model_data[model]:
|
||||||
|
raise ValueError(f'Duplicate task {task}')
|
||||||
|
task_version = data['versions'][task]
|
||||||
|
results = results[task]
|
||||||
|
results_data = {}
|
||||||
|
for result_key in results:
|
||||||
|
if result_key.endswith('_stderr'):
|
||||||
|
continue
|
||||||
|
result_value = results[result_key]
|
||||||
|
results_data[result_key] = {'value': result_value}
|
||||||
|
stderr_key = f'{result_key}_stderr'
|
||||||
|
if stderr_key in results:
|
||||||
|
results_data[result_key]['stderr'] = results[stderr_key]
|
||||||
|
else:
|
||||||
|
logger.warning(f'No stderr for {result_key} in {results}')
|
||||||
|
model_data[model][task] = {'version': task_version, 'results': results_data}
|
||||||
|
all_tasks.add(task)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f'Failed to parse {fn}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_models = list(sorted(model_data.keys()))
|
||||||
|
table_data = []
|
||||||
|
for task in all_tasks:
|
||||||
|
try:
|
||||||
|
higher_is_better = lm_eval.tasks.get_task(task).higher_is_better(None)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f'Failed to get higher_is_better for {task}')
|
||||||
|
continue
|
||||||
|
if any(task not in model_data[model] for model in all_models):
|
||||||
|
logger.warning(f'No results for {task}')
|
||||||
|
continue
|
||||||
|
results = model_data[m1][task]['results']
|
||||||
|
results2 = model_data[m2][task]['results']
|
||||||
|
for metric in results:
|
||||||
|
result_value = results[metric]['value']
|
||||||
|
stderr_value = results[metric].get('stderr', 0.0)
|
||||||
|
result2_value = results2[metric]['value']
|
||||||
|
stderr2_value = results2[metric].get('stderr', 0.0)
|
||||||
|
significance = (result_value - result2_value) / ((stderr_value + stderr2_value + 1e-8) / 2)
|
||||||
|
if higher_is_better[metric]:
|
||||||
|
significance *= -1
|
||||||
|
if abs(significance) > 1:
|
||||||
|
significant = '+' if significance > 0 else '-'
|
||||||
|
else:
|
||||||
|
significant = ''
|
||||||
|
table_data.append([task, metric, result_value, stderr_value, result2_value, stderr2_value, significant])
|
||||||
|
|
||||||
|
table_str = tabulate(table_data, headers=['Task', 'Metric', m1, 'stderr', m2, 'stderr', 'Significant'], tablefmt='pipe')
|
||||||
|
print(table_str)
|
||||||
|
Path('./results.table.txt').write_text(table_str)
|
59
src/process_data.py
Executable file
59
src/process_data.py
Executable file
|
@ -0,0 +1,59 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import json
|
||||||
|
import bs4
|
||||||
|
from loguru import logger
|
||||||
|
import multiprocessing as mp
|
||||||
|
import tqdm
|
||||||
|
from absl import app, flags
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore", category=bs4.MarkupResemblesLocatorWarning, module='bs4')
|
||||||
|
|
||||||
|
DATA_FN = '../tmp/pol_062016-112019_labeled.ndjson'
|
||||||
|
OUT_FN = '../tmp/kek.txt'
|
||||||
|
|
||||||
|
flags.DEFINE_string('data_fn', DATA_FN, 'data file')
|
||||||
|
flags.DEFINE_string('out_fn', OUT_FN, 'output file')
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
# from here: https://gist.github.com/zmwangx/ad0830ba94b1fd98f428
|
||||||
|
def text_with_newlines(elem):
|
||||||
|
text = ''
|
||||||
|
for e in elem.descendants:
|
||||||
|
if isinstance(e, str):
|
||||||
|
# text += e.strip()
|
||||||
|
text += e
|
||||||
|
elif e.name == 'br' or e.name == 'p':
|
||||||
|
text += '\n'
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def parse_line(line):
|
||||||
|
data = json.loads(line)
|
||||||
|
posts_text = []
|
||||||
|
for post in data.get('posts', []):
|
||||||
|
try:
|
||||||
|
if 'com' in post:
|
||||||
|
soup = bs4.BeautifulSoup(post['com'], 'lxml')
|
||||||
|
post_text = text_with_newlines(soup).strip()
|
||||||
|
else:
|
||||||
|
post_text = ''
|
||||||
|
post_text = f'--- {post["no"]}\n{post_text}'
|
||||||
|
posts_text.append(post_text)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f'failed to parse post {post}')
|
||||||
|
return '\n'.join(posts_text)
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
with open(FLAGS.out_fn, 'w') as out_f:
|
||||||
|
with open(FLAGS.data_fn) as in_f:
|
||||||
|
with mp.Pool() as pool:
|
||||||
|
for parsed_line in pool.imap(parse_line, tqdm.tqdm(in_f)):
|
||||||
|
out_f.write(parsed_line + '\n-----\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(main)
|
11
src/server/README.md
Normal file
11
src/server/README.md
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
clone https://github.com/kingoflolz/mesh-transformer-jax and put this code inside
|
||||||
|
|
||||||
|
`model` is from https://github.com/okbuddyhololive/project-cybertard with slight changes
|
||||||
|
|
||||||
|
(you only need to do the above things if you want to run jax inference. for hugging face, it is not necessary)
|
||||||
|
|
||||||
|
then run `uvicorn --host 0.0.0.0 --port 8080 serve_api:app`
|
||||||
|
|
||||||
|
I use python 3.9.12 and install requirements.txt, then uninstall jax, jaxlib, tensorflow, and tensorflow-cpu
|
||||||
|
|
||||||
|
install `jax==0.2.12 jaxlib==0.1.67 tensorflow==2.5.0 markupsafe==2.0.1 uvicorn fastapi loguru`
|
2
src/server/model/__init__.py
Normal file
2
src/server/model/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .inference import Inference
|
||||||
|
from .constants import ModelParams, InferConfig
|
45
src/server/model/constants.py
Normal file
45
src/server/model/constants.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
import typing
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import optax
|
||||||
|
|
||||||
|
BAD_WORDS = [] # Can't be part of config to avoid printing it
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferConfig:
|
||||||
|
name: str = "Holotard"
|
||||||
|
prompt_length: int = 65536
|
||||||
|
token_length: int = 16
|
||||||
|
|
||||||
|
response_probability: float = 0.02
|
||||||
|
top_p: float = 1.0
|
||||||
|
|
||||||
|
min_temperature: float = 0.6
|
||||||
|
max_temperature: float = 1.2
|
||||||
|
|
||||||
|
max_same_replies: int = 2
|
||||||
|
same_reply_saved_messages: int = 6
|
||||||
|
max_response_retries: int = 3
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelParams:
|
||||||
|
layers: int = 28
|
||||||
|
d_model: int = 4096
|
||||||
|
n_heads: int = 16
|
||||||
|
n_vocab: int = 50400
|
||||||
|
|
||||||
|
norm: str = "layernorm"
|
||||||
|
pe: str = "rotary"
|
||||||
|
pe_rotary_dims: int = 64
|
||||||
|
|
||||||
|
seq: int = 2048
|
||||||
|
cores_per_replica: int = 8
|
||||||
|
per_replica_batch: int = 1
|
||||||
|
|
||||||
|
# batch size of 2 needs 200gb, 1 needs <16. wtf
|
||||||
|
optimizer: optax.chain = optax.chain(optax.adaptive_grad_clip(0.001), optax.centralize(),
|
||||||
|
optax.scale_by_adam(0.99, 0.999), optax.additive_weight_decay(1e-3),
|
||||||
|
optax.scale(-1e-5), )
|
||||||
|
sampler = None
|
111
src/server/model/inference.py
Normal file
111
src/server/model/inference.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
import random
|
||||||
|
from typing import Any, Optional
|
||||||
|
from loguru import logger
|
||||||
|
import time
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
import transformers
|
||||||
|
from jax import numpy as jnp
|
||||||
|
from jax.experimental import maps
|
||||||
|
from mesh_transformer.checkpoint import read_ckpt_lowmem
|
||||||
|
from mesh_transformer.sampling import nucleaus_sample
|
||||||
|
from mesh_transformer.transformer_shard import CausalTransformer
|
||||||
|
|
||||||
|
from .constants import ModelParams, InferConfig
|
||||||
|
|
||||||
|
|
||||||
|
def default(value: Any, fallback: Any) -> Any:
|
||||||
|
# luke prefers making a function that chooses between `value` and `feedback` so i am gonna keep it
|
||||||
|
if value is None:
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
_cores_per_replica = ModelParams.cores_per_replica
|
||||||
|
_mesh_shape = (jax.device_count() // _cores_per_replica, _cores_per_replica)
|
||||||
|
_devices = np.array(jax.devices()).reshape(_mesh_shape)
|
||||||
|
|
||||||
|
#maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(_devices, ("dp", "mp")), ())
|
||||||
|
maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(_devices, ("dp", "mp")))
|
||||||
|
|
||||||
|
|
||||||
|
class Inference:
|
||||||
|
_NP_ONE = np.ones((1,))
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
parameters: Optional[ModelParams] = None,
|
||||||
|
config: Optional[InferConfig] = None,
|
||||||
|
):
|
||||||
|
path = "checkpoint_slim/" if path is None else path
|
||||||
|
|
||||||
|
self.params = ModelParams() if parameters is None else parameters
|
||||||
|
self.params.sampler = nucleaus_sample
|
||||||
|
self.config = InferConfig() if config is None else config
|
||||||
|
|
||||||
|
self.model = CausalTransformer(self.params.__dict__)
|
||||||
|
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
|
self.model.state = read_ckpt_lowmem(
|
||||||
|
self.model.state, path, self.params.cores_per_replica, load_opt=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_tokens(
|
||||||
|
self,
|
||||||
|
prompt: np.ndarray,
|
||||||
|
length: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
length = default(length, self.config.token_length)
|
||||||
|
top_p = default(top_p, self.config.top_p)
|
||||||
|
new_temp = random.random() * (self.config.max_temperature - self.config.min_temperature)
|
||||||
|
new_temp += self.config.min_temperature
|
||||||
|
temperature = default(temperature, new_temp)
|
||||||
|
#prompt = prompt[:, -2048:]
|
||||||
|
#prompt = prompt[:, -length:]
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
source = jnp.array(
|
||||||
|
np.pad(
|
||||||
|
prompt,
|
||||||
|
(
|
||||||
|
(0, 0),
|
||||||
|
(self.params.seq - prompt.shape[1], 0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(f"creating source took {time.time() - start_time}")
|
||||||
|
sampler_options = {
|
||||||
|
"top_p": self._NP_ONE * top_p,
|
||||||
|
"temp": self._NP_ONE * temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
#with jax.experimental.maps.mesh(_devices, ("dp", "mp")):
|
||||||
|
logger.info(f"creating mesh took {time.time() - start_time}")
|
||||||
|
start_time = time.time()
|
||||||
|
out = self.model.generate(
|
||||||
|
source, self._NP_ONE * prompt.shape[1], length, sampler_options
|
||||||
|
)
|
||||||
|
logger.info(f"generate took {time.time() - start_time}")
|
||||||
|
|
||||||
|
#import IPython; IPython.embed()
|
||||||
|
return out[1][0][0, :, 0]
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
length: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
) -> str:
|
||||||
|
inp_tokens = self.tokenizer([prompt], verbose=False, return_tensors="np")
|
||||||
|
inp_tokens = inp_tokens["input_ids"][0]
|
||||||
|
out_tokens = self.generate_tokens(
|
||||||
|
inp_tokens.reshape(1, -1), length, top_p, temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.tokenizer.decode(out_tokens)
|
52
src/server/model/to_slim_weights.py
Normal file
52
src/server/model/to_slim_weights.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
import optax
|
||||||
|
|
||||||
|
from mesh_transformer import util
|
||||||
|
from mesh_transformer.checkpoint import read_ckpt, write_ckpt, read_ckpt_lowmem
|
||||||
|
from mesh_transformer.transformer_shard import CausalTransformer
|
||||||
|
from smart_open import open
|
||||||
|
|
||||||
|
from mesh_transformer.util import clip_by_global_norm, to_bf16, to_f16
|
||||||
|
from model.constants import ModelParams
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
params = ModelParams().__dict__
|
||||||
|
convert_fn = to_bf16
|
||||||
|
|
||||||
|
cores_per_replica = params["cores_per_replica"]
|
||||||
|
|
||||||
|
assert cores_per_replica <= 8
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
print(f"jax devices: {jax.device_count()}")
|
||||||
|
print(f"jax runtime initialized in {time.time() - start:.06}s")
|
||||||
|
|
||||||
|
mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
|
||||||
|
devices = np.array(jax.devices()).reshape(mesh_shape)
|
||||||
|
|
||||||
|
with jax.experimental.maps.mesh(devices, ("dp", "mp")):
|
||||||
|
network = CausalTransformer(params)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
network.state = read_ckpt(
|
||||||
|
network.state, f"checkpoint/", devices.shape[1], load_opt=False
|
||||||
|
)
|
||||||
|
print(f"network loaded in {time.time() - start:.06}s")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
del network.state["opt_state"]
|
||||||
|
|
||||||
|
network.state["params"] = convert_fn(network.state["params"])
|
||||||
|
print(f"network converted in {time.time() - start:.06}s")
|
||||||
|
|
||||||
|
suffix = "_slim"
|
||||||
|
|
||||||
|
for i in range(cores_per_replica):
|
||||||
|
write_ckpt(network.state, f"checkpoint_slim/", i)
|
||||||
|
print(f"written shard {i}")
|
199
src/server/serve_api.py
Normal file
199
src/server/serve_api.py
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
from loguru import logger
|
||||||
|
from pathlib import Path
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
origins = ["*"]
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(pydantic.BaseSettings):
|
||||||
|
queue_size: int = 1024
|
||||||
|
log_file: str = "logs/serve_api.log"
|
||||||
|
api_keys_file: str = 'valid_api_keys.txt'
|
||||||
|
hf_model: str = ''
|
||||||
|
hf_cuda: bool = False
|
||||||
|
pre_prompt_length: int = 512
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
def _check_api_key(key):
|
||||||
|
key = key.strip()
|
||||||
|
for line in Path(settings.api_keys_file).open():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
valid_key = line.split()[0]
|
||||||
|
if key == valid_key:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
request_queue = queue.Queue(maxsize=settings.queue_size)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def jax_generation():
|
||||||
|
from model import inference
|
||||||
|
import jax
|
||||||
|
model = inference.Inference(path="../model_slim/step_88001/")
|
||||||
|
|
||||||
|
def _generate(request):
|
||||||
|
response = model.generate(
|
||||||
|
prompt=request.prompt,
|
||||||
|
length=request.length,
|
||||||
|
top_p=request.top_p,
|
||||||
|
temperature=request.temperature,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
with jax.experimental.maps.mesh(inference._devices, ("dp", "mp")):
|
||||||
|
yield _generate
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def hf_generation():
|
||||||
|
from transformers import GPTJForCausalLM, AutoTokenizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if settings.hf_cuda:
|
||||||
|
model = GPTJForCausalLM.from_pretrained(
|
||||||
|
settings.hf_model, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||||
|
)
|
||||||
|
model.cuda()
|
||||||
|
else:
|
||||||
|
model = GPTJForCausalLM.from_pretrained( settings.hf_model, torch_dtype=torch.float32)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
||||||
|
|
||||||
|
def _generate(request: CompleteRequest):
|
||||||
|
input_ids = tokenizer(request.prompt, return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
max_prompt_length = 2048 - request.length
|
||||||
|
input_ids = input_ids[:, -max_prompt_length:]
|
||||||
|
|
||||||
|
if request.pre_prompt:
|
||||||
|
pp_input_ids = tokenizer(request.pre_prompt, return_tensors="pt").input_ids
|
||||||
|
pp_input_ids = pp_input_ids[:, :settings.pre_prompt_length]
|
||||||
|
input_ids = input_ids[:, -(max_prompt_length-len(pp_input_ids)):]
|
||||||
|
full_prompt = tokenizer.batch_decode(pp_input_ids)[0] + tokenizer.batch_decode(input_ids)[0]
|
||||||
|
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
|
||||||
|
input_ids = input_ids[:, -max_prompt_length:]
|
||||||
|
|
||||||
|
|
||||||
|
if settings.hf_cuda:
|
||||||
|
input_ids = input_ids.cuda()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
gen_tokens = model.generate(
|
||||||
|
input_ids,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_p=request.top_p,
|
||||||
|
typical_p=request.typical_p,
|
||||||
|
max_new_tokens=request.length,
|
||||||
|
).detach().cpu()
|
||||||
|
gen_text = tokenizer.batch_decode(gen_tokens)[0]
|
||||||
|
prompt_decoded = tokenizer.batch_decode(input_ids.detach().cpu())[0]
|
||||||
|
if not gen_text.startswith(prompt_decoded):
|
||||||
|
raise Exception(f"Generated text does not start with prompt: {gen_text}\n(prompt was {prompt_decoded})")
|
||||||
|
gen_text = gen_text[len(prompt_decoded):]
|
||||||
|
return gen_text
|
||||||
|
yield _generate
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
if settings.hf_model:
|
||||||
|
generation = hf_generation
|
||||||
|
else:
|
||||||
|
generation = jax_generation
|
||||||
|
with generation() as generate_fn:
|
||||||
|
with open(settings.log_file, "a") as logf:
|
||||||
|
while True:
|
||||||
|
response_queue = None
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
(request, response_queue) = request_queue.get()
|
||||||
|
logger.info(f"getting request took {time.time() - start_time}")
|
||||||
|
start_time = time.time()
|
||||||
|
response = generate_fn(request)
|
||||||
|
logger.info(f"generate took {time.time() - start_time}, response length: {len(response)}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
logf.write(f"##### {request.api_key} ##### {time.time()} #####\n")
|
||||||
|
logf.write(f"{request.pre_prompt}\n")
|
||||||
|
logf.write("###\n")
|
||||||
|
logf.write(f"{request.prompt}\n")
|
||||||
|
logf.write("#####\n")
|
||||||
|
logf.write(f"{response}\n\n")
|
||||||
|
logf.flush()
|
||||||
|
|
||||||
|
logger.info(f"writing log took {time.time() - start_time}")
|
||||||
|
start_time = time.time()
|
||||||
|
response_queue.put(response)
|
||||||
|
logger.info(f"putting response took {time.time() - start_time}")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info(f"Got KeyboardInterrupt... quitting!")
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Got exception, will continue")
|
||||||
|
if response_queue is not None:
|
||||||
|
response_queue.put("")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def main():
|
||||||
|
return {"response": "Hello, world!"}
|
||||||
|
|
||||||
|
class CompleteRequest(pydantic.BaseModel):
|
||||||
|
prompt: pydantic.constr(min_length=0, max_length=2**14)
|
||||||
|
pre_prompt: pydantic.constr(min_length=0, max_length=2**14) = ''
|
||||||
|
api_key: pydantic.constr(min_length=1, max_length=128) = "x"*9
|
||||||
|
length: pydantic.conint(ge=1, le=1024) = 128
|
||||||
|
top_p: pydantic.confloat(ge=0.0, le=1.0) = 1.0
|
||||||
|
temperature: pydantic.confloat(ge=0.0) = 1.0
|
||||||
|
typical_p: pydantic.confloat(ge=0.0, le=1.0) = 1.0
|
||||||
|
|
||||||
|
def _enqueue(request: CompleteRequest):
|
||||||
|
response_queue = queue.Queue()
|
||||||
|
request_queue.put((request, response_queue))
|
||||||
|
response = response_queue.get()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def startup():
|
||||||
|
threading.Thread(
|
||||||
|
target=worker,
|
||||||
|
daemon=True,
|
||||||
|
).start()
|
||||||
|
_enqueue(CompleteRequest(prompt="hello"))
|
||||||
|
|
||||||
|
@app.post("/complete")
|
||||||
|
def complete(request: CompleteRequest):
|
||||||
|
logger.info(f"Received request from key {request.api_key}. Queue size is {request_queue.qsize()}")
|
||||||
|
if request_queue.full():
|
||||||
|
logger.warning("Request queue full.")
|
||||||
|
raise ValueError("Request queue full.")
|
||||||
|
if not _check_api_key(request.api_key):
|
||||||
|
logger.warning(f"api key not valid: {request.api_key}, discarding...")
|
||||||
|
raise ValueError("Invalid API key")
|
||||||
|
response = _enqueue(request)
|
||||||
|
return {"response": response}
|
65
src/txt_to_tfrecords.py
Executable file
65
src/txt_to_tfrecords.py
Executable file
|
@ -0,0 +1,65 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from absl import flags, app
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import tokenizers
|
||||||
|
import tensorflow as tf
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
flags.DEFINE_string('txt_fn', '../tmp/kek.txt', 'input txt')
|
||||||
|
flags.DEFINE_string('out_dir', '../tmp/tfrecords/', 'output directory (will be cleared)')
|
||||||
|
flags.DEFINE_integer('chunk_size', 2**24, 'how many tokens go into one tfrecords file')
|
||||||
|
flags.DEFINE_integer('read_buffer_size', 2**10, 'input file read buffer size')
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=1)
|
||||||
|
def get_tokenizer():
|
||||||
|
return tokenizers.Tokenizer.from_pretrained('gpt2')
|
||||||
|
|
||||||
|
|
||||||
|
def make_record_file(record_ids, out_dir, file_no):
|
||||||
|
out_fn = str(out_dir / f'tokens-{file_no:05d}.tfrecord')
|
||||||
|
with tf.io.TFRecordWriter(out_fn) as writer:
|
||||||
|
feature = {'text': tf.train.Feature(int64_list=tf.train.Int64List(value=record_ids))}
|
||||||
|
tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
|
||||||
|
writer.write(tf_example.SerializeToString())
|
||||||
|
|
||||||
|
|
||||||
|
def read_in_blocks(f):
|
||||||
|
while True:
|
||||||
|
block = f.read(FLAGS.read_buffer_size)
|
||||||
|
if not block:
|
||||||
|
break
|
||||||
|
yield block
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
out_dir = Path(FLAGS.out_dir)
|
||||||
|
if out_dir.exists():
|
||||||
|
logger.warning(f'clearing {out_dir}')
|
||||||
|
shutil.rmtree(out_dir)
|
||||||
|
out_dir.mkdir(exist_ok=True)
|
||||||
|
tokenizer = get_tokenizer()
|
||||||
|
with open(FLAGS.txt_fn) as in_f:
|
||||||
|
current_ids = []
|
||||||
|
out_file_no = 0
|
||||||
|
for block in tqdm.tqdm(read_in_blocks(in_f)):
|
||||||
|
current_ids.extend(tokenizer.encode(block).ids)
|
||||||
|
while len(current_ids) >= FLAGS.chunk_size:
|
||||||
|
record_ids, current_ids = current_ids[:FLAGS.chunk_size], current_ids[FLAGS.chunk_size:]
|
||||||
|
make_record_file(record_ids, out_dir, out_file_no)
|
||||||
|
out_file_no += 1
|
||||||
|
|
||||||
|
if current_ids:
|
||||||
|
make_record_file(current_ids, out_dir, out_file_no)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(main)
|
Loading…
Reference in a new issue