Skip to content

Commit

Permalink
added fk from file to chat (#1254)
Browse files Browse the repository at this point in the history
  • Loading branch information
gecBurton authored Dec 12, 2024
1 parent 298d869 commit b3716d6
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 57 deletions.
19 changes: 19 additions & 0 deletions django_app/redbox_app/redbox_core/migrations/0070_file_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Generated by Django 5.1.4 on 2024-12-12 14:18

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('redbox_core', '0069_alter_citation_source'),
]

operations = [
migrations.AddField(
model_name='file',
name='chat',
field=models.ForeignKey(blank=True, help_text='chat that this document belongs to, which may be nothing for now', null=True, on_delete=django.db.models.deletion.CASCADE, to='redbox_core.chat'),
),
]
115 changes: 61 additions & 54 deletions django_app/redbox_app/redbox_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,60 @@ def get_initials(self) -> str:
return ""


class Chat(UUIDPrimaryKeyBase, TimeStampedModel, AbstractAISettings):
name = models.TextField(max_length=1024, null=False, blank=False)
user = models.ForeignKey(User, on_delete=models.CASCADE)
archived = models.BooleanField(default=False, null=True, blank=True)

# Exit feedback - this is separate to the ratings for individual ChatMessages
feedback_achieved = models.BooleanField(
null=True,
blank=True,
help_text="Did Redbox do what you needed it to in this chat?",
)
feedback_saved_time = models.BooleanField(null=True, blank=True, help_text="Did Redbox help save you time?")
feedback_improved_work = models.BooleanField(
null=True, blank=True, help_text="Did Redbox help to improve your work?"
)
feedback_notes = models.TextField(null=True, blank=True, help_text="Do you want to tell us anything further?")

def __str__(self) -> str: # pragma: no cover
return self.name or ""

@override
def save(self, force_insert=False, force_update=False, using=None, update_fields=None):
self.name = sanitise_string(self.name)

if self.chat_backend_id is None:
self.chat_backend = self.user.ai_settings.chat_backend

if self.temperature is None:
self.temperature = self.user.ai_settings.temperature

super().save(force_insert, force_update, using, update_fields)

@classmethod
def get_ordered_by_last_message_date(
cls, user: User, exclude_chat_ids: Collection[uuid.UUID] | None = None
) -> Sequence["Chat"]:
"""Returns all chat histories for a given user, ordered by the date of the latest message."""
exclude_chat_ids = exclude_chat_ids or []
return (
cls.objects.filter(user=user, archived=False)
.exclude(id__in=exclude_chat_ids)
.annotate(latest_message_date=Max("chatmessage__created_at"))
.order_by("-latest_message_date")
)

@property
def newest_message_date(self) -> date:
return self.chatmessage_set.aggregate(newest_date=Max("created_at"))["newest_date"].date()

@property
def date_group(self):
return get_date_group(self.newest_message_date)


class InactiveFileError(ValueError):
def __init__(self, file):
super().__init__(f"{file.pk} is inactive, status is {file.status}")
Expand Down Expand Up @@ -574,6 +628,13 @@ class Status(models.TextChoices):
null=True,
help_text="error, if any, encountered during ingest",
)
chat = models.ForeignKey(
Chat,
on_delete=models.CASCADE,
null=True,
blank=True,
help_text="chat that this document belongs to, which may be nothing for now",
)

def __str__(self) -> str: # pragma: no cover
return self.file_name
Expand Down Expand Up @@ -685,60 +746,6 @@ def get_ordered_by_citation_priority(cls, chat_message_id: uuid.UUID) -> Sequenc
)


class Chat(UUIDPrimaryKeyBase, TimeStampedModel, AbstractAISettings):
name = models.TextField(max_length=1024, null=False, blank=False)
user = models.ForeignKey(User, on_delete=models.CASCADE)
archived = models.BooleanField(default=False, null=True, blank=True)

# Exit feedback - this is separate to the ratings for individual ChatMessages
feedback_achieved = models.BooleanField(
null=True,
blank=True,
help_text="Did Redbox do what you needed it to in this chat?",
)
feedback_saved_time = models.BooleanField(null=True, blank=True, help_text="Did Redbox help save you time?")
feedback_improved_work = models.BooleanField(
null=True, blank=True, help_text="Did Redbox help to improve your work?"
)
feedback_notes = models.TextField(null=True, blank=True, help_text="Do you want to tell us anything further?")

def __str__(self) -> str: # pragma: no cover
return self.name or ""

@override
def save(self, force_insert=False, force_update=False, using=None, update_fields=None):
self.name = sanitise_string(self.name)

if self.chat_backend_id is None:
self.chat_backend = self.user.ai_settings.chat_backend

