diff --git a/README.md b/README.md index 0564923..3158f80 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,9 @@ # gpt-4chan-public Code for GPT-4chan + +Note: This repository only contains helper code and small changes I made to other libraries. +The source code to the actual model is here at [https://github.com/kingoflolz/mesh-transformer-jax/](https://github.com/kingoflolz/mesh-transformer-jax/) + +Data here: [https://zenodo.org/record/3606810](https://zenodo.org/record/3606810) +Model here: [https://huggingface.co/ykilcher/gpt-4chan](https://huggingface.co/ykilcher/gpt-4chan) +Website here: [https://gpt-4chan.com](https://gpt-4chan.com) diff --git a/src/compute_metrics.py b/src/compute_metrics.py new file mode 100755 index 0000000..bb9dc12 --- /dev/null +++ b/src/compute_metrics.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 + +import json +from pathlib import Path +from loguru import logger +from tabulate import tabulate +import lm_eval.tasks + +m1 = 'GPT-J-6B' +m2 = 'GPT-4chan' + +log_dir = Path('./eval_logs') +all_tasks = set() +model_data = {} +for fn in log_dir.rglob('log_*.stdout.txt'): + try: + file_text = fn.read_text() + data = json.loads('{' + file_text.split('{', 1)[1].rsplit('}', 1)[0] + '}') + model = data['config']['model_args'].split('=')[1] + model = m2 if 'fp16' in model else m1 + if model not in model_data: + model_data[model] = {} + results = data['results'] + tasks = list(results.keys()) + assert len(tasks) == 1, 'Only one task supported' + task = tasks[0] + if task in model_data[model]: + raise ValueError(f'Duplicate task {task}') + task_version = data['versions'][task] + results = results[task] + results_data = {} + for result_key in results: + if result_key.endswith('_stderr'): + continue + result_value = results[result_key] + results_data[result_key] = {'value': result_value} + stderr_key = f'{result_key}_stderr' + if stderr_key in results: + results_data[result_key]['stderr'] = results[stderr_key] + else: + logger.warning(f'No stderr for {result_key} in {results}') + model_data[model][task] = {'version': task_version, 'results': results_data} + all_tasks.add(task) + except Exception: + logger.exception(f'Failed to parse {fn}') + continue + +all_models = list(sorted(model_data.keys())) +table_data = [] +for task in all_tasks: + try: + higher_is_better = lm_eval.tasks.get_task(task).higher_is_better(None) + except Exception: + logger.warning(f'Failed to get higher_is_better for {task}') + continue + if any(task not in model_data[model] for model in all_models): + logger.warning(f'No results for {task}') + continue + results = model_data[m1][task]['results'] + results2 = model_data[m2][task]['results'] + for metric in results: + result_value = results[metric]['value'] + stderr_value = results[metric].get('stderr', 0.0) + result2_value = results2[metric]['value'] + stderr2_value = results2[metric].get('stderr', 0.0) + significance = (result_value - result2_value) / ((stderr_value + stderr2_value + 1e-8) / 2) + if higher_is_better[metric]: + significance *= -1 + if abs(significance) > 1: + significant = '+' if significance > 0 else '-' + else: + significant = '' + table_data.append([task, metric, result_value, stderr_value, result2_value, stderr2_value, significant]) + +table_str = tabulate(table_data, headers=['Task', 'Metric', m1, 'stderr', m2, 'stderr', 'Significant'], tablefmt='pipe') +print(table_str) +Path('./results.table.txt').write_text(table_str) diff --git a/src/process_data.py b/src/process_data.py new file mode 100755 index 0000000..0618004 --- /dev/null +++ b/src/process_data.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 + +import json +import bs4 +from loguru import logger +import multiprocessing as mp +import tqdm +from absl import app, flags + +import warnings +warnings.filterwarnings("ignore", category=bs4.MarkupResemblesLocatorWarning, module='bs4') + +DATA_FN = '../tmp/pol_062016-112019_labeled.ndjson' +OUT_FN = '../tmp/kek.txt' + +flags.DEFINE_string('data_fn', DATA_FN, 'data file') +flags.DEFINE_string('out_fn', OUT_FN, 'output file') + +FLAGS = flags.FLAGS + +# from here: https://gist.github.com/zmwangx/ad0830ba94b1fd98f428 +def text_with_newlines(elem): + text = '' + for e in elem.descendants: + if isinstance(e, str): + # text += e.strip() + text += e + elif e.name == 'br' or e.name == 'p': + text += '\n' + return text + + +def parse_line(line): + data = json.loads(line) + posts_text = [] + for post in data.get('posts', []): + try: + if 'com' in post: + soup = bs4.BeautifulSoup(post['com'], 'lxml') + post_text = text_with_newlines(soup).strip() + else: + post_text = '' + post_text = f'--- {post["no"]}\n{post_text}' + posts_text.append(post_text) + except Exception: + logger.exception(f'failed to parse post {post}') + return '\n'.join(posts_text) + + +def main(_): + with open(FLAGS.out_fn, 'w') as out_f: + with open(FLAGS.data_fn) as in_f: + with mp.Pool() as pool: + for parsed_line in pool.imap(parse_line, tqdm.tqdm(in_f)): + out_f.write(parsed_line + '\n-----\n') + + +if __name__ == '__main__': + app.run(main) diff --git a/src/server/README.md b/src/server/README.md new file mode 100644 index 0000000..9530c1c --- /dev/null +++ b/src/server/README.md @@ -0,0 +1,11 @@ +clone https://github.com/kingoflolz/mesh-transformer-jax and put this code inside + +`model` is from https://github.com/okbuddyhololive/project-cybertard with slight changes + +(you only need to do the above things if you want to run jax inference. for hugging face, it is not necessary) + +then run `uvicorn --host 0.0.0.0 --port 8080 serve_api:app` + +I use python 3.9.12 and install requirements.txt, then uninstall jax, jaxlib, tensorflow, and tensorflow-cpu + +install `jax==0.2.12 jaxlib==0.1.67 tensorflow==2.5.0 markupsafe==2.0.1 uvicorn fastapi loguru` diff --git a/src/server/model/__init__.py b/src/server/model/__init__.py new file mode 100644 index 0000000..c1efd3a --- /dev/null +++ b/src/server/model/__init__.py @@ -0,0 +1,2 @@ +from .inference import Inference +from .constants import ModelParams, InferConfig diff --git a/src/server/model/constants.py b/src/server/model/constants.py new file mode 100644 index 0000000..69e22d9 --- /dev/null +++ b/src/server/model/constants.py @@ -0,0 +1,45 @@ +import typing +from dataclasses import dataclass + +import optax + +BAD_WORDS = [] # Can't be part of config to avoid printing it + + +@dataclass +class InferConfig: + name: str = "Holotard" + prompt_length: int = 65536 + token_length: int = 16 + + response_probability: float = 0.02 + top_p: float = 1.0 + + min_temperature: float = 0.6 + max_temperature: float = 1.2 + + max_same_replies: int = 2 + same_reply_saved_messages: int = 6 + max_response_retries: int = 3 + + +@dataclass +class ModelParams: + layers: int = 28 + d_model: int = 4096 + n_heads: int = 16 + n_vocab: int = 50400 + + norm: str = "layernorm" + pe: str = "rotary" + pe_rotary_dims: int = 64 + + seq: int = 2048 + cores_per_replica: int = 8 + per_replica_batch: int = 1 + + # batch size of 2 needs 200gb, 1 needs <16. wtf + optimizer: optax.chain = optax.chain(optax.adaptive_grad_clip(0.001), optax.centralize(), + optax.scale_by_adam(0.99, 0.999), optax.additive_weight_decay(1e-3), + optax.scale(-1e-5), ) + sampler = None diff --git a/src/server/model/inference.py b/src/server/model/inference.py new file mode 100644 index 0000000..6323ac6 --- /dev/null +++ b/src/server/model/inference.py @@ -0,0 +1,111 @@ +import random +from typing import Any, Optional +from loguru import logger +import time + +import jax +import numpy as np +import transformers +from jax import numpy as jnp +from jax.experimental import maps +from mesh_transformer.checkpoint import read_ckpt_lowmem +from mesh_transformer.sampling import nucleaus_sample +from mesh_transformer.transformer_shard import CausalTransformer + +from .constants import ModelParams, InferConfig + + +def default(value: Any, fallback: Any) -> Any: + # luke prefers making a function that chooses between `value` and `feedback` so i am gonna keep it + if value is None: + return fallback + + return value + + +_cores_per_replica = ModelParams.cores_per_replica +_mesh_shape = (jax.device_count() // _cores_per_replica, _cores_per_replica) +_devices = np.array(jax.devices()).reshape(_mesh_shape) + +#maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(_devices, ("dp", "mp")), ()) +maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(_devices, ("dp", "mp"))) + + +class Inference: + _NP_ONE = np.ones((1,)) + + def __init__( + self, + path: Optional[str] = None, + parameters: Optional[ModelParams] = None, + config: Optional[InferConfig] = None, + ): + path = "checkpoint_slim/" if path is None else path + + self.params = ModelParams() if parameters is None else parameters + self.params.sampler = nucleaus_sample + self.config = InferConfig() if config is None else config + + self.model = CausalTransformer(self.params.__dict__) + self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2") + self.model.state = read_ckpt_lowmem( + self.model.state, path, self.params.cores_per_replica, load_opt=False + ) + + def generate_tokens( + self, + prompt: np.ndarray, + length: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> np.ndarray: + length = default(length, self.config.token_length) + top_p = default(top_p, self.config.top_p) + new_temp = random.random() * (self.config.max_temperature - self.config.min_temperature) + new_temp += self.config.min_temperature + temperature = default(temperature, new_temp) + #prompt = prompt[:, -2048:] + #prompt = prompt[:, -length:] + + start_time = time.time() + source = jnp.array( + np.pad( + prompt, + ( + (0, 0), + (self.params.seq - prompt.shape[1], 0), + ), + ) + ) + logger.info(f"creating source took {time.time() - start_time}") + sampler_options = { + "top_p": self._NP_ONE * top_p, + "temp": self._NP_ONE * temperature, + } + + start_time = time.time() + #with jax.experimental.maps.mesh(_devices, ("dp", "mp")): + logger.info(f"creating mesh took {time.time() - start_time}") + start_time = time.time() + out = self.model.generate( + source, self._NP_ONE * prompt.shape[1], length, sampler_options + ) + logger.info(f"generate took {time.time() - start_time}") + + #import IPython; IPython.embed() + return out[1][0][0, :, 0] + + def generate( + self, + prompt: str, + length: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> str: + inp_tokens = self.tokenizer([prompt], verbose=False, return_tensors="np") + inp_tokens = inp_tokens["input_ids"][0] + out_tokens = self.generate_tokens( + inp_tokens.reshape(1, -1), length, top_p, temperature + ) + + return self.tokenizer.decode(out_tokens) diff --git a/src/server/model/to_slim_weights.py b/src/server/model/to_slim_weights.py new file mode 100644 index 0000000..028626a --- /dev/null +++ b/src/server/model/to_slim_weights.py @@ -0,0 +1,52 @@ +import argparse +import json +import time + +import jax +import numpy as np +import optax + +from mesh_transformer import util +from mesh_transformer.checkpoint import read_ckpt, write_ckpt, read_ckpt_lowmem +from mesh_transformer.transformer_shard import CausalTransformer +from smart_open import open + +from mesh_transformer.util import clip_by_global_norm, to_bf16, to_f16 +from model.constants import ModelParams + + +if __name__ == "__main__": + params = ModelParams().__dict__ + convert_fn = to_bf16 + + cores_per_replica = params["cores_per_replica"] + + assert cores_per_replica <= 8 + + start = time.time() + print(f"jax devices: {jax.device_count()}") + print(f"jax runtime initialized in {time.time() - start:.06}s") + + mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) + devices = np.array(jax.devices()).reshape(mesh_shape) + + with jax.experimental.maps.mesh(devices, ("dp", "mp")): + network = CausalTransformer(params) + + start = time.time() + network.state = read_ckpt( + network.state, f"checkpoint/", devices.shape[1], load_opt=False + ) + print(f"network loaded in {time.time() - start:.06}s") + + start = time.time() + del network.state["opt_state"] + + network.state["params"] = convert_fn(network.state["params"]) + print(f"network converted in {time.time() - start:.06}s") + + suffix = "_slim" + + for i in range(cores_per_replica): + write_ckpt(network.state, f"checkpoint_slim/", i) + print(f"written shard {i}") diff --git a/src/server/serve_api.py b/src/server/serve_api.py new file mode 100644 index 0000000..d1007f4 --- /dev/null +++ b/src/server/serve_api.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 + +from typing import Optional +import threading +import queue +import time +from loguru import logger +from pathlib import Path +import contextlib + +import pydantic +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +app = FastAPI() + +origins = ["*"] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + +class Settings(pydantic.BaseSettings): + queue_size: int = 1024 + log_file: str = "logs/serve_api.log" + api_keys_file: str = 'valid_api_keys.txt' + hf_model: str = '' + hf_cuda: bool = False + pre_prompt_length: int = 512 + + +settings = Settings() + +def _check_api_key(key): + key = key.strip() + for line in Path(settings.api_keys_file).open(): + if not line: + continue + valid_key = line.split()[0] + if key == valid_key: + break + else: + return False + return True + +request_queue = queue.Queue(maxsize=settings.queue_size) + +@contextlib.contextmanager +def jax_generation(): + from model import inference + import jax + model = inference.Inference(path="../model_slim/step_88001/") + + def _generate(request): + response = model.generate( + prompt=request.prompt, + length=request.length, + top_p=request.top_p, + temperature=request.temperature, + ) + return response + with jax.experimental.maps.mesh(inference._devices, ("dp", "mp")): + yield _generate + +@contextlib.contextmanager +def hf_generation(): + from transformers import GPTJForCausalLM, AutoTokenizer + import torch + + if settings.hf_cuda: + model = GPTJForCausalLM.from_pretrained( + settings.hf_model, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + model.cuda() + else: + model = GPTJForCausalLM.from_pretrained( settings.hf_model, torch_dtype=torch.float32) + model.eval() + + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + + def _generate(request: CompleteRequest): + input_ids = tokenizer(request.prompt, return_tensors="pt").input_ids + + max_prompt_length = 2048 - request.length + input_ids = input_ids[:, -max_prompt_length:] + + if request.pre_prompt: + pp_input_ids = tokenizer(request.pre_prompt, return_tensors="pt").input_ids + pp_input_ids = pp_input_ids[:, :settings.pre_prompt_length] + input_ids = input_ids[:, -(max_prompt_length-len(pp_input_ids)):] + full_prompt = tokenizer.batch_decode(pp_input_ids)[0] + tokenizer.batch_decode(input_ids)[0] + input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids + input_ids = input_ids[:, -max_prompt_length:] + + + if settings.hf_cuda: + input_ids = input_ids.cuda() + + with torch.no_grad(): + gen_tokens = model.generate( + input_ids, + do_sample=True, + temperature=request.temperature, + top_p=request.top_p, + typical_p=request.typical_p, + max_new_tokens=request.length, + ).detach().cpu() + gen_text = tokenizer.batch_decode(gen_tokens)[0] + prompt_decoded = tokenizer.batch_decode(input_ids.detach().cpu())[0] + if not gen_text.startswith(prompt_decoded): + raise Exception(f"Generated text does not start with prompt: {gen_text}\n(prompt was {prompt_decoded})") + gen_text = gen_text[len(prompt_decoded):] + return gen_text + yield _generate + +def worker(): + if settings.hf_model: + generation = hf_generation + else: + generation = jax_generation + with generation() as generate_fn: + with open(settings.log_file, "a") as logf: + while True: + response_queue = None + try: + start_time = time.time() + (request, response_queue) = request_queue.get() + logger.info(f"getting request took {time.time() - start_time}") + start_time = time.time() + response = generate_fn(request) + logger.info(f"generate took {time.time() - start_time}, response length: {len(response)}") + start_time = time.time() + + logf.write(f"##### {request.api_key} ##### {time.time()} #####\n") + logf.write(f"{request.pre_prompt}\n") + logf.write("###\n") + logf.write(f"{request.prompt}\n") + logf.write("#####\n") + logf.write(f"{response}\n\n") + logf.flush() + + logger.info(f"writing log took {time.time() - start_time}") + start_time = time.time() + response_queue.put(response) + logger.info(f"putting response took {time.time() - start_time}") + except KeyboardInterrupt: + logger.info(f"Got KeyboardInterrupt... quitting!") + raise + except Exception: + logger.exception(f"Got exception, will continue") + if response_queue is not None: + response_queue.put("") + + + +@app.get("/") +async def main(): + return {"response": "Hello, world!"} + +class CompleteRequest(pydantic.BaseModel): + prompt: pydantic.constr(min_length=0, max_length=2**14) + pre_prompt: pydantic.constr(min_length=0, max_length=2**14) = '' + api_key: pydantic.constr(min_length=1, max_length=128) = "x"*9 + length: pydantic.conint(ge=1, le=1024) = 128 + top_p: pydantic.confloat(ge=0.0, le=1.0) = 1.0 + temperature: pydantic.confloat(ge=0.0) = 1.0 + typical_p: pydantic.confloat(ge=0.0, le=1.0) = 1.0 + +def _enqueue(request: CompleteRequest): + response_queue = queue.Queue() + request_queue.put((request, response_queue)) + response = response_queue.get() + return response + + +@app.on_event("startup") +def startup(): + threading.Thread( + target=worker, + daemon=True, + ).start() + _enqueue(CompleteRequest(prompt="hello")) + +@app.post("/complete") +def complete(request: CompleteRequest): + logger.info(f"Received request from key {request.api_key}. Queue size is {request_queue.qsize()}") + if request_queue.full(): + logger.warning("Request queue full.") + raise ValueError("Request queue full.") + if not _check_api_key(request.api_key): + logger.warning(f"api key not valid: {request.api_key}, discarding...") + raise ValueError("Invalid API key") + response = _enqueue(request) + return {"response": response} diff --git a/src/txt_to_tfrecords.py b/src/txt_to_tfrecords.py new file mode 100755 index 0000000..ef17a77 --- /dev/null +++ b/src/txt_to_tfrecords.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 + +from absl import flags, app + +from loguru import logger + +from pathlib import Path +import shutil +import functools + +import tokenizers +import tensorflow as tf +import tqdm + +flags.DEFINE_string('txt_fn', '../tmp/kek.txt', 'input txt') +flags.DEFINE_string('out_dir', '../tmp/tfrecords/', 'output directory (will be cleared)') +flags.DEFINE_integer('chunk_size', 2**24, 'how many tokens go into one tfrecords file') +flags.DEFINE_integer('read_buffer_size', 2**10, 'input file read buffer size') + +FLAGS = flags.FLAGS + + +@functools.lru_cache(maxsize=1) +def get_tokenizer(): + return tokenizers.Tokenizer.from_pretrained('gpt2') + + +def make_record_file(record_ids, out_dir, file_no): + out_fn = str(out_dir / f'tokens-{file_no:05d}.tfrecord') + with tf.io.TFRecordWriter(out_fn) as writer: + feature = {'text': tf.train.Feature(int64_list=tf.train.Int64List(value=record_ids))} + tf_example = tf.train.Example(features=tf.train.Features(feature=feature)) + writer.write(tf_example.SerializeToString()) + + +def read_in_blocks(f): + while True: + block = f.read(FLAGS.read_buffer_size) + if not block: + break + yield block + +def main(_): + out_dir = Path(FLAGS.out_dir) + if out_dir.exists(): + logger.warning(f'clearing {out_dir}') + shutil.rmtree(out_dir) + out_dir.mkdir(exist_ok=True) + tokenizer = get_tokenizer() + with open(FLAGS.txt_fn) as in_f: + current_ids = [] + out_file_no = 0 + for block in tqdm.tqdm(read_in_blocks(in_f)): + current_ids.extend(tokenizer.encode(block).ids) + while len(current_ids) >= FLAGS.chunk_size: + record_ids, current_ids = current_ids[:FLAGS.chunk_size], current_ids[FLAGS.chunk_size:] + make_record_file(record_ids, out_dir, out_file_no) + out_file_no += 1 + + if current_ids: + make_record_file(current_ids, out_dir, out_file_no) + + +if __name__ == "__main__": + app.run(main)