Skip to content

Commit

Permalink
fix: Groups conversations based on the user's messages.
Browse files Browse the repository at this point in the history
Closes: #419

Before we could use the `chat_id` at the output messages as means
to group the messages into conversations. This logic is not working
anymore.

The new logic takes into account the user messages provided as
input to the LLM to map the messages into conversations. Usually LLMs
receive all last user messages. Example:
```
req1 = {messages:[{"role": "user", "content": "hello"}]}
req2 = {messages:[{"role": "user", "content": "hello"}, {"role": "user", "content": "how are you?}]}
```

In this last example, `req1` and `req2` should be mapped together to
form a conversation
  • Loading branch information
aponcedeleonch committed Jan 10, 2025
1 parent ff4a3a7 commit 78c2f2c
Show file tree
Hide file tree
Showing 3 changed files with 400 additions and 148 deletions.
237 changes: 157 additions & 80 deletions src/codegate/dashboard/post_processing.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import asyncio
import json
import re
from typing import List, Optional, Tuple, Union
from collections import defaultdict
from typing import List, Optional, Union

import structlog

from codegate.dashboard.request_models import (
AlertConversation,
ChatMessage,
Conversation,
PartialConversation,
PartialQuestionAnswer,
PartialQuestions,
QuestionAnswer,
)
from codegate.db.models import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow
Expand Down Expand Up @@ -74,60 +76,57 @@ async def parse_request(request_str: str) -> Optional[str]:
return None

# Only respond with the latest message
return messages[-1]
return messages


async def parse_output(output_str: str) -> Tuple[Optional[str], Optional[str]]:
async def parse_output(output_str: str) -> Optional[str]:
"""
Parse the output string from the pipeline and return the message and chat_id.
Parse the output string from the pipeline and return the message.
"""
try:
if output_str is None:
return None, None
return None

output = json.loads(output_str)
except Exception as e:
logger.warning(f"Error parsing output: {output_str}. {e}")
return None, None
return None

def _parse_single_output(single_output: dict) -> str:
single_chat_id = single_output.get("id")
single_output_message = ""
for choice in single_output.get("choices", []):
if not isinstance(choice, dict):
continue
content_dict = choice.get("delta", {}) or choice.get("message", {})
single_output_message += content_dict.get("content", "")
return single_output_message, single_chat_id
return single_output_message

full_output_message = ""
chat_id = None
if isinstance(output, list):
for output_chunk in output:
output_message, output_chat_id = "", None
output_message = ""
if isinstance(output_chunk, dict):
output_message, output_chat_id = _parse_single_output(output_chunk)
output_message = _parse_single_output(output_chunk)
elif isinstance(output_chunk, str):
try:
output_decoded = json.loads(output_chunk)
output_message, output_chat_id = _parse_single_output(output_decoded)
output_message = _parse_single_output(output_decoded)
except Exception:
logger.error(f"Error reading chunk: {output_chunk}")
else:
logger.warning(
f"Could not handle output: {output_chunk}", out_type=type(output_chunk)
)
chat_id = chat_id or output_chat_id
full_output_message += output_message
elif isinstance(output, dict):
full_output_message, chat_id = _parse_single_output(output)
full_output_message = _parse_single_output(output)

return full_output_message, chat_id
return full_output_message


