Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]add 512 token prefill #7

Open
wants to merge 2 commits into
base: prefill
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 126 additions & 56 deletions src/runner/LLM.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ struct LLMAttrType

// std::string template_prefill_filename_axmodel = "minicpmv/prefill_axmodel/minicpm_p96_l%d.axmodel";
// int prefill_axmodel_num = 40;
int prefill_token_num = 96; // auto calc
int prefill_token_num = 128; // auto calc
int prefill_max_token_num = 512;

std::string filename_post_axmodel = "tinyllama-int8/tinyllama_post.axmodel";

Expand Down Expand Up @@ -73,7 +74,7 @@ class LLM
std::vector<LLMLayer> llama_layers;
ax_runner_ax650 llama_post;

int prefill_grpid = 1;
// int prefill_grpid = 1;
int decode_grpid = 0;

// ax_runner_ax650 vpm_resampler;
Expand Down Expand Up @@ -235,7 +236,7 @@ class LLM
ALOGE("init axmodel(%s) failed", layer.filename.c_str());
}
}

printf("\n");
{
_attr.max_token_len = llama_layers[0].layer.get_input("mask").nSize / sizeof(unsigned short) - 1;
ALOGI("max_token_len : %d", _attr.max_token_len);
Expand All @@ -250,8 +251,10 @@ class LLM
return false;
}

