diff --git a/docs/architecture/graph/chat_documents.png b/docs/architecture/graph/chat_documents.png index 234930a93..d8eaf5d72 100644 Binary files a/docs/architecture/graph/chat_documents.png and b/docs/architecture/graph/chat_documents.png differ diff --git a/docs/architecture/graph/chat_documents_large.png b/docs/architecture/graph/chat_documents_large.png deleted file mode 100644 index e591e8129..000000000 Binary files a/docs/architecture/graph/chat_documents_large.png and /dev/null differ diff --git a/redbox-core/redbox/app.py b/redbox-core/redbox/app.py index f6596af4d..c79ff8f17 100644 --- a/redbox-core/redbox/app.py +++ b/redbox-core/redbox/app.py @@ -6,7 +6,6 @@ from redbox.graph.root import ( get_chat_with_documents_graph, get_root_graph, - get_chat_with_documents_large_graph, ) from redbox.models.chain import RedboxState from redbox.models.graph import ( @@ -101,9 +100,7 @@ def draw(self, output_path=None, graph_to_draw: Literal["root", "chat_with_docum if graph_to_draw == "root": graph = self.graph.get_graph() elif graph_to_draw == "chat/documents": - graph = get_chat_with_documents_graph(self.all_chunks_retriever).get_graph() - elif graph_to_draw == "chat/documents/large": - graph = get_chat_with_documents_large_graph().get_graph() + graph = get_chat_with_documents_graph(self.retriever).get_graph() else: raise Exception("Invalid graph_to_draw") diff --git a/redbox-core/redbox/graph/edges.py b/redbox-core/redbox/graph/edges.py index 679c47153..2837dc117 100644 --- a/redbox-core/redbox/graph/edges.py +++ b/redbox-core/redbox/graph/edges.py @@ -1,5 +1,4 @@ import logging -import re from typing import Literal from langchain_core.runnables import Runnable @@ -7,7 +6,6 @@ from redbox.chains.components import get_tokeniser from redbox.graph.nodes.processes import PromptSet from redbox.models.chain import RedboxState, get_prompts -from redbox.transform import get_document_token_count log = logging.getLogger() @@ -40,50 +38,12 @@ def _total_tokens_request_handler_conditional( if total_tokens > max_tokens_allowed: return "max_exceeded" elif total_tokens > token_budget_remaining_in_context: - return "context_exceeded" + return "max_exceeded" else: return "pass" return _total_tokens_request_handler_conditional -def build_documents_bigger_than_context_conditional(prompt_set: PromptSet) -> Runnable: - """Uses a set of prompts to build the correct conditional for exceeding the context window.""" - - def _documents_bigger_than_context_conditional(state: RedboxState) -> bool: - system_prompt, question_prompt = get_prompts(state, prompt_set) - token_budget = calculate_token_budget(state, system_prompt, question_prompt) - - return get_document_token_count(state) > token_budget - - return _documents_bigger_than_context_conditional - - -def documents_bigger_than_n_conditional(state: RedboxState) -> bool: - """Do the documents meet a hard limit of document token size set in AI Settings.""" - token_counts = get_document_token_count(state) - return sum(token_counts) > state.request.ai_settings.max_document_tokens - - def documents_selected_conditional(state: RedboxState) -> bool: return len(state.request.s3_keys) > 0 - - -def multiple_docs_in_group_conditional(state: RedboxState) -> bool: - return any(len(group) > 1 for group in state.documents.groups.values()) - - -def build_strings_end_text_conditional(*strings: str) -> Runnable: - """Given a list of strings, returns the string if the end of state.last_message.content contains it.""" - pattern = "|".join(re.escape(s) for s in strings) - regex = re.compile(pattern, re.IGNORECASE) - - def _strings_end_text_conditional(state: RedboxState) -> str: - matches = regex.findall(state.last_message.content[-100:]) # padding for waffle - unique_matches = set(matches) - - if len(unique_matches) == 1: - return unique_matches.pop().lower() - return "DEFAULT" - - return _strings_end_text_conditional diff --git a/redbox-core/redbox/graph/nodes/processes.py b/redbox-core/redbox/graph/nodes/processes.py index 866643795..366788ba8 100644 --- a/redbox-core/redbox/graph/nodes/processes.py +++ b/redbox-core/redbox/graph/nodes/processes.py @@ -1,27 +1,21 @@ -import json import logging import re -import textwrap from collections.abc import Callable -from functools import reduce -from typing import Any, Iterable -from uuid import uuid4 +from typing import Any from langchain.schema import StrOutputParser -from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.documents import Document from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel from langchain_core.tools import StructuredTool from langchain_core.vectorstores import VectorStoreRetriever -from redbox.chains.activity import log_activity -from redbox.chains.components import get_chat_llm, get_tokeniser +from redbox.chains.components import get_chat_llm from redbox.chains.runnables import CannedChatLLM, build_llm_chain from redbox.models import ChatRoute from redbox.models.chain import DocumentState, PromptSet, RedboxState, RequestMetadata -from redbox.models.graph import ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG, RedboxActivityEvent, RedboxEventType -from redbox.transform import combine_documents, flatten_document_state +from redbox.models.graph import ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG +from redbox.transform import flatten_document_state log = logging.getLogger(__name__) re_keyword_pattern = re.compile(r"@(\w+)") @@ -72,65 +66,6 @@ def _chat(state: RedboxState) -> dict[str, Any]: return _chat -def build_merge_pattern( - prompt_set: PromptSet, - tools: list[StructuredTool] | None = None, - final_response_chain: bool = False, -) -> Runnable[RedboxState, dict[str, Any]]: - """Returns a Runnable that uses state.request and state.documents to return one item in state.documents. - - When combined with chunk send, will replace each Document with what's returned from the LLM. - - When combined with group send, with combine all Documents and use the metadata of the first. - - When used without a send, the first Document receieved defines the metadata. - - If tools are supplied, can also set state["tool_calls"]. - """ - tokeniser = get_tokeniser() - - @RunnableLambda - def _merge(state: RedboxState) -> dict[str, Any]: - llm = get_chat_llm(state.request.ai_settings.chat_backend, tools=tools) - - if not state.documents.groups: - return {"documents": None} - - flattened_documents = flatten_document_state(state.documents) - merged_document = reduce(lambda left, right: combine_documents(left, right), flattened_documents) - - merge_state = RedboxState( - request=state.request, - documents=DocumentState( - groups={merged_document.metadata["uuid"]: {merged_document.metadata["uuid"]: merged_document}} - ), - ) - - merge_response = build_llm_chain( - prompt_set=prompt_set, llm=llm, final_response_chain=final_response_chain - ).invoke(merge_state) - - merged_document.page_content = merge_response["messages"][-1].content - request_metadata = merge_response["metadata"] - merged_document.metadata["token_count"] = len(tokeniser.encode(merged_document.page_content)) - - group_uuid = next(iter(state.documents.groups or {}), uuid4()) - document_uuid = merged_document.metadata.get("uuid", uuid4()) - - # Clear old documents, add new one - document_state = state.documents.groups.copy() - - for group in document_state: - for document in document_state[group]: - document_state[group][document] = None - - document_state[group_uuid][document_uuid] = merged_document - - return {"documents": DocumentState(groups=document_state), "metadata": request_metadata} - - return _merge - - def build_stuff_pattern( prompt_set: PromptSet, output_parser: Runnable = None, @@ -174,28 +109,6 @@ def _set_route(state: RedboxState) -> dict[str, Any]: return RunnableLambda(_set_route).with_config(tags=[ROUTE_NAME_TAG]) -def build_set_self_route_from_llm_answer( - conditional: Callable[[str], bool], - true_condition_state_update: dict, - false_condition_state_update: dict, - final_route_response: bool = True, -) -> Runnable[RedboxState, dict[str, Any]]: - """A Runnable which sets the route based on a conditional on state['text']""" - - @RunnableLambda - def _set_self_route_from_llm_answer(state: RedboxState): - llm_response = state.last_message.content - if conditional(llm_response): - return true_condition_state_update - else: - return false_condition_state_update - - runnable = _set_self_route_from_llm_answer - if final_route_response: - runnable = _set_self_route_from_llm_answer.with_config(tags=[ROUTE_NAME_TAG]) - return runnable - - def build_passthrough_pattern() -> Runnable[RedboxState, dict[str, Any]]: """Returns a Runnable that uses state["request"] to set state["text"].""" @@ -258,62 +171,5 @@ def clear_documents_process(state: RedboxState) -> dict[str, Any]: return {"documents": DocumentState(groups={group_id: None for group_id in documents.groups})} -def report_sources_process(state: RedboxState) -> None: - """A Runnable which reports the documents in the state as sources.""" - if citations_state := state.citations: - dispatch_custom_event(RedboxEventType.on_citations_report, citations_state) - elif document_state := state.documents: - dispatch_custom_event(RedboxEventType.on_source_report, flatten_document_state(document_state)) - - def empty_process(state: RedboxState) -> None: return None - - -def build_log_node(message: str) -> Runnable[RedboxState, dict[str, Any]]: - """A Runnable which logs the current state in a compact way""" - - @RunnableLambda - def _log_node(state: RedboxState): - log.info( - json.dumps( - { - "user_uuid": str(state.request.user_uuid), - "document_metadata": { - group_id: {doc_id: d.metadata for doc_id, d in group_documents.items()} - for group_id, group_documents in state.documents.group - }, - "messages": (textwrap.shorten(state.last_message.content, width=32, placeholder="...")), - "route": state.route_name, - "message": message, - } - ) - ) - return None - - return _log_node - - -def build_activity_log_node( - log_message: RedboxActivityEvent - | Callable[[RedboxState], Iterable[RedboxActivityEvent]] - | Callable[[RedboxState], Iterable[RedboxActivityEvent]], -): - """ - A Runnable which emits activity events based on the state. The message should either be a static message to log, or a function which returns an activity event or an iterator of them - """ - - @RunnableLambda - def _activity_log_node(state: RedboxState): - if isinstance(log_message, RedboxActivityEvent): - log_activity(log_message) - else: - response = log_message(state) - if isinstance(response, RedboxActivityEvent): - log_activity(response) - else: - for message in response: - log_activity(message) - return None - - return _activity_log_node diff --git a/redbox-core/redbox/graph/nodes/sends.py b/redbox-core/redbox/graph/nodes/sends.py deleted file mode 100644 index 363aefd81..000000000 --- a/redbox-core/redbox/graph/nodes/sends.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Callable - -from langgraph.constants import Send - -from redbox.models.chain import DocumentState, RedboxState - - -def _copy_state(state: RedboxState, **updates) -> RedboxState: - updated_model = state.model_copy(update=updates, deep=True) - return updated_model - - -def build_document_group_send(target: str) -> Callable[[RedboxState], list[Send]]: - """Builds Sends per document group.""" - - def _group_send(state: RedboxState) -> list[Send]: - group_send_states: list[RedboxState] = [ - _copy_state( - state, - documents=DocumentState(groups={document_group_key: document_group}), - ) - for document_group_key, document_group in state.documents.groups.items() - ] - return [Send(node=target, arg=state) for state in group_send_states] - - return _group_send - - -def build_document_chunk_send(target: str) -> Callable[[RedboxState], list[Send]]: - """Builds Sends per individual document""" - - def _chunk_send(state: RedboxState) -> list[Send]: - chunk_send_states: list[RedboxState] = [ - _copy_state( - state, - documents=DocumentState(groups={document_group_key: {document_key: document}}), - ) - for document_group_key, document_group in state.documents.groups.items() - for document_key, document in document_group.items() - ] - return [Send(node=target, arg=state) for state in chunk_send_states] - - return _chunk_send diff --git a/redbox-core/redbox/graph/root.py b/redbox-core/redbox/graph/root.py index 666f890f3..ee03409f1 100644 --- a/redbox-core/redbox/graph/root.py +++ b/redbox-core/redbox/graph/root.py @@ -3,16 +3,13 @@ from langgraph.graph.graph import CompiledGraph from redbox.graph.edges import ( - build_documents_bigger_than_context_conditional, build_total_tokens_request_handler_conditional, documents_selected_conditional, - multiple_docs_in_group_conditional, ) from redbox.graph.nodes.processes import ( PromptSet, build_chat_pattern, build_error_pattern, - build_merge_pattern, build_passthrough_pattern, build_retrieve_pattern, build_set_metadata_pattern, @@ -21,7 +18,6 @@ clear_documents_process, empty_process, ) -from redbox.graph.nodes.sends import build_document_chunk_send, build_document_group_send from redbox.models.chain import RedboxState from redbox.models.chat import ChatRoute, ErrorRoute from redbox.transform import structure_documents_by_file_name @@ -48,91 +44,6 @@ def get_chat_graph( return builder.compile(debug=debug) -def get_chat_with_documents_large_graph(): - """a subgraph for get_chat_with_documents_graph""" - builder = StateGraph(RedboxState) - - # Sends - builder.add_node("s_chunk", empty_process) - builder.add_node("s_group_1", empty_process) - builder.add_node("s_group_2", empty_process) - builder.add_node("p_too_large_error", empty_process) - - builder.add_node( - "p_summarise_each_document", - build_merge_pattern(prompt_set=PromptSet.ChatwithDocsMapReduce), - ) - builder.add_node( - "p_summarise_document_by_document", - build_merge_pattern(prompt_set=PromptSet.ChatwithDocsMapReduce), - ) - builder.add_node( - "p_summarise", - build_stuff_pattern( - prompt_set=PromptSet.ChatwithDocs, - final_response_chain=True, - ), - ) - builder.add_node("p_clear_documents", clear_documents_process) - - builder.add_node("d_groups_have_multiple_docs", empty_process) - builder.add_node("d_doc_summaries_bigger_than_context", empty_process) - builder.add_node("d_single_doc_summaries_bigger_than_context", empty_process) - - # Edges - builder.add_edge(START, "s_chunk") - builder.add_conditional_edges( - "s_chunk", - build_document_chunk_send("p_summarise_each_document"), - path_map=["p_summarise_each_document"], - ) - builder.add_edge("p_summarise_each_document", "d_groups_have_multiple_docs") - - builder.add_conditional_edges( - "d_groups_have_multiple_docs", - multiple_docs_in_group_conditional, - { - True: "s_group_1", - False: "d_doc_summaries_bigger_than_context", - }, - ) - - builder.add_conditional_edges( - "s_group_1", - build_document_group_send("d_single_doc_summaries_bigger_than_context"), - path_map=["d_single_doc_summaries_bigger_than_context"], - ) - builder.add_conditional_edges( - "d_single_doc_summaries_bigger_than_context", - build_documents_bigger_than_context_conditional(PromptSet.ChatwithDocsMapReduce), - { - True: "p_too_large_error", - False: "s_group_2", - }, - ) - builder.add_conditional_edges( - "s_group_2", - build_document_group_send("p_summarise_document_by_document"), - path_map=["p_summarise_document_by_document"], - ) - builder.add_edge("p_summarise_document_by_document", "d_doc_summaries_bigger_than_context") - - builder.add_conditional_edges( - "d_doc_summaries_bigger_than_context", - build_documents_bigger_than_context_conditional(PromptSet.ChatwithDocs), - { - True: "p_too_large_error", - False: "p_summarise", - }, - ) - builder.add_edge("p_summarise", "p_clear_documents") - - builder.add_edge("p_too_large_error", END) - builder.add_edge("p_clear_documents", END) - - return builder.compile() - - def get_chat_with_documents_graph( retriever: VectorStoreRetriever, debug: bool = False, @@ -143,10 +54,6 @@ def get_chat_with_documents_graph( # Processes builder.add_node("p_pass_question_to_text", build_passthrough_pattern()) builder.add_node("p_set_chat_docs_route", build_set_route_pattern(route=ChatRoute.chat_with_docs)) - builder.add_node( - "p_set_chat_docs_map_reduce_route", - build_set_route_pattern(route=ChatRoute.chat_with_docs_map_reduce), - ) builder.add_node( "p_summarise", build_stuff_pattern( @@ -171,8 +78,6 @@ def get_chat_with_documents_graph( ), ) - builder.add_node("chat_with_documents_large", get_chat_with_documents_large_graph()) - # Decisions builder.add_node("d_request_handler_from_total_tokens", empty_process) @@ -184,24 +89,15 @@ def get_chat_with_documents_graph( build_total_tokens_request_handler_conditional(PromptSet.ChatwithDocsMapReduce), { "max_exceeded": "p_too_large_error", - "context_exceeded": "p_set_chat_docs_map_reduce_route", + "context_exceeded": "p_too_large_error", "pass": "p_set_chat_docs_route", }, ) builder.add_edge("p_set_chat_docs_route", "p_retrieve_all_chunks") - builder.add_edge("p_set_chat_docs_map_reduce_route", "p_retrieve_all_chunks") - builder.add_conditional_edges( - "p_retrieve_all_chunks", - lambda s: s.route_name, - { - ChatRoute.chat_with_docs: "p_summarise", - ChatRoute.chat_with_docs_map_reduce: "chat_with_documents_large", - }, - ) + builder.add_edge("p_retrieve_all_chunks", "p_summarise") builder.add_edge("p_summarise", "p_clear_documents") builder.add_edge("p_clear_documents", END) builder.add_edge("p_too_large_error", END) - builder.add_edge("chat_with_documents_large", END) return builder.compile(debug=debug) diff --git a/redbox-core/redbox/loader/ingester.py b/redbox-core/redbox/loader/ingester.py index 5c04c0bc4..4db8232ec 100644 --- a/redbox-core/redbox/loader/ingester.py +++ b/redbox-core/redbox/loader/ingester.py @@ -1,16 +1,10 @@ from io import BytesIO -from typing import TYPE_CHECKING from redbox.loader.loaders import UnstructuredChunkLoader from redbox.models.settings import get_settings from redbox.models.file import ChunkResolution, UploadedFileMetadata -if TYPE_CHECKING: - from mypy_boto3_s3.client import S3Client -else: - S3Client = object - env = get_settings() diff --git a/redbox-core/redbox/loader/loaders.py b/redbox-core/redbox/loader/loaders.py index 767c04dcc..38b1fe4a9 100644 --- a/redbox-core/redbox/loader/loaders.py +++ b/redbox-core/redbox/loader/loaders.py @@ -1,7 +1,6 @@ import logging from datetime import datetime from io import BytesIO -from typing import TYPE_CHECKING import requests import tiktoken @@ -13,11 +12,6 @@ encoding = tiktoken.get_encoding("cl100k_base") -if TYPE_CHECKING: - from mypy_boto3_s3.client import S3Client -else: - S3Client = object - class UnstructuredChunkLoader: """ diff --git a/redbox-core/redbox/models/chain.py b/redbox-core/redbox/models/chain.py index d32531b98..299f3e25f 100644 --- a/redbox-core/redbox/models/chain.py +++ b/redbox-core/redbox/models/chain.py @@ -11,7 +11,6 @@ from langchain_core.documents import Document from langchain_core.messages import AnyMessage from langgraph.graph.message import add_messages -from langgraph.managed.is_last_step import RemainingStepsManager from pydantic import BaseModel, Field from redbox.models import prompts @@ -235,7 +234,6 @@ class RedboxState(BaseModel): route_name: str | None = None metadata: Annotated[RequestMetadata | None, metadata_reducer] = None citations: list[Citation] | None = None - steps_left: Annotated[int | None, RemainingStepsManager] = None messages: Annotated[list[AnyMessage], add_messages] = Field(default_factory=list) @property diff --git a/redbox-core/redbox/models/chat.py b/redbox-core/redbox/models/chat.py index 74c5ae046..de5341c76 100644 --- a/redbox-core/redbox/models/chat.py +++ b/redbox-core/redbox/models/chat.py @@ -4,7 +4,6 @@ class ChatRoute(StrEnum): chat = "chat" chat_with_docs = "chat/documents" - chat_with_docs_map_reduce = "chat/documents/large" class ErrorRoute(StrEnum): diff --git a/redbox-core/tests/graph/nodes/__init__.py b/redbox-core/tests/graph/nodes/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/redbox-core/tests/graph/nodes/test_sends.py b/redbox-core/tests/graph/nodes/test_sends.py deleted file mode 100644 index e94b5f373..000000000 --- a/redbox-core/tests/graph/nodes/test_sends.py +++ /dev/null @@ -1,71 +0,0 @@ -from uuid import uuid4 - -from langchain_core.documents import Document -from langgraph.constants import Send - -from redbox.graph.nodes.sends import build_document_chunk_send, build_document_group_send -from redbox.models.chain import DocumentState, RedboxQuery, RedboxState - - -def test_build_document_group_send(): - target = "my-target" - request = RedboxQuery(question="what colour is the sky?", user_uuid=uuid4(), chat_history=[]) - documents = DocumentState( - groups={ - uuid4(): { - uuid4(): Document(page_content="Hello, world!"), - uuid4(): Document(page_content="Goodbye, world!"), - } - } - ) - - document_group_send = build_document_group_send("my-target") - state = RedboxState( - request=request, - documents=documents, - text=None, - route_name=None, - ) - actual = document_group_send(state) - expected = [Send(node=target, arg=state)] - assert expected == actual - - -def test_build_document_chunk_send(): - target = "my-target" - request = RedboxQuery(question="what colour is the sky?", user_uuid=uuid4(), chat_history=[]) - - uuid_1 = uuid4() - doc_1 = Document(page_content="Hello, world!") - uuid_2 = uuid4() - doc_2 = Document(page_content="Goodbye, world!") - - document_chunk_send = build_document_chunk_send("my-target") - state = RedboxState( - request=request, - documents=DocumentState(groups={uuid_1: {uuid_1: doc_1}, uuid_2: {uuid_2: doc_2}}), - text=None, - route_name=None, - ) - actual = document_chunk_send(state) - expected = [ - Send( - node=target, - arg=RedboxState( - request=request, - documents=DocumentState(groups={uuid_1: {uuid_1: doc_1}}), - text=None, - route_name=None, - ), - ), - Send( - node=target, - arg=RedboxState( - request=request, - documents=DocumentState(groups={uuid_2: {uuid_2: doc_2}}), - text=None, - route_name=None, - ), - ), - ] - assert expected == actual diff --git a/redbox-core/tests/graph/test_app.py b/redbox-core/tests/graph/test_app.py index 56160d8de..2ca649ae8 100644 --- a/redbox-core/tests/graph/test_app.py +++ b/redbox-core/tests/graph/test_app.py @@ -109,38 +109,21 @@ def assert_number_of_events(num_of_events: int): RedboxTestData( number_of_docs=2, tokens_in_all_docs=140_000, - llm_responses=["Map Step Response"] * 2 + ["Testing Response 1"], - expected_route=ChatRoute.chat_with_docs_map_reduce, + llm_responses=["These documents are too large to work with."] * 2 + + ["These documents are too large to work with."], + expected_route=ErrorRoute.files_too_large, ), RedboxTestData( number_of_docs=4, tokens_in_all_docs=140_000, llm_responses=["Map Step Response"] * 4 + ["Merge Per Document Response"] * 2 - + ["Testing Response 1"], - expected_route=ChatRoute.chat_with_docs_map_reduce, + + ["These documents are too large to work with."], + expected_route=ErrorRoute.files_too_large, ), ], test_id="Chat with multiple docs", ), - generate_test_cases( - query=RedboxQuery( - question="What is AI?", - s3_keys=["s3_key"], - user_uuid=uuid4(), - chat_history=[], - permitted_s3_keys=["s3_key"], - ), - test_data=[ - RedboxTestData( - number_of_docs=2, - tokens_in_all_docs=200_000, - llm_responses=["Map Step Response"] * 2 + ["Merge Per Document Response"] + ["Testing Response 1"], - expected_route=ChatRoute.chat_with_docs_map_reduce, - ), - ], - test_id="Chat with large doc", - ), generate_test_cases( query=RedboxQuery( question="What is AI?", diff --git a/redbox-core/tests/graph/test_patterns.py b/redbox-core/tests/graph/test_patterns.py index cb4aca0b7..09ee5fb80 100644 --- a/redbox-core/tests/graph/test_patterns.py +++ b/redbox-core/tests/graph/test_patterns.py @@ -11,7 +11,6 @@ from redbox.chains.runnables import CannedChatLLM, build_chat_prompt_from_messages_runnable, build_llm_chain from redbox.graph.nodes.processes import ( build_chat_pattern, - build_merge_pattern, build_passthrough_pattern, build_retrieve_pattern, build_set_route_pattern, @@ -30,7 +29,7 @@ mock_retriever, mock_parameterised_retriever, ) -from redbox.transform import flatten_document_state, structure_documents_by_file_name +from redbox.transform import structure_documents_by_file_name LANGGRAPH_DEBUG = True @@ -228,56 +227,6 @@ def test_build_retrieve_pattern(test_case: RedboxChatTestCase, mock_retriever: B assert final_state.documents == structure_documents_by_file_name(test_case.docs) -MERGE_TEST_CASES = generate_test_cases( - query=RedboxQuery( - question="What is AI?", - s3_keys=["s3_key_1", "s3_key_2"], - user_uuid=uuid4(), - chat_history=[], - permitted_s3_keys=["s3_key_1", "s3_key_2"], - ), - test_data=[ - RedboxTestData( - number_of_docs=2, - tokens_in_all_docs=40_000, - llm_responses=["Testing Response 1"], - expected_route=ChatRoute.chat_with_docs, - ), - RedboxTestData( - number_of_docs=4, - tokens_in_all_docs=40_000, - llm_responses=["Testing Response 2"], - expected_route=ChatRoute.chat_with_docs, - ), - ], - test_id="Merge pattern", -) - - -@pytest.mark.parametrize(("test_case"), MERGE_TEST_CASES, ids=[t.test_id for t in MERGE_TEST_CASES]) -def test_build_merge_pattern(test_case: RedboxChatTestCase, mocker: MockerFixture): - """Tests a given state["request"] and state["documents"] correctly changes state["documents"].""" - llm = GenericFakeChatModel(messages=iter(test_case.test_data.llm_responses)) - state = RedboxState(request=test_case.query, documents=structure_documents_by_file_name(test_case.docs)) - - merge = build_merge_pattern(prompt_set=PromptSet.ChatwithDocsMapReduce, final_response_chain=True) - - mocker.patch("redbox.graph.nodes.processes.get_chat_llm", return_value=llm) - response = merge.invoke(state) - final_state = RedboxState(**response, request=test_case.query) - - response_documents = [doc for doc in flatten_document_state(final_state.documents) if doc is not None] - noned_documents = sum(1 for doc in final_state.documents.groups.values() for v in doc.values() if v is None) - - test_case_content = test_case.test_data.llm_responses[-1].content - - assert len(response_documents) == 1 - assert noned_documents == len(test_case.docs) - 1 - assert ( - response_documents[0].page_content == test_case_content - ), f"Expected document content: '{test_case_content}'. Received '{response_documents[0].page_content}'" - - STUFF_TEST_CASES = generate_test_cases( query=RedboxQuery( question="What is AI?", @@ -381,7 +330,7 @@ def test_empty_process(): ), documents=structure_documents_by_file_name([doc for doc in generate_docs(s3_key="s3_key")]), messages=[HumanMessage(content="Foo")], - route_name=ChatRoute.chat_with_docs_map_reduce, + route_name=ChatRoute.chat_with_docs, ) builder = StateGraph(RedboxState) @@ -403,7 +352,7 @@ def test_empty_process(): ), documents=structure_documents_by_file_name([doc for doc in generate_docs(s3_key="s3_key")]), messages=[HumanMessage(content="Foo")], - route_name=ChatRoute.chat_with_docs_map_reduce, + route_name=ChatRoute.chat_with_docs, ), RedboxState( request=RedboxQuery( @@ -411,7 +360,7 @@ def test_empty_process(): ), documents={}, messages=[HumanMessage(content="Foo")], - route_name=ChatRoute.chat_with_docs_map_reduce, + route_name=ChatRoute.chat_with_docs, ), ] diff --git a/utilities/draw_graph.py b/utilities/draw_graph.py index 5e49c4674..1bcad0e32 100644 --- a/utilities/draw_graph.py +++ b/utilities/draw_graph.py @@ -1,6 +1,8 @@ +from langchain_core.runnables import RunnablePassthrough + from redbox.app import Redbox -app = Redbox() +app = Redbox(retriever=RunnablePassthrough()) -for g in ["root", "chat/documents", "chat/documents/large"]: +for g in ["root", "chat/documents"]: app.draw(graph_to_draw=g, output_path=f"../docs/architecture/graph/{g.replace('/', '_')}.png")