Skip to content

Commit

Permalink
Add retry logic for embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
richard-epsilla committed Dec 14, 2023
1 parent c7811f6 commit d77e465
Showing 1 changed file with 38 additions and 22 deletions.
60 changes: 38 additions & 22 deletions engine/services/embedding_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#include <iostream>
#include <curl/curl.h>

#include <chrono>
#include <thread>
#include <random>

namespace vectordb {
namespace engine {

Expand Down Expand Up @@ -45,31 +49,43 @@ Status EmbeddingService::denseEmbedDocuments(
size_t end_record,
size_t dimension
) {
try {
auto requestBody = EmbeddingRequestBody::createShared();
requestBody->model = model_name;
// Constructing documents list from attr_column_container
requestBody->documents = oatpp::List<oatpp::String>({});
for (size_t idx = start_record; idx < end_record; ++idx) {
// Assuming attr_column_container[idx] returns a string or can be converted to string
requestBody->documents->push_back(oatpp::String(std::get<std::string>(attr_column_container[idx]).c_str()));
}
auto response = m_client->denseEmbedDocuments("/v1/embeddings", requestBody);
auto responseBody = response->readBodyToString();
vectordb::Json json;
json.LoadFromString(responseBody->c_str());
if (json.GetInt("statusCode") == 200) {
auto embeddings = json.GetArray("result");
for (auto idx = start_record; idx < end_record; ++idx) {
auto embedding = embeddings.GetArrayElement(idx - start_record);
for (auto i = 0; i < dimension; i++) {
vector_table[idx * dimension + i] = static_cast<float>((float)(embedding.GetArrayElement(i).GetDouble()));
int attempt = 0;
while (attempt < EmbeddingRetry) {
try {
auto requestBody = EmbeddingRequestBody::createShared();
requestBody->model = model_name;
// Constructing documents list from attr_column_container
requestBody->documents = oatpp::List<oatpp::String>({});
for (size_t idx = start_record; idx < end_record; ++idx) {
// Assuming attr_column_container[idx] returns a string or can be converted to string
requestBody->documents->push_back(oatpp::String(std::get<std::string>(attr_column_container[idx]).c_str()));
}
auto response = m_client->denseEmbedDocuments("/v1/embeddings", requestBody);
auto responseBody = response->readBodyToString();
// std::cout << "Embedding response: " << responseBody->c_str() << std::endl;
vectordb::Json json;
json.LoadFromString(responseBody->c_str());
if (json.GetInt("statusCode") == 200) {
auto embeddings = json.GetArray("result");
for (auto idx = start_record; idx < end_record; ++idx) {
auto embedding = embeddings.GetArrayElement(idx - start_record);
for (auto i = 0; i < dimension; i++) {
vector_table[idx * dimension + i] = static_cast<float>((float)(embedding.GetArrayElement(i).GetDouble()));
}
}
return Status::OK();
}
return Status::OK();
} catch (const std::exception& e) {
std::cerr << "Exception in embedDocuments: " << e.what() << std::endl;
}
} catch (const std::exception& e) {
std::cerr << "Exception in embedDocuments: " << e.what() << std::endl;
attempt++;
if (attempt >= EmbeddingRetry) {
break;
}
// Exponential backoff logic
int delaySec = EmbeddingBackoffInitialDelaySec * std::pow(EmbeddingBackoffExpBase, attempt);
std::this_thread::sleep_for(std::chrono::seconds(delaySec));
std::cout << "Retry embedding documents." << std::endl;
}
return Status(INFRA_UNEXPECTED_ERROR, "Failed to embbed the documents.");
}
Expand Down

0 comments on commit d77e465

Please sign in to comment.