gpt-4chan-public/src/server/model/inference.py

112 lines
3.7 KiB
Python

import random
from typing import Any, Optional
from loguru import logger
import time
import jax
import numpy as np
import transformers
from jax import numpy as jnp
from jax.experimental import maps
from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer
from .constants import ModelParams, InferConfig
def default(value: Any, fallback: Any) -> Any:
# luke prefers making a function that chooses between `value` and `feedback` so i am gonna keep it
if value is None:
return fallback
return value
_cores_per_replica = ModelParams.cores_per_replica
_mesh_shape = (jax.device_count() // _cores_per_replica, _cores_per_replica)
_devices = np.array(jax.devices()).reshape(_mesh_shape)
#maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(_devices, ("dp", "mp")), ())
maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(_devices, ("dp", "mp")))
class Inference:
_NP_ONE = np.ones((1,))
def __init__(
self,
path: Optional[str] = None,
parameters: Optional[ModelParams] = None,
config: Optional[InferConfig] = None,
):
path = "checkpoint_slim/" if path is None else path
self.params = ModelParams() if parameters is None else parameters
self.params.sampler = nucleaus_sample
self.config = InferConfig() if config is None else config
self.model = CausalTransformer(self.params.__dict__)
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.model.state = read_ckpt_lowmem(
self.model.state, path, self.params.cores_per_replica, load_opt=False
)
def generate_tokens(
self,
prompt: np.ndarray,
length: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
) -> np.ndarray:
length = default(length, self.config.token_length)
top_p = default(top_p, self.config.top_p)
new_temp = random.random() * (self.config.max_temperature - self.config.min_temperature)
new_temp += self.config.min_temperature
temperature = default(temperature, new_temp)
#prompt = prompt[:, -2048:]
#prompt = prompt[:, -length:]
start_time = time.time()
source = jnp.array(
np.pad(
prompt,
(
(0, 0),
(self.params.seq - prompt.shape[1], 0),
),
)
)
logger.info(f"creating source took {time.time() - start_time}")
sampler_options = {
"top_p": self._NP_ONE * top_p,
"temp": self._NP_ONE * temperature,
}
start_time = time.time()
#with jax.experimental.maps.mesh(_devices, ("dp", "mp")):
logger.info(f"creating mesh took {time.time() - start_time}")
start_time = time.time()
out = self.model.generate(
source, self._NP_ONE * prompt.shape[1], length, sampler_options
)
logger.info(f"generate took {time.time() - start_time}")
#import IPython; IPython.embed()
return out[1][0][0, :, 0]
def generate(
self,
prompt: str,
length: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
) -> str:
inp_tokens = self.tokenizer([prompt], verbose=False, return_tensors="np")
inp_tokens = inp_tokens["input_ids"][0]
out_tokens = self.generate_tokens(
inp_tokens.reshape(1, -1), length, top_p, temperature
)
return self.tokenizer.decode(out_tokens)