initial commit

This commit is contained in:
Yannic Kilcher 2022-06-03 17:20:18 +02:00
parent 0538b0e31e
commit b1089a2d89
10 changed files with 628 additions and 0 deletions

View File

@ -1,2 +1,9 @@
# gpt-4chan-public
Code for GPT-4chan
Note: This repository only contains helper code and small changes I made to other libraries.
The source code to the actual model is here at [https://github.com/kingoflolz/mesh-transformer-jax/](https://github.com/kingoflolz/mesh-transformer-jax/)
Data here: [https://zenodo.org/record/3606810](https://zenodo.org/record/3606810)
Model here: [https://huggingface.co/ykilcher/gpt-4chan](https://huggingface.co/ykilcher/gpt-4chan)
Website here: [https://gpt-4chan.com](https://gpt-4chan.com)

77
src/compute_metrics.py Executable file
View File

@ -0,0 +1,77 @@
#!/usr/bin/env python3
import json
from pathlib import Path
from loguru import logger
from tabulate import tabulate
import lm_eval.tasks
m1 = 'GPT-J-6B'
m2 = 'GPT-4chan'
log_dir = Path('./eval_logs')
all_tasks = set()
model_data = {}
for fn in log_dir.rglob('log_*.stdout.txt'):
try:
file_text = fn.read_text()
data = json.loads('{' + file_text.split('{', 1)[1].rsplit('}', 1)[0] + '}')
model = data['config']['model_args'].split('=')[1]
model = m2 if 'fp16' in model else m1
if model not in model_data:
model_data[model] = {}
results = data['results']
tasks = list(results.keys())
assert len(tasks) == 1, 'Only one task supported'
task = tasks[0]
if task in model_data[model]:
raise ValueError(f'Duplicate task {task}')
task_version = data['versions'][task]
results = results[task]
results_data = {}
for result_key in results:
if result_key.endswith('_stderr'):
continue
result_value = results[result_key]
results_data[result_key] = {'value': result_value}
stderr_key = f'{result_key}_stderr'
if stderr_key in results:
results_data[result_key]['stderr'] = results[stderr_key]
else:
logger.warning(f'No stderr for {result_key} in {results}')
model_data[model][task] = {'version': task_version, 'results': results_data}
all_tasks.add(task)
except Exception:
logger.exception(f'Failed to parse {fn}')
continue
all_models = list(sorted(model_data.keys()))
table_data = []
for task in all_tasks:
try:
higher_is_better = lm_eval.tasks.get_task(task).higher_is_better(None)
except Exception:
logger.warning(f'Failed to get higher_is_better for {task}')
continue
if any(task not in model_data[model] for model in all_models):
logger.warning(f'No results for {task}')
continue
results = model_data[m1][task]['results']
results2 = model_data[m2][task]['results']
for metric in results:
result_value = results[metric]['value']
stderr_value = results[metric].get('stderr', 0.0)
result2_value = results2[metric]['value']
stderr2_value = results2[metric].get('stderr', 0.0)
significance = (result_value - result2_value) / ((stderr_value + stderr2_value + 1e-8) / 2)
if higher_is_better[metric]:
significance *= -1
if abs(significance) > 1:
significant = '+' if significance > 0 else '-'
else:
significant = ''
table_data.append([task, metric, result_value, stderr_value, result2_value, stderr2_value, significant])
table_str = tabulate(table_data, headers=['Task', 'Metric', m1, 'stderr', m2, 'stderr', 'Significant'], tablefmt='pipe')
print(table_str)
Path('./results.table.txt').write_text(table_str)

59
src/process_data.py Executable file
View File

@ -0,0 +1,59 @@
#!/usr/bin/env python3
import json
import bs4
from loguru import logger
import multiprocessing as mp
import tqdm
from absl import app, flags
import warnings
warnings.filterwarnings("ignore", category=bs4.MarkupResemblesLocatorWarning, module='bs4')
DATA_FN = '../tmp/pol_062016-112019_labeled.ndjson'
OUT_FN = '../tmp/kek.txt'
flags.DEFINE_string('data_fn', DATA_FN, 'data file')
flags.DEFINE_string('out_fn', OUT_FN, 'output file')
FLAGS = flags.FLAGS
# from here: https://gist.github.com/zmwangx/ad0830ba94b1fd98f428
def text_with_newlines(elem):
text = ''
for e in elem.descendants:
if isinstance(e, str):
# text += e.strip()
text += e
elif e.name == 'br' or e.name == 'p':
text += '\n'
return text
def parse_line(line):
data = json.loads(line)
posts_text = []
for post in data.get('posts', []):
try:
if 'com' in post:
soup = bs4.BeautifulSoup(post['com'], 'lxml')
post_text = text_with_newlines(soup).strip()
else:
post_text = ''
post_text = f'--- {post["no"]}\n{post_text}'
posts_text.append(post_text)
except Exception:
logger.exception(f'failed to parse post {post}')
return '\n'.join(posts_text)
def main(_):
with open(FLAGS.out_fn, 'w') as out_f:
with open(FLAGS.data_fn) as in_f:
with mp.Pool() as pool:
for parsed_line in pool.imap(parse_line, tqdm.tqdm(in_f)):
out_f.write(parsed_line + '\n-----\n')
if __name__ == '__main__':
app.run(main)

11
src/server/README.md Normal file
View File

@ -0,0 +1,11 @@
clone https://github.com/kingoflolz/mesh-transformer-jax and put this code inside
`model` is from https://github.com/okbuddyhololive/project-cybertard with slight changes
(you only need to do the above things if you want to run jax inference. for hugging face, it is not necessary)
then run `uvicorn --host 0.0.0.0 --port 8080 serve_api:app`
I use python 3.9.12 and install requirements.txt, then uninstall jax, jaxlib, tensorflow, and tensorflow-cpu
install `jax==0.2.12 jaxlib==0.1.67 tensorflow==2.5.0 markupsafe==2.0.1 uvicorn fastapi loguru`

View File

@ -0,0 +1,2 @@
from .inference import Inference
from .constants import ModelParams, InferConfig

View File

@ -0,0 +1,45 @@
import typing
from dataclasses import dataclass
import optax
BAD_WORDS = [] # Can't be part of config to avoid printing it
@dataclass
class InferConfig:
name: str = "Holotard"
prompt_length: int = 65536
token_length: int = 16
response_probability: float = 0.02
top_p: float = 1.0
min_temperature: float = 0.6
max_temperature: float = 1.2
max_same_replies: int = 2
same_reply_saved_messages: int = 6
max_response_retries: int = 3
@dataclass
class ModelParams:
layers: int = 28
d_model: int = 4096
n_heads: int = 16
n_vocab: int = 50400
norm: str = "layernorm"
pe: str = "rotary"
pe_rotary_dims: int = 64
seq: int = 2048
cores_per_replica: int = 8
per_replica_batch: int = 1
# batch size of 2 needs 200gb, 1 needs <16. wtf
optimizer: optax.chain = optax.chain(optax.adaptive_grad_clip(0.001), optax.centralize(),
optax.scale_by_adam(0.99, 0.999), optax.additive_weight_decay(1e-3),
optax.scale(-1e-5), )
sampler = None

View File

@ -0,0 +1,111 @@
import random
from typing import Any, Optional
from loguru import logger
import time
import jax
import numpy as np
import transformers
from jax import numpy as jnp
from jax.experimental import maps
from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer
from .constants import ModelParams, InferConfig
def default(value: Any, fallback: Any) -> Any:
# luke prefers making a function that chooses between `value` and `feedback` so i am gonna keep it
if value is None:
return fallback
return value
_cores_per_replica = ModelParams.cores_per_replica
_mesh_shape = (jax.device_count() // _cores_per_replica, _cores_per_replica)
_devices = np.array(jax.devices()).reshape(_mesh_shape)
#maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(_devices, ("dp", "mp")), ())
maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(_devices, ("dp", "mp")))
class Inference:
_NP_ONE = np.ones((1,))
def __init__(
self,
path: Optional[str] = None,
parameters: Optional[ModelParams] = None,
config: Optional[InferConfig] = None,
):
path = "checkpoint_slim/" if path is None else path
self.params = ModelParams() if parameters is None else parameters
self.params.sampler = nucleaus_sample
self.config = InferConfig() if config is None else config
self.model = CausalTransformer(self.params.__dict__)
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.model.state = read_ckpt_lowmem(
self.model.state, path, self.params.cores_per_replica, load_opt=False
)
def generate_tokens(
self,
prompt: np.ndarray,
length: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
) -> np.ndarray:
length = default(length, self.config.token_length)
top_p = default(top_p, self.config.top_p)
new_temp = random.random() * (self.config.max_temperature - self.config.min_temperature)
new_temp += self.config.min_temperature
temperature = default(temperature, new_temp)
#prompt = prompt[:, -2048:]
#prompt = prompt[:, -length:]
start_time = time.time()
source = jnp.array(
np.pad(
prompt,
(
(0, 0),
(self.params.seq - prompt.shape[1], 0),
),
)
)
logger.info(f"creating source took {time.time() - start_time}")
sampler_options = {
"top_p": self._NP_ONE * top_p,
"temp": self._NP_ONE * temperature,
}
start_time = time.time()
#with jax.experimental.maps.mesh(_devices, ("dp", "mp")):
logger.info(f"creating mesh took {time.time() - start_time}")
start_time = time.time()
out = self.model.generate(
source, self._NP_ONE * prompt.shape[1], length, sampler_options
)
logger.info(f"generate took {time.time() - start_time}")
#import IPython; IPython.embed()
return out[1][0][0, :, 0]
def generate(
self,
prompt: str,
length: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
) -> str:
inp_tokens = self.tokenizer([prompt], verbose=False, return_tensors="np")
inp_tokens = inp_tokens["input_ids"][0]
out_tokens = self.generate_tokens(
inp_tokens.reshape(1, -1), length, top_p, temperature
)
return self.tokenizer.decode(out_tokens)

View File

@ -0,0 +1,52 @@
import argparse
import json
import time
import jax
import numpy as np
import optax
from mesh_transformer import util
from mesh_transformer.checkpoint import read_ckpt, write_ckpt, read_ckpt_lowmem
from mesh_transformer.transformer_shard import CausalTransformer
from smart_open import open
from mesh_transformer.util import clip_by_global_norm, to_bf16, to_f16
from model.constants import ModelParams
if __name__ == "__main__":
params = ModelParams().__dict__
convert_fn = to_bf16
cores_per_replica = params["cores_per_replica"]
assert cores_per_replica <= 8
start = time.time()
print(f"jax devices: {jax.device_count()}")
print(f"jax runtime initialized in {time.time() - start:.06}s")
mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)
with jax.experimental.maps.mesh(devices, ("dp", "mp")):
network = CausalTransformer(params)
start = time.time()
network.state = read_ckpt(
network.state, f"checkpoint/", devices.shape[1], load_opt=False
)
print(f"network loaded in {time.time() - start:.06}s")
start = time.time()
del network.state["opt_state"]
network.state["params"] = convert_fn(network.state["params"])
print(f"network converted in {time.time() - start:.06}s")
suffix = "_slim"
for i in range(cores_per_replica):
write_ckpt(network.state, f"checkpoint_slim/", i)
print(f"written shard {i}")

199
src/server/serve_api.py Normal file
View File

@ -0,0 +1,199 @@
#!/usr/bin/env python3
from typing import Optional
import threading
import queue
import time
from loguru import logger
from pathlib import Path
import contextlib
import pydantic
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Settings(pydantic.BaseSettings):
queue_size: int = 1024
log_file: str = "logs/serve_api.log"
api_keys_file: str = 'valid_api_keys.txt'
hf_model: str = ''
hf_cuda: bool = False
pre_prompt_length: int = 512
settings = Settings()
def _check_api_key(key):
key = key.strip()
for line in Path(settings.api_keys_file).open():
if not line:
continue
valid_key = line.split()[0]
if key == valid_key:
break
else:
return False
return True
request_queue = queue.Queue(maxsize=settings.queue_size)
@contextlib.contextmanager
def jax_generation():
from model import inference
import jax
model = inference.Inference(path="../model_slim/step_88001/")
def _generate(request):
response = model.generate(
prompt=request.prompt,
length=request.length,
top_p=request.top_p,
temperature=request.temperature,
)
return response
with jax.experimental.maps.mesh(inference._devices, ("dp", "mp")):
yield _generate
@contextlib.contextmanager
def hf_generation():
from transformers import GPTJForCausalLM, AutoTokenizer
import torch
if settings.hf_cuda:
model = GPTJForCausalLM.from_pretrained(
settings.hf_model, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True
)
model.cuda()
else:
model = GPTJForCausalLM.from_pretrained( settings.hf_model, torch_dtype=torch.float32)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
def _generate(request: CompleteRequest):
input_ids = tokenizer(request.prompt, return_tensors="pt").input_ids
max_prompt_length = 2048 - request.length
input_ids = input_ids[:, -max_prompt_length:]
if request.pre_prompt:
pp_input_ids = tokenizer(request.pre_prompt, return_tensors="pt").input_ids
pp_input_ids = pp_input_ids[:, :settings.pre_prompt_length]
input_ids = input_ids[:, -(max_prompt_length-len(pp_input_ids)):]
full_prompt = tokenizer.batch_decode(pp_input_ids)[0] + tokenizer.batch_decode(input_ids)[0]
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
input_ids = input_ids[:, -max_prompt_length:]
if settings.hf_cuda:
input_ids = input_ids.cuda()
with torch.no_grad():
gen_tokens = model.generate(
input_ids,
do_sample=True,
temperature=request.temperature,
top_p=request.top_p,
typical_p=request.typical_p,
max_new_tokens=request.length,
).detach().cpu()
gen_text = tokenizer.batch_decode(gen_tokens)[0]
prompt_decoded = tokenizer.batch_decode(input_ids.detach().cpu())[0]
if not gen_text.startswith(prompt_decoded):
raise Exception(f"Generated text does not start with prompt: {gen_text}\n(prompt was {prompt_decoded})")
gen_text = gen_text[len(prompt_decoded):]
return gen_text
yield _generate
def worker():
if settings.hf_model:
generation = hf_generation
else:
generation = jax_generation
with generation() as generate_fn:
with open(settings.log_file, "a") as logf:
while True:
response_queue = None
try:
start_time = time.time()
(request, response_queue) = request_queue.get()
logger.info(f"getting request took {time.time() - start_time}")
start_time = time.time()
response = generate_fn(request)
logger.info(f"generate took {time.time() - start_time}, response length: {len(response)}")
start_time = time.time()
logf.write(f"##### {request.api_key} ##### {time.time()} #####\n")
logf.write(f"{request.pre_prompt}\n")
logf.write("###\n")
logf.write(f"{request.prompt}\n")
logf.write("#####\n")
logf.write(f"{response}\n\n")
logf.flush()
logger.info(f"writing log took {time.time() - start_time}")
start_time = time.time()
response_queue.put(response)
logger.info(f"putting response took {time.time() - start_time}")
except KeyboardInterrupt:
logger.info(f"Got KeyboardInterrupt... quitting!")
raise
except Exception:
logger.exception(f"Got exception, will continue")
if response_queue is not None:
response_queue.put("")
@app.get("/")
async def main():
return {"response": "Hello, world!"}
class CompleteRequest(pydantic.BaseModel):
prompt: pydantic.constr(min_length=0, max_length=2**14)
pre_prompt: pydantic.constr(min_length=0, max_length=2**14) = ''
api_key: pydantic.constr(min_length=1, max_length=128) = "x"*9
length: pydantic.conint(ge=1, le=1024) = 128
top_p: pydantic.confloat(ge=0.0, le=1.0) = 1.0
temperature: pydantic.confloat(ge=0.0) = 1.0
typical_p: pydantic.confloat(ge=0.0, le=1.0) = 1.0
def _enqueue(request: CompleteRequest):
response_queue = queue.Queue()
request_queue.put((request, response_queue))
response = response_queue.get()
return response
@app.on_event("startup")
def startup():
threading.Thread(
target=worker,
daemon=True,
).start()
_enqueue(CompleteRequest(prompt="hello"))
@app.post("/complete")
def complete(request: CompleteRequest):
logger.info(f"Received request from key {request.api_key}. Queue size is {request_queue.qsize()}")
if request_queue.full():
logger.warning("Request queue full.")
raise ValueError("Request queue full.")
if not _check_api_key(request.api_key):
logger.warning(f"api key not valid: {request.api_key}, discarding...")
raise ValueError("Invalid API key")
response = _enqueue(request)
return {"response": response}

65
src/txt_to_tfrecords.py Executable file
View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
from absl import flags, app
from loguru import logger
from pathlib import Path
import shutil
import functools
import tokenizers
import tensorflow as tf
import tqdm
flags.DEFINE_string('txt_fn', '../tmp/kek.txt', 'input txt')
flags.DEFINE_string('out_dir', '../tmp/tfrecords/', 'output directory (will be cleared)')
flags.DEFINE_integer('chunk_size', 2**24, 'how many tokens go into one tfrecords file')
flags.DEFINE_integer('read_buffer_size', 2**10, 'input file read buffer size')
FLAGS = flags.FLAGS
@functools.lru_cache(maxsize=1)
def get_tokenizer():
return tokenizers.Tokenizer.from_pretrained('gpt2')
def make_record_file(record_ids, out_dir, file_no):
out_fn = str(out_dir / f'tokens-{file_no:05d}.tfrecord')
with tf.io.TFRecordWriter(out_fn) as writer:
feature = {'text': tf.train.Feature(int64_list=tf.train.Int64List(value=record_ids))}
tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(tf_example.SerializeToString())
def read_in_blocks(f):
while True:
block = f.read(FLAGS.read_buffer_size)
if not block:
break
yield block
def main(_):
out_dir = Path(FLAGS.out_dir)
if out_dir.exists():
logger.warning(f'clearing {out_dir}')
shutil.rmtree(out_dir)
out_dir.mkdir(exist_ok=True)
tokenizer = get_tokenizer()
with open(FLAGS.txt_fn) as in_f:
current_ids = []
out_file_no = 0
for block in tqdm.tqdm(read_in_blocks(in_f)):
current_ids.extend(tokenizer.encode(block).ids)
while len(current_ids) >= FLAGS.chunk_size:
record_ids, current_ids = current_ids[:FLAGS.chunk_size], current_ids[FLAGS.chunk_size:]
make_record_file(record_ids, out_dir, out_file_no)
out_file_no += 1
if current_ids:
make_record_file(current_ids, out_dir, out_file_no)
if __name__ == "__main__":
app.run(main)