_attr.prefill_token_num = llama_layers[0].layer.get_input(prefill_grpid, "indices").vShape[1];
_attr.prefill_token_num = llama_layers[0].layer.get_input(1, "indices").vShape[1];
ALOGI("prefill_token_num : %d", _attr.prefill_token_num);
_attr.prefill_max_token_num = llama_layers[0].layer.get_input(llama_layers[0].layer.get_num_input_groups() - 1, "mask").vShape[2];
ALOGI("prefill_max_token_num : %d", _attr.prefill_max_token_num);
}
if (attr.b_dynamic_load_axmodel_layer)
{
Expand Down Expand Up @@ -293,9 +296,9 @@ class LLM
int Encode(std::vector<unsigned short> &out_embed, std::string prompt = "What is in the image?")
{
std::vector<int> input_ids = tokenizer->Encode(prompt, true);
if (input_ids.size() > _attr.prefill_token_num)
if (input_ids.size() > _attr.prefill_max_token_num)
{
ALOGE("input_ids(%d) > prefill_token_num(%d)", input_ids.size(), _attr.prefill_token_num);
ALOGE("input_ids(%d) > prefill_max_token_num(%d)", input_ids.size(), _attr.prefill_max_token_num);
return -1;
}
out_embed.resize(input_ids.size() * _attr.tokens_embed_size);
Expand All @@ -322,23 +325,50 @@ class LLM
b_stop = false;
std::string final_out;

int input_embed_num = test_embed.size() / _attr.tokens_embed_size;
ALOGI("input token num : %d", input_embed_num);
int prefill_split_num = ceil((double)input_embed_num / _attr.prefill_token_num);
// ALOGI("prefill_split_num : %d", prefill_split_num);

bfloat16 bf16 = -65536.f;
std::vector<unsigned short> mask(_attr.kv_cache_num + 1, bf16.data);
std::vector<unsigned short> mask_p(_attr.prefill_token_num * _attr.prefill_token_num, bf16.data);
std::vector<std::vector<unsigned short>> mask_p(prefill_split_num);
std::vector<unsigned short> embed(_attr.tokens_embed_size, 0);

for (size_t i = 0; i < _attr.prefill_token_num; i++)
// for (size_t i = 0; i < _attr.prefill_token_num; i++)
// {
// for (size_t j = 0; j < i + 1; j++)
// {
// mask_p[i * _attr.prefill_token_num + j] = 0;
// }
// }
for (size_t p = 0; p < prefill_split_num; p++)
{
for (size_t j = 0; j < i + 1; j++)
std::vector<unsigned short> &mask_tmp = mask_p[p];
mask_tmp.resize((p + 1) * _attr.prefill_token_num * _attr.prefill_token_num, bf16.data);

size_t i = 0;
for (size_t t = p * _attr.prefill_token_num; t < (p + 1) * _attr.prefill_token_num; t++)
{
mask_p[i * _attr.prefill_token_num + j] = 0;
if (t < input_embed_num)
{
for (size_t j = 0; j < p * _attr.prefill_token_num + i + 1; j++)
mask_tmp[i * ((p + 1) * _attr.prefill_token_num) + j] = 0;
}
i++;
}
// char path[128];
// sprintf(path, "mask_p_%d.bin", p);
// FILE *fp = fopen(path, "wb");
// fwrite(mask_tmp.data(), sizeof(unsigned short), mask_tmp.size(), fp);
// fclose(fp);
}

std::vector<int> cached_token;
std::vector<int> token_ids;
// std::vector<int> token_ids = tokenizer->Encode(input_str);
// int len_of_input = token_ids.size();
int input_embed_num = test_embed.size() / _attr.tokens_embed_size;

// ALOGI("input_embed_num(%d)", input_embed_num);

mask[_attr.kv_cache_num] = 0;
Expand All @@ -350,70 +380,111 @@ class LLM
timer ttft_timer;
ttft_timer.start();

for (unsigned int m = 0; m < _attr.axmodel_num; m++)
for (size_t p = 0; p < prefill_split_num; p++)
{
if (b_stop)
{
break;
}

auto &layer = llama_layers[m];
auto &layer_llama = llama_layers[m];
std::vector<unsigned short> &mask_tmp = mask_p[p];
std::vector<unsigned short> embed_tmp(_attr.prefill_token_num * _attr.tokens_embed_size, 0);
if (p == (prefill_split_num - 1))
{
memcpy(embed_tmp.data(), test_embed.data() + p * _attr.prefill_token_num * _attr.tokens_embed_size, (input_embed_num - p * _attr.prefill_token_num) * _attr.tokens_embed_size * sizeof(unsigned short));
}
else
{
memcpy(embed_tmp.data(), test_embed.data() + p * _attr.prefill_token_num * _attr.tokens_embed_size, _attr.prefill_token_num * _attr.tokens_embed_size * sizeof(unsigned short));
}
int prefill_grpid = p + 1;

if (_attr.b_dynamic_load_axmodel_layer)
for (unsigned int m = 0; m < _attr.axmodel_num; m++)
{
int ret;
if (_attr.b_use_mmap_load_layer)
if (b_stop)
{
ret = layer.layer.init((char *)layer.layer_buffer.data(), layer.layer_buffer.size());
break;
}
else

auto &layer = llama_layers[m];
auto &layer_llama = llama_layers[m];

if (_attr.b_dynamic_load_axmodel_layer)
{
ret = layer.layer.init(layer.layer_buffer_vec.data(), layer.layer_buffer_vec.size());
int ret;
if (_attr.b_use_mmap_load_layer)
{
ret = layer.layer.init((char *)layer.layer_buffer.data(), layer.layer_buffer.size());
}
else
{
ret = layer.layer.init(layer.layer_buffer_vec.data(), layer.layer_buffer_vec.size());
}
if (ret != 0)
{
ALOGE("init axmodel(%s) failed", layer.filename.c_str());
}
}
if (ret != 0)

auto &input_indices = layer.layer.get_input(prefill_grpid, "indices");
unsigned int *input_indices_ptr = (unsigned int *)input_indices.pVirAddr;

for (unsigned int i = 0; i < _attr.prefill_token_num; i++)
{
ALOGE("init axmodel(%s) failed", layer.filename.c_str());
input_indices_ptr[i] = p * _attr.prefill_token_num + i;
}
}

auto &input_indices = layer.layer.get_input(prefill_grpid, "indices");
unsigned int *input_indices_ptr = (unsigned int *)input_indices.pVirAddr;
for (unsigned int i = 0; i < input_embed_num; i++)
{
input_indices_ptr[i] = i;
}
if (p > 0)
{
auto &input_prefill_k_cache = layer.layer.get_input(prefill_grpid, "K_cache");
auto &input_prefill_v_cache = layer.layer.get_input(prefill_grpid, "V_cache");
for (size_t i = 0; i < p; i++)
{
auto &output_k_cache = layer.layer.get_output(i + 1, "K_cache_out");
memcpy((unsigned short *)input_prefill_k_cache.pVirAddr + i * _attr.prefill_token_num * _attr.kv_cache_size,
output_k_cache.pVirAddr,
sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size);

auto &output_v_cache = layer.layer.get_output(i + 1, "V_cache_out");
memcpy((unsigned short *)input_prefill_v_cache.pVirAddr + i * _attr.prefill_token_num * _attr.kv_cache_size,
output_v_cache.pVirAddr,
sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size);
}
}

auto &input_mask = layer.layer.get_input(prefill_grpid, "mask");
memcpy(input_mask.pVirAddr, mask_p.data(), mask_p.size() * sizeof(unsigned short));
auto &input_mask = layer.layer.get_input(prefill_grpid, "mask");
memcpy(input_mask.pVirAddr, mask_tmp.data(), mask_tmp.size() * sizeof(unsigned short));

auto &input_input = layer.layer.get_input(prefill_grpid, "input");
memcpy(input_input.pVirAddr, test_embed.data(), test_embed.size() * sizeof(unsigned short));
if (m == 0)
{
test_embed.resize(_attr.prefill_token_num * _attr.tokens_embed_size);
}
auto &input_input = layer.layer.get_input(prefill_grpid, "input");
memcpy(input_input.pVirAddr, embed_tmp.data(), embed_tmp.size() * sizeof(unsigned short));

layer.layer.inference(prefill_grpid);
layer.layer.inference(prefill_grpid);

auto &output_k_cache = layer.layer.get_output(prefill_grpid, "K_cache_out");
AX_SYS_MinvalidateCache(output_k_cache.phyAddr, output_k_cache.pVirAddr, output_k_cache.nSize);
auto &input_k_cache = layer_llama.layer.get_input(decode_grpid, "K_cache");
memcpy(input_k_cache.pVirAddr, output_k_cache.pVirAddr, sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size);
auto &output_k_cache = layer.layer.get_output(prefill_grpid, "K_cache_out");
AX_SYS_MinvalidateCache(output_k_cache.phyAddr, output_k_cache.pVirAddr, output_k_cache.nSize);
auto &input_k_cache = layer_llama.layer.get_input(decode_grpid, "K_cache");
memcpy((unsigned short *)input_k_cache.pVirAddr + p * _attr.prefill_token_num * _attr.kv_cache_size, output_k_cache.pVirAddr, sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size);

auto &output_v_cache = layer.layer.get_output(prefill_grpid, "V_cache_out");
AX_SYS_MinvalidateCache(output_v_cache.phyAddr, output_v_cache.pVirAddr, output_v_cache.nSize);
auto &input_v_cache = layer_llama.layer.get_input(decode_grpid, "V_cache");
memcpy(input_v_cache.pVirAddr, output_v_cache.pVirAddr, sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size);
auto &output_v_cache = layer.layer.get_output(prefill_grpid, "V_cache_out");
AX_SYS_MinvalidateCache(output_v_cache.phyAddr, output_v_cache.pVirAddr, output_v_cache.nSize);
auto &input_v_cache = layer_llama.layer.get_input(decode_grpid, "V_cache");
memcpy((unsigned short *)input_v_cache.pVirAddr + p * _attr.prefill_token_num * _attr.kv_cache_size, output_v_cache.pVirAddr, sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size);

auto &output = layer.layer.get_output(prefill_grpid, "output");
AX_SYS_MinvalidateCache(output.phyAddr, output.pVirAddr, output.nSize);
memcpy(test_embed.data(), output.pVirAddr, test_embed.size() * sizeof(unsigned short));
if (_attr.b_dynamic_load_axmodel_layer)
auto &output = layer.layer.get_output(prefill_grpid, "output");
AX_SYS_MinvalidateCache(output.phyAddr, output.pVirAddr, output.nSize);
memcpy(embed_tmp.data(), output.pVirAddr, embed_tmp.size() * sizeof(unsigned short));
if (_attr.b_dynamic_load_axmodel_layer)
{
layer.layer.deinit();
}
// ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(), bfloat16(embed[2]).fp32(), bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32());
}
if (p == (prefill_split_num - 1))
{
layer.layer.deinit();
memcpy(embed.data(),
embed_tmp.data() + (input_embed_num - p * _attr.prefill_token_num - 1) * _attr.tokens_embed_size,
_attr.tokens_embed_size * sizeof(unsigned short));
}
// ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(), bfloat16(embed[2]).fp32(), bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32());
}

// ALOGI("prefill time cost: %.2f s", t_cost.cost() / 1000);
Expand All @@ -428,11 +499,10 @@ class LLM

int next_token = -1;
t_cqdm cqdm = create_cqdm(_attr.max_token_len, 32);
std::vector<unsigned short> embed(_attr.tokens_embed_size, 0);

memcpy(embed.data(),
test_embed.data() + (input_embed_num - 1) * _attr.tokens_embed_size,
_attr.tokens_embed_size * sizeof(unsigned short));
// memcpy(embed.data(),
// test_embed.data() + (input_embed_num - 1) * _attr.tokens_embed_size,
// _attr.tokens_embed_size * sizeof(unsigned short));

{

Expand Down
3 changes: 3 additions & 0 deletions src/runner/ax_model_runner/ax_model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class ax_runner_base
int get_num_inputs() { return minput_tensors.size(); };
int get_num_outputs() { return moutput_tensors.size(); };

int get_num_input_groups() { return mgroup_input_tensors.size(); };
int get_num_output_groups() { return mgroup_output_tensors.size(); };

const ax_runner_tensor_t &get_input(int idx) { return minput_tensors[idx]; }
const ax_runner_tensor_t *get_inputs_ptr() { return minput_tensors.data(); }
const ax_runner_tensor_t &get_input(std::string name)
Expand Down