if self.temperature is None:
self.temperature = self.user.ai_settings.temperature

super().save(force_insert, force_update, using, update_fields)

@classmethod
def get_ordered_by_last_message_date(
cls, user: User, exclude_chat_ids: Collection[uuid.UUID] | None = None
) -> Sequence["Chat"]:
"""Returns all chat histories for a given user, ordered by the date of the latest message."""
exclude_chat_ids = exclude_chat_ids or []
return (
cls.objects.filter(user=user, archived=False)
.exclude(id__in=exclude_chat_ids)
.annotate(latest_message_date=Max("chatmessage__created_at"))
.order_by("-latest_message_date")
)

@property
def newest_message_date(self) -> date:
return self.chatmessage_set.aggregate(newest_date=Max("created_at"))["newest_date"].date()

@property
def date_group(self):
return get_date_group(self.newest_message_date)


class Citation(UUIDPrimaryKeyBase, TimeStampedModel):
class Origin(models.TextChoices):
WIKIPEDIA = "Wikipedia", _("wikipedia")
Expand Down
7 changes: 4 additions & 3 deletions django_app/redbox_app/redbox_core/views/document_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get(self, request: HttpRequest) -> HttpResponse:
return self.build_response(request)

@method_decorator(login_required)
def post(self, request: HttpRequest) -> HttpResponse:
def post(self, request: HttpRequest, chat_id: uuid.UUID | None = None) -> HttpResponse:
errors: MutableSequence[str] = []

uploaded_files: MutableSequence[UploadedFile] = request.FILES.getlist("uploadDocs")
Expand All @@ -88,7 +88,7 @@ def post(self, request: HttpRequest) -> HttpResponse:
if not errors:
for uploaded_file in uploaded_files:
# ingest errors are handled differently, as the other documents have started uploading by this point
request.session["ingest_errors"] = self.ingest_file(uploaded_file, request.user)
request.session["ingest_errors"] = self.ingest_file(uploaded_file, request.user, chat_id)
return redirect(reverse("documents"))

return self.build_response(request, errors)
Expand Down Expand Up @@ -125,13 +125,14 @@ def validate_uploaded_file(uploaded_file: UploadedFile) -> Sequence[str]:
return errors

@staticmethod
def ingest_file(uploaded_file: UploadedFile, user: User) -> Sequence[str]:
def ingest_file(uploaded_file: UploadedFile, user: User, chat_id: uuid.UUID | None = None) -> Sequence[str]:
try:
logger.info("getting file from s3")
file = File.objects.create(
status=File.Status.processing.value,
user=user,
original_file=uploaded_file,
chat_id=chat_id,
)
except (ValueError, FieldError, ValidationError) as e:
logger.exception("Error creating File model object for %s.", uploaded_file, exc_info=e)
Expand Down
1 change: 1 addition & 0 deletions django_app/redbox_app/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
file_urlpatterns = [
path("documents/", views.DocumentView.as_view(), name="documents"),
path("upload/", views.UploadView.as_view(), name="upload"),
path("upload/<uuid:chat_id>/", views.UploadView.as_view(), name="upload"),
path("remove-doc/<uuid:doc_id>", views.remove_doc_view, name="remove-doc"),
]

Expand Down
32 changes: 32 additions & 0 deletions django_app/tests/views/test_document_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,38 @@ def test_upload_view(alice, client, file_pdf_path: Path, s3_client):
assert response.status_code == HTTPStatus.FOUND
assert response.url == "/documents/"

file = File.objects.get(original_file=file_name.replace(" ", "_"))
assert file.chat is None


@pytest.mark.django_db()
def test_upload_view_with_chat(chat, client, file_pdf_path: Path, s3_client):
"""
Given that the object store does not have a file with our test file in it
When we POST our test file to /upload/
We Expect to see this file in the object store
"""
file_name = f"{chat.user.email}/{file_pdf_path.name}"

# we begin by removing any file in minio that has this key
s3_client.delete_object(Bucket=settings.BUCKET_NAME, Key=file_name.replace(" ", "_"))

assert not file_exists(s3_client, file_name)

client.force_login(chat.user)

with file_pdf_path.open("rb") as f:
url = reverse("upload", args=(chat.pk,))

response = client.post(url, {"uploadDocs": f})

assert file_exists(s3_client, file_name)
assert response.status_code == HTTPStatus.FOUND
assert response.url == "/documents/"

file = File.objects.get(original_file=file_name.replace(" ", "_"))
assert file.chat == chat


@pytest.mark.django_db()
def test_document_upload_status(client, alice, file_pdf_path: Path, s3_client):
Expand Down

0 comments on commit b3716d6

Please sign in to comment.