Skip to content

Commit

Permalink
Feature/remove large docs (#1251)
Browse files Browse the repository at this point in the history
* Feature/remove search (#1249)

* removed search subgraph

* remove parametrized retiever from get_root_graph

* removed parametrized retirver

* remove self route

* removed search keyword

* removed some unused functions

* graphs updated

* removed search subgraph

* remove parametrized retiever from get_root_graph

* removed parametrized retirver

* remove self route

* removed search keyword

* removed some unused functions

* graphs updated

* rmove large docs

* removed unused code

* Fixed graph drawing and regenerated architecture diagrams (#1268)

* Fixed graph drawing and regenerated architecture diagrams

* Ruff

---------

Co-authored-by: James Richards <39167172+jamesrichards4@users.noreply.github.com>
  • Loading branch information
gecBurton and jamesrichards4 authored Jan 7, 2025
1 parent 03be96d commit 030243b
Show file tree
Hide file tree
Showing 16 changed files with 21 additions and 507 deletions.
Binary file modified docs/architecture/graph/chat_documents.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/architecture/graph/chat_documents_large.png
Binary file not shown.
5 changes: 1 addition & 4 deletions redbox-core/redbox/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from redbox.graph.root import (
get_chat_with_documents_graph,
get_root_graph,
get_chat_with_documents_large_graph,
)
from redbox.models.chain import RedboxState
from redbox.models.graph import (
Expand Down Expand Up @@ -101,9 +100,7 @@ def draw(self, output_path=None, graph_to_draw: Literal["root", "chat_with_docum
if graph_to_draw == "root":
graph = self.graph.get_graph()
elif graph_to_draw == "chat/documents":
graph = get_chat_with_documents_graph(self.all_chunks_retriever).get_graph()
elif graph_to_draw == "chat/documents/large":
graph = get_chat_with_documents_large_graph().get_graph()
graph = get_chat_with_documents_graph(self.retriever).get_graph()
else:
raise Exception("Invalid graph_to_draw")

Expand Down
42 changes: 1 addition & 41 deletions redbox-core/redbox/graph/edges.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import logging
import re
from typing import Literal

from langchain_core.runnables import Runnable

from redbox.chains.components import get_tokeniser
from redbox.graph.nodes.processes import PromptSet
from redbox.models.chain import RedboxState, get_prompts
from redbox.transform import get_document_token_count

log = logging.getLogger()

Expand Down Expand Up @@ -40,50 +38,12 @@ def _total_tokens_request_handler_conditional(
if total_tokens > max_tokens_allowed:
return "max_exceeded"
elif total_tokens > token_budget_remaining_in_context:
return "context_exceeded"
return "max_exceeded"
else:
return "pass"

return _total_tokens_request_handler_conditional


def build_documents_bigger_than_context_conditional(prompt_set: PromptSet) -> Runnable:
"""Uses a set of prompts to build the correct conditional for exceeding the context window."""

def _documents_bigger_than_context_conditional(state: RedboxState) -> bool:
system_prompt, question_prompt = get_prompts(state, prompt_set)
token_budget = calculate_token_budget(state, system_prompt, question_prompt)

return get_document_token_count(state) > token_budget

return _documents_bigger_than_context_conditional


def documents_bigger_than_n_conditional(state: RedboxState) -> bool:
"""Do the documents meet a hard limit of document token size set in AI Settings."""
token_counts = get_document_token_count(state)
return sum(token_counts) > state.request.ai_settings.max_document_tokens


def documents_selected_conditional(state: RedboxState) -> bool:
return len(state.request.s3_keys) > 0


def multiple_docs_in_group_conditional(state: RedboxState) -> bool:
return any(len(group) > 1 for group in state.documents.groups.values())


def build_strings_end_text_conditional(*strings: str) -> Runnable:
"""Given a list of strings, returns the string if the end of state.last_message.content contains it."""
pattern = "|".join(re.escape(s) for s in strings)
regex = re.compile(pattern, re.IGNORECASE)

def _strings_end_text_conditional(state: RedboxState) -> str:
matches = regex.findall(state.last_message.content[-100:]) # padding for waffle
unique_matches = set(matches)

if len(unique_matches) == 1:
return unique_matches.pop().lower()
return "DEFAULT"

return _strings_end_text_conditional
152 changes: 4 additions & 148 deletions redbox-core/redbox/graph/nodes/processes.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
import json
import logging
import re
import textwrap
from collections.abc import Callable
from functools import reduce
from typing import Any, Iterable
from uuid import uuid4
from typing import Any

from langchain.schema import StrOutputParser
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel
from langchain_core.tools import StructuredTool
from langchain_core.vectorstores import VectorStoreRetriever

from redbox.chains.activity import log_activity
from redbox.chains.components import get_chat_llm, get_tokeniser
from redbox.chains.components import get_chat_llm
from redbox.chains.runnables import CannedChatLLM, build_llm_chain
from redbox.models import ChatRoute
from redbox.models.chain import DocumentState, PromptSet, RedboxState, RequestMetadata
from redbox.models.graph import ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG, RedboxActivityEvent, RedboxEventType
from redbox.transform import combine_documents, flatten_document_state
from redbox.models.graph import ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG
from redbox.transform import flatten_document_state

log = logging.getLogger(__name__)
re_keyword_pattern = re.compile(r"@(\w+)")
Expand Down Expand Up @@ -72,65 +66,6 @@ def _chat(state: RedboxState) -> dict[str, Any]:
return _chat


def build_merge_pattern(
prompt_set: PromptSet,
tools: list[StructuredTool] | None = None,
final_response_chain: bool = False,
) -> Runnable[RedboxState, dict[str, Any]]:
"""Returns a Runnable that uses state.request and state.documents to return one item in state.documents.
When combined with chunk send, will replace each Document with what's returned from the LLM.
When combined with group send, with combine all Documents and use the metadata of the first.
When used without a send, the first Document receieved defines the metadata.
If tools are supplied, can also set state["tool_calls"].
"""
tokeniser = get_tokeniser()

@RunnableLambda
def _merge(state: RedboxState) -> dict[str, Any]:
llm = get_chat_llm(state.request.ai_settings.chat_backend, tools=tools)

if not state.documents.groups:
return {"documents": None}

flattened_documents = flatten_document_state(state.documents)
merged_document = reduce(lambda left, right: combine_documents(left, right), flattened_documents)

merge_state = RedboxState(
request=state.request,
documents=DocumentState(
groups={merged_document.metadata["uuid"]: {merged_document.metadata["uuid"]: merged_document}}
),
)

merge_response = build_llm_chain(
prompt_set=prompt_set, llm=llm, final_response_chain=final_response_chain
).invoke(merge_state)

merged_document.page_content = merge_response["messages"][-1].content
request_metadata = merge_response["metadata"]
merged_document.metadata["token_count"] = len(tokeniser.encode(merged_document.page_content))

group_uuid = next(iter(state.documents.groups or {}), uuid4())
document_uuid = merged_document.metadata.get("uuid", uuid4())

# Clear old documents, add new one
document_state = state.documents.groups.copy()

for group in document_state:
for document in document_state[group]:
document_state[group][document] = None

document_state[group_uuid][document_uuid] = merged_document

return {"documents": DocumentState(groups=document_state), "metadata": request_metadata}

return _merge


def build_stuff_pattern(
prompt_set: PromptSet,
output_parser: Runnable = None,
Expand Down Expand Up @@ -174,28 +109,6 @@ def _set_route(state: RedboxState) -> dict[str, Any]:
return RunnableLambda(_set_route).with_config(tags=[ROUTE_NAME_TAG])


def build_set_self_route_from_llm_answer(
conditional: Callable[[str], bool],
true_condition_state_update: dict,
false_condition_state_update: dict,
final_route_response: bool = True,
) -> Runnable[RedboxState, dict[str, Any]]:
"""A Runnable which sets the route based on a conditional on state['text']"""

@RunnableLambda
def _set_self_route_from_llm_answer(state: RedboxState):
llm_response = state.last_message.content
if conditional(llm_response):
return true_condition_state_update
else:
return false_condition_state_update

runnable = _set_self_route_from_llm_answer
if final_route_response:
runnable = _set_self_route_from_llm_answer.with_config(tags=[ROUTE_NAME_TAG])
return runnable


def build_passthrough_pattern() -> Runnable[RedboxState, dict[str, Any]]:
"""Returns a Runnable that uses state["request"] to set state["text"]."""

Expand Down Expand Up @@ -258,62 +171,5 @@ def clear_documents_process(state: RedboxState) -> dict[str, Any]:
return {"documents": DocumentState(groups={group_id: None for group_id in documents.groups})}


def report_sources_process(state: RedboxState) -> None:
"""A Runnable which reports the documents in the state as sources."""
if citations_state := state.citations:
dispatch_custom_event(RedboxEventType.on_citations_report, citations_state)
elif document_state := state.documents:
dispatch_custom_event(RedboxEventType.on_source_report, flatten_document_state(document_state))


def empty_process(state: RedboxState) -> None:
return None


def build_log_node(message: str) -> Runnable[RedboxState, dict[str, Any]]:
"""A Runnable which logs the current state in a compact way"""

@RunnableLambda
def _log_node(state: RedboxState):
log.info(
json.dumps(
{
"user_uuid": str(state.request.user_uuid),
"document_metadata": {
group_id: {doc_id: d.metadata for doc_id, d in group_documents.items()}
for group_id, group_documents in state.documents.group
},
"messages": (textwrap.shorten(state.last_message.content, width=32, placeholder="...")),
"route": state.route_name,
"message": message,
}
)
)
return None

return _log_node


def build_activity_log_node(
log_message: RedboxActivityEvent
| Callable[[RedboxState], Iterable[RedboxActivityEvent]]
| Callable[[RedboxState], Iterable[RedboxActivityEvent]],
):
"""
A Runnable which emits activity events based on the state. The message should either be a static message to log, or a function which returns an activity event or an iterator of them
"""

@RunnableLambda
def _activity_log_node(state: RedboxState):
if isinstance(log_message, RedboxActivityEvent):
log_activity(log_message)
else:
response = log_message(state)
if isinstance(response, RedboxActivityEvent):
log_activity(response)
else:
for message in response:
log_activity(message)
return None

return _activity_log_node
43 changes: 0 additions & 43 deletions redbox-core/redbox/graph/nodes/sends.py

This file was deleted.

Loading

0 comments on commit 030243b

Please sign in to comment.