Skip to content

Commit

Permalink
Added token usage and price
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Feb 22, 2023
1 parent 131f03e commit 40a6e8c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
23 changes: 17 additions & 6 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from langchain.llms import OpenAI
from langchain.llms.base import LLM
from langchain.chains import LLMChain
from langchain.callbacks import get_openai_callback


@dataclass
Expand All @@ -22,6 +23,7 @@ class Answer:
references: str
formatted_answer: str
passages: Dict[str, str]
tokens: int

def __str__(self) -> str:
"""Return the answer as a string."""
Expand Down Expand Up @@ -164,17 +166,24 @@ def query(
):
if k < max_sources:
raise ValueError("k should be greater than max_sources")
context_str, citations = self.get_evidence(query, k=k, max_sources=max_sources)
tokens = 0
with get_openai_callback() as cb:
context_str, citations = self.get_evidence(
query, k=k, max_sources=max_sources
)
tokens += cb.total_tokens
bib = dict()
passages = dict()
if len(context_str) < 10:
answer = "I cannot answer this question due to insufficient information."
else:
answer = self.qa_chain.run(
question=query, context_str=context_str, length=length_prompt
)[1:]
if maybe_is_truncated(answer):
answer = self.edit_chain.run(question=query, answer=answer)
with get_openai_callback() as cb:
answer = self.qa_chain.run(
question=query, context_str=context_str, length=length_prompt
)[1:]
if maybe_is_truncated(answer):
answer = self.edit_chain.run(question=query, answer=answer)
tokens += cb.total_tokens
for key, citation, summary, text in citations:
# do check for whole key (so we don't catch Callahan2019a with Callahan2019)
skey = key.split(" ")[0]
Expand All @@ -187,11 +196,13 @@ def query(
formatted_answer = f"Question: {query}\n\n{answer}\n"
if len(bib) > 0:
formatted_answer += f"\nReferences\n\n{bib_str}\n"
formatted_answer += f"\nTokens Used: {tokens} Cost: ${tokens/1000 * 0.02:.2f}"
return Answer(
answer=answer,
question=query,
formatted_answer=formatted_answer,
context=context_str,
references=bib_str,
passages=passages,
tokens=tokens,
)
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.9"
__version__ = "0.0.10"

0 comments on commit 40a6e8c

Please sign in to comment.