diff --git a/src/db.rs b/src/db.rs index 423ece3..a819f4d 100644 --- a/src/db.rs +++ b/src/db.rs @@ -34,12 +34,15 @@ pub enum Error { #[derive(Debug, serde::Serialize, serde::Deserialize)] pub struct Db { + /// Collections in the database pub collections: HashMap, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema)] pub struct SimilarityResult { + /// Similarity score score: f32, + /// Matching embedding embedding: Embedding, } @@ -55,7 +58,37 @@ pub struct Collection { } impl Collection { - pub fn get_similarity(&self, query: &[f32], k: usize) -> Vec { + pub fn list(&self) -> Vec { + self + .embeddings + .iter() + .map(|e| e.id.to_owned()) + .collect() + } + + pub fn get(&self, id: &str) -> Option<&Embedding> { + self + .embeddings + .iter() + .find(|e| e.id == id) + } + + pub fn get_by_metadata(&self, filter: &[HashMap], k: usize) -> Vec { + self + .embeddings + .iter() + .filter_map(|embedding| { + if match_embedding(embedding, filter) { + Some(embedding.clone()) + } else { + None + } + }) + .take(k) + .collect() + } + + pub fn get_by_metadata_and_similarity(&self, filter: &[HashMap], query: &[f32], k: usize) -> Vec { let memo_attr = get_cache_attr(self.distance, query); let distance_fn = get_distance_fn(self.distance); @@ -63,9 +96,13 @@ impl Collection { .embeddings .par_iter() .enumerate() - .map(|(index, embedding)| { - let score = distance_fn(&embedding.vector, query, memo_attr); - ScoreIndex { score, index } + .filter_map(|(index, embedding)| { + if match_embedding(embedding, filter) { + let score = distance_fn(&embedding.vector, query, memo_attr); + Some(ScoreIndex { score, index }) + } else { + None + } }) .collect::>(); @@ -88,12 +125,86 @@ impl Collection { }) .collect() } + + pub fn delete(&mut self, id: &str) -> bool { + let index_opt = self.embeddings + .iter() + .position(|e| e.id == id); + + match index_opt { + None => false, + Some(index) => { self.embeddings.remove(index); true } + } + } + + pub fn delete_by_metadata(&mut self, filter: &[HashMap]) { + if filter.len() == 0 { + self.embeddings.clear(); + return + } + + let indexes = self + .embeddings + .par_iter() + .enumerate() + .filter_map(|(index, embedding)| { + if match_embedding(embedding, filter) { + Some(index) + } else { + None + } + }) + .collect::>(); + + for index in indexes { + self.embeddings.remove(index); + } + } +} + +fn match_embedding(embedding: &Embedding, filter: &[HashMap]) -> bool { + // an empty filter matches any embedding + if filter.len() == 0 { + return true + } + + match &embedding.metadata { + // no metadata in an embedding cannot be matched by a not empty filter + None => false, + Some(metadata) => { + // enumerate criteria with OR semantics; look for the first one matching + for criteria in filter { + let mut matches = true; + // enumerate entries with AND semantics; look for the first one failing + for (key, expected) in criteria { + let found = match metadata.get(key) { + None => false, + Some(actual) => actual == expected + }; + // a not matching entry means the whole embedding not matching + if !found { + matches = false; + break + } + } + // all entries matching mean the whole embedding matching + if matches { + return true + } + } + // no match found + false + } + } } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema)] pub struct Embedding { + /// Unique identifier pub id: String, + /// Vector computed from a text chunk pub vector: Vec, + /// Metadata about the source text pub metadata: Option>, } @@ -171,6 +282,18 @@ impl Db { self.collections.get(name) } + pub fn get_collection_mut(&mut self, name: &str) -> Option<&mut Collection> { + self.collections.get_mut(name) + } + + pub fn list(&self) -> Vec { + self + .collections + .keys() + .map(|name| name.to_owned()) + .collect() + } + fn load_from_store() -> anyhow::Result { if !STORE_PATH.exists() { tracing::debug!("Creating database store"); diff --git a/src/routes/collection.rs b/src/routes/collection.rs index c0f8f47..888b604 100644 --- a/src/routes/collection.rs +++ b/src/routes/collection.rs @@ -5,7 +5,10 @@ use aide::axum::{ use axum::{extract::Path, http::StatusCode, Extension}; use axum_jsonschema::Json; use schemars::JsonSchema; -use std::time::Instant; +use std::{ + collections::HashMap, + time::Instant, +}; use crate::{ db::{self, Collection, DbExtension, Embedding, Error as DbError, SimilarityResult}, @@ -17,14 +20,33 @@ pub fn handler() -> ApiRouter { ApiRouter::new().nest( "/collections", ApiRouter::new() + .api_route("/", get(get_collections)) .api_route("/:collection_name", put(create_collection)) .api_route("/:collection_name", post(query_collection)) .api_route("/:collection_name", get(get_collection_info)) .api_route("/:collection_name", delete(delete_collection)) - .api_route("/:collection_name/insert", post(insert_into_collection)), + .api_route("/:collection_name/embeddings", get(get_embeddings)) + .api_route("/:collection_name/embeddings", post(query_embeddings)) + .api_route("/:collection_name/embeddings", delete(delete_embeddings)) + .api_route("/:collection_name/embeddings/:embedding_id", put(insert_into_collection)) + .api_route("/:collection_name/embeddings/:embedding_id", get(get_embedding)) + .api_route("/:collection_name/embeddings/:embedding_id", delete(delete_embedding)), ) } +/// Get collection names +async fn get_collections( + Extension(db): DbExtension, +) -> Result>, HTTPError> { + tracing::trace!("Getting collection names"); + + let db = db.read().await; + + let results = db.list(); + + Ok(Json(results)) +} + /// Create a new collection async fn create_collection( Path(collection_name): Path, @@ -54,6 +76,8 @@ async fn create_collection( struct QueryCollectionQuery { /// Vector to query with query: Vec, + /// Metadata to filter with + filter: Option>>, /// Number of results to return k: Option, } @@ -77,7 +101,7 @@ async fn query_collection( } let instant = Instant::now(); - let results = collection.get_similarity(&req.query, req.k.unwrap_or(1)); + let results = collection.get_by_metadata_and_similarity(&req.filter.unwrap_or_default(), &req.query, req.k.unwrap_or(1)); drop(db); tracing::trace!("Query to {collection_name} took {:?}", instant.elapsed()); @@ -138,16 +162,29 @@ async fn delete_collection( } } +#[derive(Debug, serde::Deserialize, JsonSchema)] +struct EmbeddingData { + /// Vector computed from a text chunk + vector: Vec, + /// Metadata about the source text + metadata: Option>, +} + /// Insert a vector into a collection async fn insert_into_collection( - Path(collection_name): Path, + Path((collection_name, embedding_id)): Path<(String, String)>, Extension(db): DbExtension, - Json(embedding): Json, + Json(embedding_data): Json, ) -> Result { tracing::trace!("Inserting into collection {collection_name}"); let mut db = db.write().await; + let embedding = Embedding { + id: embedding_id, + vector: embedding_data.vector, + metadata: embedding_data.metadata, + }; let insert_result = db.insert_into_collection(&collection_name, embedding); drop(db); @@ -165,3 +202,109 @@ async fn insert_into_collection( .with_status(StatusCode::BAD_REQUEST)), } } + +/// Query embeddings in a collection +async fn get_embeddings( + Path(collection_name): Path, + Extension(db): DbExtension, +) -> Result>, HTTPError> { + tracing::trace!("Querying embeddings from collection {collection_name}"); + + let db = db.read().await; + let collection = db + .get_collection(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + let results = collection.list(); + drop(db); + + Ok(Json(results)) +} + +#[derive(Debug, serde::Deserialize, JsonSchema)] +struct EmbeddingsQuery { + /// Metadata to filter with + filter: Vec>, + /// Number of results to return + k: Option, +} + +/// Query embeddings in a collection +async fn query_embeddings( + Path(collection_name): Path, + Extension(db): DbExtension, + Json(req): Json, +) -> Result>, HTTPError> { + tracing::trace!("Querying embeddings from collection {collection_name}"); + + let db = db.read().await; + let collection = db + .get_collection(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + let instant = Instant::now(); + let results = collection.get_by_metadata(&req.filter, req.k.unwrap_or(1)); + drop(db); + + tracing::trace!("Query embeddings from {collection_name} took {:?}", instant.elapsed()); + Ok(Json(results)) +} + +/// Delete embeddings in a collection +async fn delete_embeddings( + Path(collection_name): Path, + Extension(db): DbExtension, + Json(req): Json, +) -> Result { + tracing::trace!("Querying embeddings from collection {collection_name}"); + + let mut db = db.write().await; + let collection = db + .get_collection_mut(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + collection.delete_by_metadata(&req.filter); + drop(db); + + Ok(StatusCode::NO_CONTENT) +} + +/// Get an embedding from a collection +async fn get_embedding( + Path((collection_name, embedding_id)): Path<(String, String)>, + Extension(db): DbExtension, +) -> Result, HTTPError> { + tracing::trace!("Getting {embedding_id} from collection {collection_name}"); + + let db = db.read().await; + let collection = db + .get_collection(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + let embedding = collection + .get(&embedding_id) + .ok_or_else(|| HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND))?; + + Ok(Json(embedding.to_owned())) +} + +/// Delete an embedding from a collection +async fn delete_embedding( + Path((collection_name, embedding_id)): Path<(String, String)>, + Extension(db): DbExtension, +) -> Result { + tracing::trace!("Removing embedding {embedding_id} from collection {collection_name}"); + + let mut db = db.write().await; + let collection = db + .get_collection_mut(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + let delete_result = collection.delete(&embedding_id); + drop(db); + + match delete_result { + true => Ok(StatusCode::NO_CONTENT), + false => Err(HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND)), + } +}