mirror of
https://github.com/yk/gpt-4chan-public.git
synced 2024-12-21 18:30:04 +00:00
initial commit
This commit is contained in:
parent
0538b0e31e
commit
b1089a2d89
|
@ -1,2 +1,9 @@
|
|||
# gpt-4chan-public
|
||||
Code for GPT-4chan
|
||||
|
||||
Note: This repository only contains helper code and small changes I made to other libraries.
|
||||
The source code to the actual model is here at [https://github.com/kingoflolz/mesh-transformer-jax/](https://github.com/kingoflolz/mesh-transformer-jax/)
|
||||
|
||||
Data here: [https://zenodo.org/record/3606810](https://zenodo.org/record/3606810)
|
||||
Model here: [https://huggingface.co/ykilcher/gpt-4chan](https://huggingface.co/ykilcher/gpt-4chan)
|
||||
Website here: [https://gpt-4chan.com](https://gpt-4chan.com)
|
||||
|
|
77
src/compute_metrics.py
Executable file
77
src/compute_metrics.py
Executable file
|
@ -0,0 +1,77 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
import lm_eval.tasks
|
||||
|
||||
m1 = 'GPT-J-6B'
|
||||
m2 = 'GPT-4chan'
|
||||
|
||||
log_dir = Path('./eval_logs')
|
||||
all_tasks = set()
|
||||
model_data = {}
|
||||
for fn in log_dir.rglob('log_*.stdout.txt'):
|
||||
try:
|
||||
file_text = fn.read_text()
|
||||
data = json.loads('{' + file_text.split('{', 1)[1].rsplit('}', 1)[0] + '}')
|
||||
model = data['config']['model_args'].split('=')[1]
|
||||
model = m2 if 'fp16' in model else m1
|
||||
if model not in model_data:
|
||||
model_data[model] = {}
|
||||
results = data['results']
|
||||
tasks = list(results.keys())
|
||||
assert len(tasks) == 1, 'Only one task supported'
|
||||
task = tasks[0]
|
||||
if task in model_data[model]:
|
||||
raise ValueError(f'Duplicate task {task}')
|
||||
task_version = data['versions'][task]
|
||||
results = results[task]
|
||||
results_data = {}
|
||||
for result_key in results:
|
||||
if result_key.endswith('_stderr'):
|
||||
continue
|
||||
result_value = results[result_key]
|
||||
results_data[result_key] = {'value': result_value}
|
||||
stderr_key = f'{result_key}_stderr'
|
||||
if stderr_key in results:
|
||||
results_data[result_key]['stderr'] = results[stderr_key]
|
||||
else:
|
||||
logger.warning(f'No stderr for {result_key} in {results}')
|
||||
model_data[model][task] = {'version': task_version, 'results': results_data}
|
||||
all_tasks.add(task)
|
||||
except Exception:
|
||||
logger.exception(f'Failed to parse {fn}')
|
||||
continue
|
||||
|
||||
all_models = list(sorted(model_data.keys()))
|
||||
table_data = []
|
||||
for task in all_tasks:
|
||||
try:
|
||||
higher_is_better = lm_eval.tasks.get_task(task).higher_is_better(None)
|
||||
except Exception:
|
||||
logger.warning(f'Failed to get higher_is_better for {task}')
|
||||
continue
|
||||
if any(task not in model_data[model] for model in all_models):
|
||||
logger.warning(f'No results for {task}')
|
||||
continue
|
||||
results = model_data[m1][task]['results']
|
||||
results2 = model_data[m2][task]['results']
|
||||
for metric in results:
|
||||
result_value = results[metric]['value']
|
||||
stderr_value = results[metric].get('stderr', 0.0)
|
||||
result2_value = results2[metric]['value']
|
||||
stderr2_value = results2[metric].get('stderr', 0.0)
|
||||
significance = (result_value - result2_value) / ((stderr_value + stderr2_value + 1e-8) / 2)
|
||||
if higher_is_better[metric]:
|
||||
significance *= -1
|
||||
if abs(significance) > 1:
|
||||
significant = '+' if significance > 0 else '-'
|
||||
else:
|
||||
significant = ''
|
||||
table_data.append([task, metric, result_value, stderr_value, result2_value, stderr2_value, significant])
|
||||
|
||||
table_str = tabulate(table_data, headers=['Task', 'Metric', m1, 'stderr', m2, 'stderr', 'Significant'], tablefmt='pipe')
|
||||
print(table_str)
|
||||
Path('./results.table.txt').write_text(table_str)
|
59
src/process_data.py
Executable file
59
src/process_data.py
Executable file
|
@ -0,0 +1,59 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import bs4
|
||||
from loguru import logger
|
||||
import multiprocessing as mp
|
||||
import tqdm
|
||||
from absl import app, flags
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=bs4.MarkupResemblesLocatorWarning, module='bs4')
|
||||
|
||||
DATA_FN = '../tmp/pol_062016-112019_labeled.ndjson'
|
||||
OUT_FN = '../tmp/kek.txt'
|
||||
|
||||
flags.DEFINE_string('data_fn', DATA_FN, 'data file')
|
||||
flags.DEFINE_string('out_fn', OUT_FN, 'output file')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# from here: https://gist.github.com/zmwangx/ad0830ba94b1fd98f428
|
||||
def text_with_newlines(elem):
|
||||
text = ''
|
||||
for e in elem.descendants:
|
||||
if isinstance(e, str):
|
||||
# text += e.strip()
|
||||
text += e
|
||||
elif e.name == 'br' or e.name == 'p':
|
||||
text += '\n'
|
||||
return text
|
||||
|
||||
|
||||
def parse_line(line):
|
||||
data = json.loads(line)
|
||||
posts_text = []
|
||||
for post in data.get('posts', []):
|
||||
try:
|
||||
if 'com' in post:
|
||||
soup = bs4.BeautifulSoup(post['com'], 'lxml')
|
||||
post_text = text_with_newlines(soup).strip()
|
||||
else:
|
||||
post_text = ''
|
||||
post_text = f'--- {post["no"]}\n{post_text}'
|
||||
posts_text.append(post_text)
|
||||
except Exception:
|
||||
logger.exception(f'failed to parse post {post}')
|
||||
return '\n'.join(posts_text)
|
||||
|
||||
|
||||
def main(_):
|
||||
with open(FLAGS.out_fn, 'w') as out_f:
|
||||
with open(FLAGS.data_fn) as in_f:
|
||||
with mp.Pool() as pool:
|
||||
for parsed_line in pool.imap(parse_line, tqdm.tqdm(in_f)):
|
||||
out_f.write(parsed_line + '\n-----\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
11
src/server/README.md
Normal file
11
src/server/README.md
Normal file
|
@ -0,0 +1,11 @@
|
|||
clone https://github.com/kingoflolz/mesh-transformer-jax and put this code inside
|
||||
|
||||
`model` is from https://github.com/okbuddyhololive/project-cybertard with slight changes
|
||||
|
||||
(you only need to do the above things if you want to run jax inference. for hugging face, it is not necessary)
|
||||
|
||||
then run `uvicorn --host 0.0.0.0 --port 8080 serve_api:app`
|
||||
|
||||
I use python 3.9.12 and install requirements.txt, then uninstall jax, jaxlib, tensorflow, and tensorflow-cpu
|
||||
|
||||
install `jax==0.2.12 jaxlib==0.1.67 tensorflow==2.5.0 markupsafe==2.0.1 uvicorn fastapi loguru`
|
2
src/server/model/__init__.py
Normal file
2
src/server/model/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .inference import Inference
|
||||
from .constants import ModelParams, InferConfig
|
45
src/server/model/constants.py
Normal file
45
src/server/model/constants.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
import typing
|
||||
from dataclasses import dataclass
|
||||
|
||||
import optax
|
||||
|
||||
BAD_WORDS = [] # Can't be part of config to avoid printing it
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferConfig:
|
||||
name: str = "Holotard"
|
||||
prompt_length: int = 65536
|
||||
token_length: int = 16
|
||||
|
||||
response_probability: float = 0.02
|
||||
top_p: float = 1.0
|
||||
|
||||
min_temperature: float = 0.6
|
||||
max_temperature: float = 1.2
|
||||
|
||||
max_same_replies: int = 2
|
||||
same_reply_saved_messages: int = 6
|
||||
max_response_retries: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelParams:
|
||||
layers: int = 28
|
||||
d_model: int = 4096
|
||||
n_heads: int = 16
|
||||
n_vocab: int = 50400
|
||||
|
||||
norm: str = "layernorm"
|
||||
pe: str = "rotary"
|
||||
pe_rotary_dims: int = 64
|
||||
|
||||
seq: int = 2048
|
||||
cores_per_replica: int = 8
|
||||
per_replica_batch: int = 1
|
||||
|
||||
# batch size of 2 needs 200gb, 1 needs <16. wtf
|
||||
optimizer: optax.chain = optax.chain(optax.adaptive_grad_clip(0.001), optax.centralize(),
|
||||
optax.scale_by_adam(0.99, 0.999), optax.additive_weight_decay(1e-3),
|
||||
optax.scale(-1e-5), )
|
||||
sampler = None
|
111
src/server/model/inference.py
Normal file
111
src/server/model/inference.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
import random
|
||||
from typing import Any, Optional
|
||||
from loguru import logger
|
||||
import time
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import transformers
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import maps
|
||||
from mesh_transformer.checkpoint import read_ckpt_lowmem
|
||||
from mesh_transformer.sampling import nucleaus_sample
|
||||
from mesh_transformer.transformer_shard import CausalTransformer
|
||||
|
||||
from .constants import ModelParams, InferConfig
|
||||
|
||||
|
||||
def default(value: Any, fallback: Any) -> Any:
|
||||
# luke prefers making a function that chooses between `value` and `feedback` so i am gonna keep it
|
||||
if value is None:
|
||||
return fallback
|
||||
|
||||
return value
|
||||
|
||||
|
||||
_cores_per_replica = ModelParams.cores_per_replica
|
||||
_mesh_shape = (jax.device_count() // _cores_per_replica, _cores_per_replica)
|
||||
_devices = np.array(jax.devices()).reshape(_mesh_shape)
|
||||
|
||||
#maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(_devices, ("dp", "mp")), ())
|
||||
maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(_devices, ("dp", "mp")))
|
||||
|
||||
|
||||
class Inference:
|
||||
_NP_ONE = np.ones((1,))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: Optional[str] = None,
|
||||
parameters: Optional[ModelParams] = None,
|
||||
config: Optional[InferConfig] = None,
|
||||
):
|
||||
path = "checkpoint_slim/" if path is None else path
|
||||
|
||||
self.params = ModelParams() if parameters is None else parameters
|
||||
self.params.sampler = nucleaus_sample
|
||||
self.config = InferConfig() if config is None else config
|
||||
|
||||
self.model = CausalTransformer(self.params.__dict__)
|
||||
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
self.model.state = read_ckpt_lowmem(
|
||||
self.model.state, path, self.params.cores_per_replica, load_opt=False
|
||||
)
|
||||
|
||||
def generate_tokens(
|
||||
self,
|
||||
prompt: np.ndarray,
|
||||
length: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
) -> np.ndarray:
|
||||
length = default(length, self.config.token_length)
|
||||
top_p = default(top_p, self.config.top_p)
|
||||
new_temp = random.random() * (self.config.max_temperature - self.config.min_temperature)
|
||||
new_temp += self.config.min_temperature
|
||||
temperature = default(temperature, new_temp)
|
||||
#prompt = prompt[:, -2048:]
|
||||
#prompt = prompt[:, -length:]
|
||||
|
||||
start_time = time.time()
|
||||
source = jnp.array(
|
||||
np.pad(
|
||||
prompt,
|
||||
(
|
||||
(0, 0),
|
||||
(self.params.seq - prompt.shape[1], 0),
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.info(f"creating source took {time.time() - start_time}")
|
||||
sampler_options = {
|
||||
"top_p": self._NP_ONE * top_p,
|
||||
"temp": self._NP_ONE * temperature,
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
#with jax.experimental.maps.mesh(_devices, ("dp", "mp")):
|
||||
logger.info(f"creating mesh took {time.time() - start_time}")
|
||||
start_time = time.time()
|
||||
out = self.model.generate(
|
||||
source, self._NP_ONE * prompt.shape[1], length, sampler_options
|
||||
)
|
||||
logger.info(f"generate took {time.time() - start_time}")
|
||||
|
||||
#import IPython; IPython.embed()
|
||||
return out[1][0][0, :, 0]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
length: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
) -> str:
|
||||
inp_tokens = self.tokenizer([prompt], verbose=False, return_tensors="np")
|
||||
inp_tokens = inp_tokens["input_ids"][0]
|
||||
out_tokens = self.generate_tokens(
|
||||
inp_tokens.reshape(1, -1), length, top_p, temperature
|
||||
)
|
||||
|
||||
return self.tokenizer.decode(out_tokens)
|
52
src/server/model/to_slim_weights.py
Normal file
52
src/server/model/to_slim_weights.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import optax
|
||||
|
||||
from mesh_transformer import util
|
||||
from mesh_transformer.checkpoint import read_ckpt, write_ckpt, read_ckpt_lowmem
|
||||
from mesh_transformer.transformer_shard import CausalTransformer
|
||||
from smart_open import open
|
||||
|
||||
from mesh_transformer.util import clip_by_global_norm, to_bf16, to_f16
|
||||
from model.constants import ModelParams
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
params = ModelParams().__dict__
|
||||
convert_fn = to_bf16
|
||||
|
||||
cores_per_replica = params["cores_per_replica"]
|
||||
|
||||
assert cores_per_replica <= 8
|
||||
|
||||
start = time.time()
|
||||
print(f"jax devices: {jax.device_count()}")
|
||||
print(f"jax runtime initialized in {time.time() - start:.06}s")
|
||||
|
||||
mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
|
||||
devices = np.array(jax.devices()).reshape(mesh_shape)
|
||||
|
||||
with jax.experimental.maps.mesh(devices, ("dp", "mp")):
|
||||
network = CausalTransformer(params)
|
||||
|
||||
start = time.time()
|
||||
network.state = read_ckpt(
|
||||
network.state, f"checkpoint/", devices.shape[1], load_opt=False
|
||||
)
|
||||
print(f"network loaded in {time.time() - start:.06}s")
|
||||
|
||||
start = time.time()
|
||||
del network.state["opt_state"]
|
||||
|
||||
network.state["params"] = convert_fn(network.state["params"])
|
||||
print(f"network converted in {time.time() - start:.06}s")
|
||||
|
||||
suffix = "_slim"
|
||||
|
||||
for i in range(cores_per_replica):
|
||||
write_ckpt(network.state, f"checkpoint_slim/", i)
|
||||
print(f"written shard {i}")
|
199
src/server/serve_api.py
Normal file
199
src/server/serve_api.py
Normal file
|
@ -0,0 +1,199 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from typing import Optional
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
import contextlib
|
||||
|
||||
import pydantic
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
origins = ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class Settings(pydantic.BaseSettings):
|
||||
queue_size: int = 1024
|
||||
log_file: str = "logs/serve_api.log"
|
||||
api_keys_file: str = 'valid_api_keys.txt'
|
||||
hf_model: str = ''
|
||||
hf_cuda: bool = False
|
||||
pre_prompt_length: int = 512
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
def _check_api_key(key):
|
||||
key = key.strip()
|
||||
for line in Path(settings.api_keys_file).open():
|
||||
if not line:
|
||||
continue
|
||||
valid_key = line.split()[0]
|
||||
if key == valid_key:
|
||||
break
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
request_queue = queue.Queue(maxsize=settings.queue_size)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def jax_generation():
|
||||
from model import inference
|
||||
import jax
|
||||
model = inference.Inference(path="../model_slim/step_88001/")
|
||||
|
||||
def _generate(request):
|
||||
response = model.generate(
|
||||
prompt=request.prompt,
|
||||
length=request.length,
|
||||
top_p=request.top_p,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
return response
|
||||
with jax.experimental.maps.mesh(inference._devices, ("dp", "mp")):
|
||||
yield _generate
|
||||
|
||||
@contextlib.contextmanager
|
||||
def hf_generation():
|
||||
from transformers import GPTJForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
if settings.hf_cuda:
|
||||
model = GPTJForCausalLM.from_pretrained(
|
||||
settings.hf_model, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
)
|
||||
model.cuda()
|
||||
else:
|
||||
model = GPTJForCausalLM.from_pretrained( settings.hf_model, torch_dtype=torch.float32)
|
||||
model.eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
||||
|
||||
def _generate(request: CompleteRequest):
|
||||
input_ids = tokenizer(request.prompt, return_tensors="pt").input_ids
|
||||
|
||||
max_prompt_length = 2048 - request.length
|
||||
input_ids = input_ids[:, -max_prompt_length:]
|
||||
|
||||
if request.pre_prompt:
|
||||
pp_input_ids = tokenizer(request.pre_prompt, return_tensors="pt").input_ids
|
||||
pp_input_ids = pp_input_ids[:, :settings.pre_prompt_length]
|
||||
input_ids = input_ids[:, -(max_prompt_length-len(pp_input_ids)):]
|
||||
full_prompt = tokenizer.batch_decode(pp_input_ids)[0] + tokenizer.batch_decode(input_ids)[0]
|
||||
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
|
||||
input_ids = input_ids[:, -max_prompt_length:]
|
||||
|
||||
|
||||
if settings.hf_cuda:
|
||||
input_ids = input_ids.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
gen_tokens = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
typical_p=request.typical_p,
|
||||
max_new_tokens=request.length,
|
||||
).detach().cpu()
|
||||
gen_text = tokenizer.batch_decode(gen_tokens)[0]
|
||||
prompt_decoded = tokenizer.batch_decode(input_ids.detach().cpu())[0]
|
||||
if not gen_text.startswith(prompt_decoded):
|
||||
raise Exception(f"Generated text does not start with prompt: {gen_text}\n(prompt was {prompt_decoded})")
|
||||
gen_text = gen_text[len(prompt_decoded):]
|
||||
return gen_text
|
||||
yield _generate
|
||||
|
||||
def worker():
|
||||
if settings.hf_model:
|
||||
generation = hf_generation
|
||||
else:
|
||||
generation = jax_generation
|
||||
with generation() as generate_fn:
|
||||
with open(settings.log_file, "a") as logf:
|
||||
while True:
|
||||
response_queue = None
|
||||
try:
|
||||
start_time = time.time()
|
||||
(request, response_queue) = request_queue.get()
|
||||
logger.info(f"getting request took {time.time() - start_time}")
|
||||
start_time = time.time()
|
||||
response = generate_fn(request)
|
||||
logger.info(f"generate took {time.time() - start_time}, response length: {len(response)}")
|
||||
start_time = time.time()
|
||||
|
||||
logf.write(f"##### {request.api_key} ##### {time.time()} #####\n")
|
||||
logf.write(f"{request.pre_prompt}\n")
|
||||
logf.write("###\n")
|
||||
logf.write(f"{request.prompt}\n")
|
||||
logf.write("#####\n")
|
||||
logf.write(f"{response}\n\n")
|
||||
logf.flush()
|
||||
|
||||
logger.info(f"writing log took {time.time() - start_time}")
|
||||
start_time = time.time()
|
||||
response_queue.put(response)
|
||||
logger.info(f"putting response took {time.time() - start_time}")
|
||||
except KeyboardInterrupt:
|
||||
logger.info(f"Got KeyboardInterrupt... quitting!")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Got exception, will continue")
|
||||
if response_queue is not None:
|
||||
response_queue.put("")
|
||||
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def main():
|
||||
return {"response": "Hello, world!"}
|
||||
|
||||
class CompleteRequest(pydantic.BaseModel):
|
||||
prompt: pydantic.constr(min_length=0, max_length=2**14)
|
||||
pre_prompt: pydantic.constr(min_length=0, max_length=2**14) = ''
|
||||
api_key: pydantic.constr(min_length=1, max_length=128) = "x"*9
|
||||
length: pydantic.conint(ge=1, le=1024) = 128
|
||||
top_p: pydantic.confloat(ge=0.0, le=1.0) = 1.0
|
||||
temperature: pydantic.confloat(ge=0.0) = 1.0
|
||||
typical_p: pydantic.confloat(ge=0.0, le=1.0) = 1.0
|
||||
|
||||
def _enqueue(request: CompleteRequest):
|
||||
response_queue = queue.Queue()
|
||||
request_queue.put((request, response_queue))
|
||||
response = response_queue.get()
|
||||
return response
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup():
|
||||
threading.Thread(
|
||||
target=worker,
|
||||
daemon=True,
|
||||
).start()
|
||||
_enqueue(CompleteRequest(prompt="hello"))
|
||||
|
||||
@app.post("/complete")
|
||||
def complete(request: CompleteRequest):
|
||||
logger.info(f"Received request from key {request.api_key}. Queue size is {request_queue.qsize()}")
|
||||
if request_queue.full():
|
||||
logger.warning("Request queue full.")
|
||||
raise ValueError("Request queue full.")
|
||||
if not _check_api_key(request.api_key):
|
||||
logger.warning(f"api key not valid: {request.api_key}, discarding...")
|
||||
raise ValueError("Invalid API key")
|
||||
response = _enqueue(request)
|
||||
return {"response": response}
|
65
src/txt_to_tfrecords.py
Executable file
65
src/txt_to_tfrecords.py
Executable file
|
@ -0,0 +1,65 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from absl import flags, app
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import functools
|
||||
|
||||
import tokenizers
|
||||
import tensorflow as tf
|
||||
import tqdm
|
||||
|
||||
flags.DEFINE_string('txt_fn', '../tmp/kek.txt', 'input txt')
|
||||
flags.DEFINE_string('out_dir', '../tmp/tfrecords/', 'output directory (will be cleared)')
|
||||
flags.DEFINE_integer('chunk_size', 2**24, 'how many tokens go into one tfrecords file')
|
||||
flags.DEFINE_integer('read_buffer_size', 2**10, 'input file read buffer size')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def get_tokenizer():
|
||||
return tokenizers.Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
|
||||
def make_record_file(record_ids, out_dir, file_no):
|
||||
out_fn = str(out_dir / f'tokens-{file_no:05d}.tfrecord')
|
||||
with tf.io.TFRecordWriter(out_fn) as writer:
|
||||
feature = {'text': tf.train.Feature(int64_list=tf.train.Int64List(value=record_ids))}
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
|
||||
writer.write(tf_example.SerializeToString())
|
||||
|
||||
|
||||
def read_in_blocks(f):
|
||||
while True:
|
||||
block = f.read(FLAGS.read_buffer_size)
|
||||
if not block:
|
||||
break
|
||||
yield block
|
||||
|
||||
def main(_):
|
||||
out_dir = Path(FLAGS.out_dir)
|
||||
if out_dir.exists():
|
||||
logger.warning(f'clearing {out_dir}')
|
||||
shutil.rmtree(out_dir)
|
||||
out_dir.mkdir(exist_ok=True)
|
||||
tokenizer = get_tokenizer()
|
||||
with open(FLAGS.txt_fn) as in_f:
|
||||
current_ids = []
|
||||
out_file_no = 0
|
||||
for block in tqdm.tqdm(read_in_blocks(in_f)):
|
||||
current_ids.extend(tokenizer.encode(block).ids)
|
||||
while len(current_ids) >= FLAGS.chunk_size:
|
||||
record_ids, current_ids = current_ids[:FLAGS.chunk_size], current_ids[FLAGS.chunk_size:]
|
||||
make_record_file(record_ids, out_dir, out_file_no)
|
||||
out_file_no += 1
|
||||
|
||||
if current_ids:
|
||||
make_record_file(current_ids, out_dir, out_file_no)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
Loading…
Reference in a new issue