async def _get_question_answer(
row: Union[GetPromptWithOutputsRow, GetAlertsWithPromptAndOutputRow]
) -> Tuple[Optional[QuestionAnswer], Optional[str]]:
) -> Optional[PartialQuestionAnswer]:
"""
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
Expand All @@ -137,17 +136,19 @@ async def _get_question_answer(
request_task = tg.create_task(parse_request(row.request))
output_task = tg.create_task(parse_output(row.output))

request_msg_str = request_task.result()
output_msg_str, chat_id = output_task.result()
request_user_msgs = request_task.result()
output_msg_str = output_task.result()

# If we couldn't parse the request or output, return None
if not request_msg_str:
return None, None
# If we couldn't parse the request, return None
if not request_user_msgs:
return None

request_message = ChatMessage(
message=request_msg_str,
request_message = PartialQuestions(
messages=request_user_msgs,
timestamp=row.timestamp,
message_id=row.id,
provider=row.provider,
type=row.type,
)
if output_msg_str:
output_message = ChatMessage(
Expand All @@ -157,28 +158,7 @@ async def _get_question_answer(
)
else:
output_message = None
chat_id = row.id
return QuestionAnswer(question=request_message, answer=output_message), chat_id


async def parse_get_prompt_with_output(
row: GetPromptWithOutputsRow,
) -> Optional[PartialConversation]:
"""
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
The row contains the raw request and output strings from the pipeline.
"""
question_answer, chat_id = await _get_question_answer(row)
if not question_answer or not chat_id:
return None
return PartialConversation(
question_answer=question_answer,
provider=row.provider,
type=row.type,
chat_id=chat_id,
request_timestamp=row.timestamp,
)
return PartialQuestionAnswer(partial_questions=request_message, answer=output_message)


def parse_question_answer(input_text: str) -> str:
Expand All @@ -195,50 +175,135 @@ def parse_question_answer(input_text: str) -> str:
return input_text


def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[PartialQuestions]]:
"""
A PartialQuestion is an object that contains several user messages provided from a
chat conversation. Example:
- PartialQuestion(messages=["Hello"], timestamp=2022-01-01T00:00:00Z)
- PartialQuestion(messages=["Hello", "How are you?"], timestamp=2022-01-01T00:00:01Z)
In the above example both PartialQuestions are part of the same conversation and should be
matched together.
Group PartialQuestions objects such that:
- If one PartialQuestion (pq) is a subset of another pq's messages, group them together.
- If multiple subsets exist for the same superset, choose only the one
closest in timestamp to the superset.
- Leave any unpaired pq by itself.
- Finally, sort the resulting groups by the earliest timestamp in each group.
"""
# 1) Sort by length of messages descending (largest/most-complete first),
# then by timestamp ascending for stable processing.
pq_list_sorted = sorted(pq_list, key=lambda x: (-len(x.messages), x.timestamp))

used = set()
groups = []

# 2) Iterate in order of "largest messages first"
for sup in pq_list_sorted:
if sup.message_id in used:
continue # Already grouped

# Find all potential subsets of 'sup' that are not yet used
# (If sup's messages == sub's messages, that also counts, because sub ⊆ sup)
possible_subsets = []
for sub in pq_list_sorted:
if sub.message_id == sup.message_id:
continue
if sub.message_id in used:
continue
if (
set(sub.messages).issubset(set(sup.messages))
and sub.provider == sup.provider
and set(sub.messages) != set(sup.messages)
):
possible_subsets.append(sub)

# 3) If there are no subsets, this sup stands alone
if not possible_subsets:
groups.append([sup])
used.add(sup.message_id)
else:
# 4) Group subsets by messages to discard duplicates e.g.: 2 subsets with single 'hello'
subs_group_by_messages = defaultdict(list)
for q in possible_subsets:
subs_group_by_messages[tuple(q.messages)].append(q)

new_group = [sup]
used.add(sup.message_id)
for subs_same_message in subs_group_by_messages.values():
# If more than one pick the one subset closest in time to sup
closest_subset = min(
subs_same_message, key=lambda s: abs(s.timestamp - sup.timestamp)
)
new_group.append(closest_subset)
used.add(closest_subset.message_id)
groups.append(new_group)

# 5) Sort the groups by the earliest timestamp within each group
groups.sort(key=lambda g: min(pq.timestamp for pq in g))
return groups


def _get_question_answer_from_partial(
partial_question_answer: PartialQuestionAnswer,
) -> QuestionAnswer:
"""
Get a QuestionAnswer object from a PartialQuestionAnswer object.
"""
# Get the last user message as the question
question = ChatMessage(
message=partial_question_answer.partial_questions.messages[-1],
timestamp=partial_question_answer.partial_questions.timestamp,
message_id=partial_question_answer.partial_questions.message_id,
)

return QuestionAnswer(question=question, answer=partial_question_answer.answer)


async def match_conversations(
partial_conversations: List[Optional[PartialConversation]],
partial_question_answers: List[Optional[PartialQuestionAnswer]],
) -> List[Conversation]:
"""
Match partial conversations to form a complete conversation.
"""
convers = {}
for partial_conversation in partial_conversations:
if not partial_conversation:
continue

# Group by chat_id
if partial_conversation.chat_id not in convers:
convers[partial_conversation.chat_id] = []
convers[partial_conversation.chat_id].append(partial_conversation)
valid_partial_qas = [
partial_qas for partial_qas in partial_question_answers if partial_qas is not None
]
grouped_partial_questions = _group_partial_messages(
[partial_qs_a.partial_questions for partial_qs_a in valid_partial_qas]
)

# Sort by timestamp
sorted_convers = {
chat_id: sorted(conversations, key=lambda x: x.request_timestamp)
for chat_id, conversations in convers.items()
}
# Create the conversation objects
conversations = []
for chat_id, sorted_convers in sorted_convers.items():
for group in grouped_partial_questions:
questions_answers = []
first_partial_conversation = None
for partial_conversation in sorted_convers:
first_partial_qa = None
for partial_question in sorted(group, key=lambda x: x.timestamp):
# Partial questions don't contain the answer, so we need to find the corresponding
selected_partial_qa = None
for partial_qa in valid_partial_qas:
if partial_question.message_id == partial_qa.partial_questions.message_id:
selected_partial_qa = partial_qa
break

# check if we have an answer, otherwise do not add it
if partial_conversation.question_answer.answer is not None:
first_partial_conversation = partial_conversation
partial_conversation.question_answer.question.message = parse_question_answer(
partial_conversation.question_answer.question.message
if selected_partial_qa.answer is not None:
# if we don't have a first question, set it
first_partial_qa = first_partial_qa or selected_partial_qa
question_answer = _get_question_answer_from_partial(selected_partial_qa)
question_answer.question.message = parse_question_answer(
question_answer.question.message
)
questions_answers.append(partial_conversation.question_answer)
questions_answers.append(question_answer)

# only add conversation if we have some answers
if len(questions_answers) > 0 and first_partial_conversation is not None:
if len(questions_answers) > 0 and first_partial_qa is not None:
conversations.append(
Conversation(
question_answers=questions_answers,
provider=first_partial_conversation.provider,
type=first_partial_conversation.type,
chat_id=chat_id,
conversation_timestamp=sorted_convers[0].request_timestamp,
provider=first_partial_qa.partial_questions.provider,
type=first_partial_qa.partial_questions.type,
chat_id=first_partial_qa.partial_questions.message_id,
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
)
)

Expand All @@ -254,10 +319,10 @@ async def parse_messages_in_conversations(

# Parse the prompts and outputs in parallel
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(parse_get_prompt_with_output(row)) for row in prompts_outputs]
partial_conversations = [task.result() for task in tasks]
tasks = [tg.create_task(_get_question_answer(row)) for row in prompts_outputs]
partial_question_answers = [task.result() for task in tasks]

conversations = await match_conversations(partial_conversations)
conversations = await match_conversations(partial_question_answers)
return conversations


Expand All @@ -269,15 +334,17 @@ async def parse_row_alert_conversation(
The row contains the raw request and output strings from the pipeline.
"""
question_answer, chat_id = await _get_question_answer(row)
if not question_answer or not chat_id:
partial_qa = await _get_question_answer(row)
if not partial_qa:
return None

