Skip to content

Commit

Permalink
Bump ai package to the latest version / use generateText directly / F… (
Browse files Browse the repository at this point in the history
#70)

* Bump ai package to the latest version / use generateText directly / Fix maxTokens value

* Make properties readonly
  • Loading branch information
sakowicz authored Dec 1, 2024
1 parent 43d4306 commit b15c350
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 69 deletions.
14 changes: 7 additions & 7 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions src/actual-ai.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { ActualAiServiceI, ActualApiServiceI, TransactionServiceI } from './types';

class ActualAiService implements ActualAiServiceI {
private transactionService: TransactionServiceI;
private readonly transactionService: TransactionServiceI;

private actualApiService: ActualApiServiceI;
private readonly actualApiService: ActualApiServiceI;

constructor(
transactionService: TransactionServiceI,
Expand Down
12 changes: 6 additions & 6 deletions src/actual-api-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ class ActualApiService implements ActualApiServiceI {

private fs: typeof import('fs');

private dataDir: string;
private readonly dataDir: string;

private serverURL: string;
private readonly serverURL: string;

private password: string;
private readonly password: string;

private budgetId: string;
private readonly budgetId: string;

private e2ePassword: string;
private readonly e2ePassword: string;

constructor(
actualApiClient: typeof import('@actual-app/api'),
Expand Down Expand Up @@ -77,7 +77,7 @@ class ActualApiService implements ActualApiServiceI {
return this.actualApiClient.getCategoryGroups();
}

public async getCategories(): Promise<(APICategoryEntity|APICategoryGroupEntity)[]> {
public async getCategories(): Promise<(APICategoryEntity | APICategoryGroupEntity)[]> {
return this.actualApiClient.getCategories();
}

Expand Down
26 changes: 12 additions & 14 deletions src/container.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
import * as actualApiClient from '@actual-app/api';
import fs from 'fs';
import { generateText } from 'ai';
import ActualApiService from './actual-api-service';
import TransactionService from './transaction-service';
import LlmModelFactory from './llm-model-factory';
import {
llmProvider,
openaiApiKey,
openaiModel,
openaiBaseURL,
anthropicBaseURL,
anthropicApiKey,
anthropicBaseURL,
anthropicModel,
googleModel,
googleBaseURL,
budgetId,
dataDir,
e2ePassword,
googleApiKey,
ollamaModel,
googleBaseURL,
googleModel,
llmProvider,
ollamaBaseURL,
dataDir,
serverURL,
ollamaModel,
openaiApiKey,
openaiBaseURL,
openaiModel,
password,
budgetId,
e2ePassword,
serverURL,
syncAccountsBeforeClassify,
} from './config';
import ActualAiService from './actual-ai';
Expand Down Expand Up @@ -54,7 +53,6 @@ const actualApiService = new ActualApiService(
);

const llmService = new LlmService(
generateText,
llmModelFactory,
);

Expand Down
24 changes: 12 additions & 12 deletions src/llm-model-factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@ import { createOllama } from 'ollama-ai-provider';
import { LlmModelFactoryI } from './types';

class LlmModelFactory implements LlmModelFactoryI {
private llmProvider: string;
private readonly llmProvider: string;

private openaiApiKey: string;
private readonly openaiApiKey: string;

private openaiModel: string;
private readonly openaiModel: string;

private openaiBaseURL: string;
private readonly openaiBaseURL: string;

private anthropicBaseURL: string;
private readonly anthropicBaseURL: string;

private anthropicApiKey: string;
private readonly anthropicApiKey: string;

private anthropicModel: string;
private readonly anthropicModel: string;

private googleModel: string;
private readonly googleModel: string;

private googleBaseURL: string;
private readonly googleBaseURL: string;

private googleApiKey: string;
private readonly googleApiKey: string;

private ollamaModel: string;
private readonly ollamaModel: string;

private ollamaBaseURL: string;
private readonly ollamaBaseURL: string;

constructor(
llmProvider: string,
Expand Down
14 changes: 5 additions & 9 deletions src/llm-service.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
import { LanguageModel } from 'ai';
import { GenerateTextFunction, LlmModelFactoryI, LlmServiceI } from './types';
import { generateText, LanguageModel } from 'ai';
import { LlmModelFactoryI, LlmServiceI } from './types';

export default class LlmService implements LlmServiceI {
private generateText: GenerateTextFunction;

private model: LanguageModel;
private readonly model: LanguageModel;

constructor(
generateText: GenerateTextFunction,
llmModelFactory: LlmModelFactoryI,
) {
this.generateText = generateText;
this.model = llmModelFactory.create();
}

public async ask(prompt: string): Promise<string> {
const { text } = await this.generateText({
const { text } = await generateText({
model: this.model,
prompt,
temperature: 0.1,
max_tokens: 50,
maxTokens: 35,
});

return text.replace(/(\r\n|\n|\r|"|')/gm, '');
Expand Down
20 changes: 10 additions & 10 deletions src/transaction-service.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import suppressConsoleLogsAsync from './utils';
import {
TransactionServiceI, PromptGeneratorI, ActualApiServiceI, LlmServiceI,
ActualApiServiceI, LlmServiceI, PromptGeneratorI, TransactionServiceI,
} from './types';

const NOTES_NOT_GUESSED = 'actual-ai could not guess this category';
const NOTES_GUESSED = 'actual-ai guessed this category';

class TransactionService implements TransactionServiceI {
private actualAiService: ActualApiServiceI;
private readonly actualAiService: ActualApiServiceI;

private llmService: LlmServiceI;
private readonly llmService: LlmServiceI;

private promptGenerator: PromptGeneratorI;
private readonly promptGenerator: PromptGeneratorI;

private syncAccountsBeforeClassify: boolean;
private readonly syncAccountsBeforeClassify: boolean;

constructor(
actualApiClient: ActualApiServiceI,
Expand Down Expand Up @@ -54,11 +54,11 @@ class TransactionService implements TransactionServiceI {
const transactions = await this.actualAiService.getTransactions();
const uncategorizedTransactions = transactions.filter(
(transaction) => !transaction.category
&& (transaction.transfer_id === null || transaction.transfer_id === undefined)
&& transaction.starting_balance_flag !== true
&& transaction.imported_payee !== null
&& transaction.imported_payee !== ''
&& (transaction.notes === null || (!transaction.notes?.includes(NOTES_NOT_GUESSED))),
&& (transaction.transfer_id === null || transaction.transfer_id === undefined)
&& transaction.starting_balance_flag !== true
&& transaction.imported_payee !== null
&& transaction.imported_payee !== ''
&& (transaction.notes === null || (!transaction.notes?.includes(NOTES_NOT_GUESSED))),
);

for (let i = 0; i < uncategorizedTransactions.length; i++) {
Expand Down
18 changes: 9 additions & 9 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { CoreTool, GenerateTextResult, LanguageModel } from 'ai';
import { LanguageModel } from 'ai';
import {
APICategoryEntity,
APICategoryGroupEntity,
Expand All @@ -12,17 +12,25 @@ export interface LlmModelFactoryI {

export interface ActualApiServiceI {
initializeApi(): Promise<void>;

shutdownApi(): Promise<void>;

getCategoryGroups(): Promise<APICategoryGroupEntity[]>

getCategories(): Promise<(APICategoryEntity | APICategoryGroupEntity)[]>

getPayees(): Promise<APIPayeeEntity[]>

getTransactions(): Promise<TransactionEntity[]>

updateTransactionNotes(id: string, notes: string): Promise<void>

updateTransactionNotesAndCategory(
id: string,
notes: string,
categoryId: string,
): Promise<void>

runBankSync(): Promise<void>
}

Expand All @@ -45,11 +53,3 @@ export interface PromptGeneratorI {
payees: APIPayeeEntity[],
): string
}

// eslint-disable-next-line no-unused-vars
export type GenerateTextFunction = (options: {
model: LanguageModel;
prompt?: string;
temperature?: number;
max_tokens?: number;
}) => Promise<GenerateTextResult<Record<string, CoreTool>>>;

0 comments on commit b15c350

Please sign in to comment.