46 lines
1.1 KiB
Python
46 lines
1.1 KiB
Python
import typing
|
|
from dataclasses import dataclass
|
|
|
|
import optax
|
|
|
|
BAD_WORDS = [] # Can't be part of config to avoid printing it
|
|
|
|
|
|
@dataclass
|
|
class InferConfig:
|
|
name: str = "Holotard"
|
|
prompt_length: int = 65536
|
|
token_length: int = 16
|
|
|
|
response_probability: float = 0.02
|
|
top_p: float = 1.0
|
|
|
|
min_temperature: float = 0.6
|
|
max_temperature: float = 1.2
|
|
|
|
max_same_replies: int = 2
|
|
same_reply_saved_messages: int = 6
|
|
max_response_retries: int = 3
|
|
|
|
|
|
@dataclass
|
|
class ModelParams:
|
|
layers: int = 28
|
|
d_model: int = 4096
|
|
n_heads: int = 16
|
|
n_vocab: int = 50400
|
|
|
|
norm: str = "layernorm"
|
|
pe: str = "rotary"
|
|
pe_rotary_dims: int = 64
|
|
|
|
seq: int = 2048
|
|
cores_per_replica: int = 8
|
|
per_replica_batch: int = 1
|
|
|
|
# batch size of 2 needs 200gb, 1 needs <16. wtf
|
|
optimizer: optax.chain = optax.chain(optax.adaptive_grad_clip(0.001), optax.centralize(),
|
|
optax.scale_by_adam(0.99, 0.999), optax.additive_weight_decay(1e-3),
|
|
optax.scale(-1e-5), )
|
|
sampler = None
|