Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model chat example #1197

Merged
merged 15 commits into from
Jan 22, 2025
127 changes: 127 additions & 0 deletions examples/python/model-chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import onnxruntime_genai as og
ajindal1 marked this conversation as resolved.
Show resolved Hide resolved
import argparse
import time

def main(args):
if args.verbose: print("Loading model...")
if args.timings:
started_timestamp = 0
first_token_timestamp = 0

config = og.Config(args.model_path)
config.clear_providers()
if args.execution_provider != "cpu":
if args.verbose: print(f"Setting model to {args.execution_provider}")
config.append_provider(args.execution_provider)
model = og.Model(config)

if args.verbose: print("Model loaded")

tokenizer = og.Tokenizer(model)
tokenizer_stream = tokenizer.create_stream()
if args.verbose: print("Tokenizer created")
if args.verbose: print()

search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}
search_options['batch_size'] = 1

if args.verbose: print(search_options)

model_type = config.get_model_type()
if args.chat_template:
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
print("Error, chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
exit(1)
else:
if model_type.startswith("phi"):
args.chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
elif model_type.startswith("llama"):
args.chat_template = '<|start_header_id|>user<|end_header_id|>{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
else:
print("Chat Template is unknown for model type:", model_type, "and it can result in erroneous results, please specify --chat_template flag for better output, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
exit(1)
ajindal1 marked this conversation as resolved.
Show resolved Hide resolved

if args.verbose:
print("Model type is:", model_type)
print("Chat Template is:", args.chat_template)

params = og.GeneratorParams(model)
params.set_search_options(**search_options)
generator = og.Generator(model, params)

# Set system prompt
system_prompt = args.system_prompt
system_tokens = tokenizer.encode(system_prompt)
generator.append_tokens(system_tokens)
system_prompt_length = len(system_tokens)

# Keep asking for input prompts in a loop
while True:
text = input("Input: ")
if not text:
print("Error, input cannot be empty")
continue

if args.timings: started_timestamp = time.time()

# If there is a chat template, use it
prompt = text
if args.chat_template:
prompt = f'{args.chat_template.format(input=text)}'

input_tokens = tokenizer.encode(prompt)

generator.append_tokens(input_tokens)
if args.verbose: print("Generator created")
ajindal1 marked this conversation as resolved.
Show resolved Hide resolved

if args.verbose: print("Running generation loop ...")
if args.timings:
first = True
new_tokens = []

print()
print("Output: ", end='', flush=True)

try:
while not generator.is_done():
generator.generate_next_token()
if args.timings:
if first:
first_token_timestamp = time.time()
first = False

new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end='', flush=True)
if args.timings: new_tokens.append(new_token)
except KeyboardInterrupt:
print(" --control+c pressed, aborting generation--")
print()
print()

if args.timings:
prompt_time = first_token_timestamp - started_timestamp
run_time = time.time() - first_token_timestamp
print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps")

# Rewind the generator to the system prompt
if args.rewind:
generator.rewind_to(system_prompt_length)

if __name__ == "__main__":
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai")
parser.add_argument('-m', '--model_path', type=str, required=True, help='Onnx model folder path (must contain genai_config.json and model.onnx)')
parser.add_argument('-e', '--execution_provider', type=str, required=True, choices=["cpu", "cuda", "dml"], help="Execution provider to run ONNX model with")
parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt')
parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt')
parser.add_argument('-ds', '--do_random_sampling', action='store_true', help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false')
parser.add_argument('-p', '--top_p', type=float, help='Top p probability to sample with')
parser.add_argument('-k', '--top_k', type=int, help='Top k tokens to sample from')
parser.add_argument('-t', '--temperature', type=float, help='Temperature to sample with')
parser.add_argument('-re', '--repetition_penalty', type=float, help='Repetition penalty to sample with')
parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false')
parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false')
parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}')
parser.add_argument('-s', '--system_prompt', type=str, default='You are a helpful assistant.', help='System prompt to use for the prompt.')
parser.add_argument('-r', '--rewind', action='store_true', default=False, help='Rewind to the system prompt after each generation. Defaults to false')
args = parser.parse_args()
main(args)
4 changes: 4 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,10 @@ void ClearProviders(Config& config) {
config.model.decoder.session_options.provider_options.clear();
}

std::string GetModelType(Config& config) {
return config.model.type;
}

void SetProviderOption(Config& config, std::string_view provider_name, std::string_view option_name, std::string_view option_value) {
std::ostringstream json;
json << R"({")" << provider_name << R"(":{)";
Expand Down
1 change: 1 addition & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ struct Config {
void SetSearchNumber(Config::Search& search, std::string_view name, double value);
void SetSearchBool(Config::Search& search, std::string_view name, bool value);
void ClearProviders(Config& config);
std::string GetModelType(Config& config);
void SetProviderOption(Config& config, std::string_view provider_name, std::string_view option_name, std::string_view option_value);
bool IsCudaGraphEnabled(Config::SessionOptions& session_options);

Expand Down
6 changes: 6 additions & 0 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ struct OgaConfig : OgaAbstract {
OgaCheckResult(OgaConfigClearProviders(this));
}

std::string GetModelType() {
std::string name;
OgaCheckResult(OgaConfigGetModelType(this, name));
return name;
}

void AppendProvider(const char* provider) {
OgaCheckResult(OgaConfigAppendProvider(this, provider));
}
Expand Down
7 changes: 7 additions & 0 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ OgaResult* OGA_API_CALL OgaConfigClearProviders(OgaConfig* config) {
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaConfigGetModelType(OgaConfig* config, std::string& name) {
OGA_TRY
name = Generators::GetModelType(*reinterpret_cast<Generators::Config*>(config));
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaConfigAppendProvider(OgaConfig* config, const char* provider) {
OGA_TRY
Generators::SetProviderOption(*reinterpret_cast<Generators::Config*>(config), provider, {}, {});
Expand Down
7 changes: 7 additions & 0 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateConfig(const char* config_path, OgaC
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaConfigClearProviders(OgaConfig* config);

/**
* \brief Get the Model Type in the given config
* \param[in] config The config to get the model type from
* \return OgaResult containing the error message if the clearing of the providers failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaConfigGetModelType(OgaConfig* config, std::string& name);
ajindal1 marked this conversation as resolved.
Show resolved Hide resolved

/**
* \brief Add the provider at the end of the list of providers in the given config if it doesn't already exist
* if it already exists, does nothing.
Expand Down
3 changes: 2 additions & 1 deletion src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
.def(pybind11::init([](const std::string& config_path) { return OgaConfig::Create(config_path.c_str()); }))
.def("append_provider", &OgaConfig::AppendProvider)
.def("set_provider_option", &OgaConfig::SetProviderOption)
.def("clear_providers", &OgaConfig::ClearProviders);
.def("clear_providers", &OgaConfig::ClearProviders)
.def("get_model_type", &OgaConfig::GetModelType);
ajindal1 marked this conversation as resolved.
Show resolved Hide resolved

pybind11::class_<Model, std::shared_ptr<Model>>(m, "Model")
.def(pybind11::init([](const OgaConfig& config) {
Expand Down
1 change: 1 addition & 0 deletions test/python/test_onnxruntime_genai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
config.set_provider_option("cuda", "infinite_clock", "1")
config.set_provider_option("quantum", "break_universe", "true")
config.append_provider("slide rule")
assert(config.get_model_type(), "gpt2")
Fixed Show fixed Hide fixed

@pytest.mark.parametrize(
"relative_model_path",
Expand Down
Loading