-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathsample.py
76 lines (52 loc) · 1.93 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from dotenv import load_dotenv
load_dotenv()
import click
import humanize
import jax
from jax import nn, random, jit, tree_util, numpy as np
from haiku import PRNGSequence
from progen_transformer import ProGen
from progen_transformer.data import decode_tokens, encode_tokens
from progen_transformer.utils import sample, set_hardware_rng_
from progen_transformer.checkpoint import get_checkpoint_fns
# speedup rng
set_hardware_rng_(jax)
# main functions
@click.command()
@click.option('--seed', default = 42)
@click.option('--checkpoint_path', default = './ckpts')
@click.option('--prime', default = '')
def main(
seed,
checkpoint_path,
prime,
):
# prepare folders
_, get_last_checkpoint, _ = get_checkpoint_fns(checkpoint_path)
last_checkpoint = get_last_checkpoint()
if last_checkpoint is None:
exit(f'no checkpoints found at {checkpoint_path}')
params = last_checkpoint['params']
num_seqs = max(last_checkpoint['next_seq_index'], 0)
# setup model and params
model_kwargs = last_checkpoint['model_config']
model = ProGen(**model_kwargs)
model_apply = jit(model.apply)
rng = PRNGSequence(seed)
# initialize all states, or load from checkpoint
seq_len = model_kwargs['seq_len']
num_params = tree_util.tree_reduce(lambda acc, el: acc + el.size, params, 0)
num_params_readable = humanize.naturalsize(num_params)
# print
print(f'params: {num_params_readable}')
print(f'sequence length: {seq_len}')
print(f'trained for {num_seqs} sequences')
# sample with prime
prime_tokens = encode_tokens(prime)
prime_length = len(prime_tokens) + 1
prime_tensor = np.array(prime_tokens, dtype = np.uint16)
sampled = sample(rng, jit(model_apply), params, prime_tensor, seq_len, top_k = 25, add_bos = True)
sampled_str = decode_tokens(sampled[prime_length:])
print("\n", prime, "\n", "*" * 40, "\n", sampled_str)
if __name__ == '__main__':
main()