gpt-4chan-public/src/txt_to_tfrecords.py

66 lines
1.9 KiB
Python
Executable File

#!/usr/bin/env python3
from absl import flags, app
from loguru import logger
from pathlib import Path
import shutil
import functools
import tokenizers
import tensorflow as tf
import tqdm
flags.DEFINE_string('txt_fn', '../tmp/kek.txt', 'input txt')
flags.DEFINE_string('out_dir', '../tmp/tfrecords/', 'output directory (will be cleared)')
flags.DEFINE_integer('chunk_size', 2**24, 'how many tokens go into one tfrecords file')
flags.DEFINE_integer('read_buffer_size', 2**10, 'input file read buffer size')
FLAGS = flags.FLAGS
@functools.lru_cache(maxsize=1)
def get_tokenizer():
return tokenizers.Tokenizer.from_pretrained('gpt2')
def make_record_file(record_ids, out_dir, file_no):
out_fn = str(out_dir / f'tokens-{file_no:05d}.tfrecord')
with tf.io.TFRecordWriter(out_fn) as writer:
feature = {'text': tf.train.Feature(int64_list=tf.train.Int64List(value=record_ids))}
tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(tf_example.SerializeToString())
def read_in_blocks(f):
while True:
block = f.read(FLAGS.read_buffer_size)
if not block:
break
yield block
def main(_):
out_dir = Path(FLAGS.out_dir)
if out_dir.exists():
logger.warning(f'clearing {out_dir}')
shutil.rmtree(out_dir)
out_dir.mkdir(exist_ok=True)
tokenizer = get_tokenizer()
with open(FLAGS.txt_fn) as in_f:
current_ids = []
out_file_no = 0
for block in tqdm.tqdm(read_in_blocks(in_f)):
current_ids.extend(tokenizer.encode(block).ids)
while len(current_ids) >= FLAGS.chunk_size:
record_ids, current_ids = current_ids[:FLAGS.chunk_size], current_ids[FLAGS.chunk_size:]
make_record_file(record_ids, out_dir, out_file_no)
out_file_no += 1
if current_ids:
make_record_file(current_ids, out_dir, out_file_no)
if __name__ == "__main__":
app.run(main)