From 2690f65b47a3897dbab46bb2b4f3e9f1e9f051fd Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Tue, 19 Nov 2024 19:41:15 -0800 Subject: [PATCH 1/3] Allow OgaShutdown to be called multiple times --- src/generators.cpp | 20 ++++++++++++++++---- test/c_api_tests.cpp | 4 ++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 40148d4ba..8c041bd38 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -61,11 +61,21 @@ struct ValidateShutdown { } }; +static std::unique_ptr g_globals; +static std::mutex g_globals_mutex; +static auto g_validate_shutdown = std::make_unique(); // Must be after the above line so the destructor runs before the above destructor + std::unique_ptr& GetOrtGlobals() { - static auto globals = std::make_unique(); - static auto validate = std::make_unique(); // Must be after the above line so the destructor runs before the above destructor - return globals; + + // Initialize g_globals using the g_globals_mutex + if (!g_globals) { + std::lock_guard lock(g_globals_mutex); + if (!g_globals) // Now that we're in the mutex, double check + g_globals = std::make_unique(); + } + + return g_globals; } // Used by Shutdown() to display the counts and types of any leaked objects @@ -81,7 +91,7 @@ void Shutdown() { std::abort(); } - GetOrtGlobals().reset(); // Delete now because on process exit is too late + g_globals.reset(); } OrtEnv& GetOrtEnv() { @@ -313,6 +323,8 @@ bool Generator::IsSessionTerminated() const { } void Generator::GenerateNextToken() { + if (search_->IsDone()) + throw std::runtime_error("Search is already done, can't generate next token"); ThrowErrorIfSessionTerminated(state_->session_terminated_); if (!computed_logits_) throw std::runtime_error("Must call ComputeLogits before GenerateNextToken"); diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 16747bf38..86769a151 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -348,6 +348,10 @@ TEST(CAPITests, SetTerminate) { #endif } +TEST(CAPITests, Shutdown) { + OgaShutdown(); // Shutdown in the middle, and see if the next test works properly +} + #if TEST_PHI2 struct Phi2Test { From d26cf98bde0bf7247c4c9e27506d90d0a61cd46e Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Tue, 19 Nov 2024 19:45:15 -0800 Subject: [PATCH 2/3] Lint --- src/generators.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 8c041bd38..16eb088a5 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -67,11 +67,10 @@ static auto g_validate_shutdown = std::make_unique(); // Must std::unique_ptr& GetOrtGlobals() { - // Initialize g_globals using the g_globals_mutex if (!g_globals) { std::lock_guard lock(g_globals_mutex); - if (!g_globals) // Now that we're in the mutex, double check + if (!g_globals) // Now that we're in the mutex, double check g_globals = std::make_unique(); } From 4c583903dcdf6fd108c5dad8bfea61fba6e2e544 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Tue, 19 Nov 2024 20:01:08 -0800 Subject: [PATCH 3/3] Fix ValidateShutdown --- src/generators.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 16eb088a5..fdadc354e 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -51,19 +51,18 @@ OrtGlobals::OrtGlobals() env_->CreateAndRegisterAllocator(allocator_cpu.GetInfo(), *arena_config); } +static std::unique_ptr g_globals; +static std::mutex g_globals_mutex; + // Ensure Shutdown() has been called before process exit struct ValidateShutdown { ~ValidateShutdown() { - if (GetOrtGlobals()) { + if (g_globals) { std::cerr << "OGA Error: Shutdown must be called before process exit, please check the documentation for the proper API to call to ensure clean shutdown." << std::endl; std::abort(); } } -}; - -static std::unique_ptr g_globals; -static std::mutex g_globals_mutex; -static auto g_validate_shutdown = std::make_unique(); // Must be after the above line so the destructor runs before the above destructor +} g_shutdown; // This struct should stay immediately after the g_globals and g_globals_mutex std::unique_ptr& GetOrtGlobals() {