question_answer = _get_question_answer_from_partial(partial_qa)

conversation = Conversation(
question_answers=[question_answer],
provider=row.provider,
type=row.type,
chat_id=chat_id or "chat-id-not-found",
chat_id=row.id,
conversation_timestamp=row.timestamp,
)
code_snippet = json.loads(row.code_snippet) if row.code_snippet else None
Expand Down Expand Up @@ -311,3 +378,13 @@ async def parse_get_alert_conversation(
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(parse_row_alert_conversation(row)) for row in alerts_conversations]
return [task.result() for task in tasks if task.result() is not None]


if __name__ == "__main__":
from codegate.db.connection import DbReader

db_reader = DbReader()
prompts_outputs = asyncio.run(db_reader.get_prompts_with_output())

parsed_messages = asyncio.run(parse_messages_in_conversations(prompts_outputs))
print(parsed_messages)
19 changes: 14 additions & 5 deletions src/codegate/dashboard/request_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,25 @@ class QuestionAnswer(BaseModel):
answer: Optional[ChatMessage]


class PartialConversation(BaseModel):
class PartialQuestions(BaseModel):
"""
Represents a partial conversation obtained from a DB row.
Represents all user messages obtained from a DB row.
"""

question_answer: QuestionAnswer
messages: List[str]
timestamp: datetime.datetime
message_id: str
provider: Optional[str]
type: str
chat_id: str
request_timestamp: datetime.datetime


class PartialQuestionAnswer(BaseModel):
"""
Represents a partial conversation.
"""

partial_questions: PartialQuestions
answer: Optional[ChatMessage]


class Conversation(BaseModel):
Expand Down
Loading

0 comments on commit 78c2f2c

Please sign in to comment.