Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

['digest'] issue and ['args'] issue fix #20

Merged
merged 4 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ai_engine_sdk/api_models/agents_json_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class TaskSelectionMessage(AgentJsonMessage):
text: str
options: Dict[str, TaskOption]


def get_options_keys(self) -> list[TaskOption]:
return [option for option in self.options]

Expand All @@ -50,7 +49,6 @@ class DataRequestMessage(AgentJsonMessage):
class ConfirmationMessage(AgentJsonMessage):
type: Literal[AgentJsonMessageTypes.CONFIRMATION] = AgentJsonMessageTypes.CONFIRMATION
text: str
model: str
payload: Dict[str, Any]


Expand All @@ -70,14 +68,16 @@ def is_agent_json_confirmation_message(message_type: str) -> bool:

def is_task_selection_message(message_type: str) -> bool:
union_of_type = TaskSelectionTypes
allowed_values = [literal for lit in get_args(union_of_type) for literal in get_args(lit)]
allowed_values = [literal for lit in get_args(
tanaygodse marked this conversation as resolved.
Show resolved Hide resolved
union_of_type) for literal in get_args(lit)]
return message_type.upper() in allowed_values


def is_data_request_message(message_type: str) -> bool:
union_of_type = DataRequestTypes
if get_origin(union_of_type) is Union:
allowed_values = [literal for lit in get_args(union_of_type) for literal in get_args(lit)]
tanaygodse marked this conversation as resolved.
Show resolved Hide resolved
allowed_values = [literal for lit in get_args(
union_of_type) for literal in get_args(lit)]
elif get_origin(union_of_type) is Literal:
allowed_values = get_args(union_of_type)

Expand Down
50 changes: 32 additions & 18 deletions ai_engine_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,14 @@ async def make_api_request(
}

async with aiohttp.ClientSession() as session:
logger.debug(f"\n\n 📤 Request triggered : {method} {api_base_url}{endpoint}")
logger.debug(f"\n\n 📤 Request triggered : {
tanaygodse marked this conversation as resolved.
Show resolved Hide resolved
method} {api_base_url}{endpoint}")
logger.debug(f"{body=}")
logger.debug("---------------------------\n\n")
async with session.request(method, f"{api_base_url}{endpoint}", headers=headers, data=body) as response:
if not bool(re.search(pattern="^2..$", string=str(response.status))):
raise Exception(f"Request failed with status {response.status} to {method}: {endpoint}")
raise Exception(f"Request failed with status {
response.status} to {method}: {endpoint}")
return await response.json()


Expand All @@ -113,6 +115,7 @@ class Session:
_messages (List[ApiBaseMessage]): A list to store messages associated with the session.
_message_ids (set[str]): A set to store unique message IDs to prevent duplication.
"""

def __init__(self, api_base_url: str, api_key: str, session_id: str, function_group: str):
"""
Initializes a new session with the given parameters.
Expand Down Expand Up @@ -254,12 +257,14 @@ async def get_messages(self) -> List[ApiBaseMessage]:

Each message type has a different purpose as the name indicates.
"""
queryParams = f"?last_message_id={self._messages[-1]['message_id']}" if self._messages else ""
queryParams = f"?last_message_id={
self._messages[-1]['message_id']}" if self._messages else ""
response = await make_api_request(
api_base_url=self._api_base_url,
api_key=self._api_key,
method='GET',
endpoint=f"/v1beta1/engine/chat/sessions/{self.session_id}/new-messages{queryParams}"
endpoint=f"/v1beta1/engine/chat/sessions/{
self.session_id}/new-messages{queryParams}"
)

newMessages: List[ApiBaseMessage] = []
Expand All @@ -273,14 +278,15 @@ async def get_messages(self) -> List[ApiBaseMessage]:
agent_json: dict = message['agent_json']
agent_json_type: str = agent_json['type'].upper()
if is_task_selection_message(message_type=agent_json_type):
indexed_task_options: dict = get_indexed_task_options_from_raw_api_response(raw_api_response=message)
indexed_task_options: dict = get_indexed_task_options_from_raw_api_response(
tanaygodse marked this conversation as resolved.
Show resolved Hide resolved
raw_api_response=message)
newMessages.append(
TaskSelectionMessage.model_validate({
'type': agent_json_type,
'id': message['message_id'],
'timestamp': message['timestamp'],
'text': agent_json['text'],
'options':indexed_task_options
'options': indexed_task_options
})
)
elif is_api_context_json(message_type=agent_json_type, agent_json_text=agent_json['text']):
Expand All @@ -289,8 +295,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
'id': message['message_id'],
'timestamp': message['timestamp'],
'text': agent_json['text'],
'model': agent_json['context_json']['digest'],
'payload': agent_json['context_json']['args'],
'payload': agent_json['context_json'],
})
)
elif is_data_request_message(message_type=agent_json_type):
Expand Down Expand Up @@ -352,7 +357,7 @@ async def delete(self):
endpoint=f"/v1beta1/engine/chat/sessions/{self.session_id}"
)

async def execute_function(self, function_ids: list[str], objective: str, context: str|None = None):
async def execute_function(self, function_ids: list[str], objective: str, context: str | None = None):
await self._submit_message(
payload=ApiUserMessageExecuteFunctions.model_validate({
"functions": function_ids,
Expand All @@ -362,15 +367,17 @@ async def execute_function(self, function_ids: list[str], objective: str, contex
})
)


class AiEngine:
def __init__(self, api_key: str, options: Optional[dict] = None):
self._api_base_url = options.get('api_base_url') if options and 'api_base_url' in options else default_api_base_url
self._api_base_url = options.get(
'api_base_url') if options and 'api_base_url' in options else default_api_base_url
self._api_key = api_key


####
# Function groups
####

async def get_function_groups(self) -> List[FunctionGroup]:
logger.debug("get_function_groups")
publicGroups, privateGroups = await asyncio.gather(
Expand Down Expand Up @@ -463,6 +470,7 @@ async def get_function_group_by_function(self, function_id: str):
###
# Functions
###

async def get_functions_by_function_group(self, function_group_id: str) -> list[FunctionGroupFunctions]:
raw_response: dict = await make_api_request(
api_base_url=self._api_base_url,
Expand All @@ -474,14 +482,14 @@ async def get_functions_by_function_group(self, function_group_id: str) -> list[
if "functions" in raw_response:
list(
map(
lambda function_name: FunctionGroupFunctions.model_validate({"name": function_name}),
lambda function_name: FunctionGroupFunctions.model_validate(
{"name": function_name}),
raw_response["functions"]
)
)

return result


async def get_functions(self) -> list[Function]:
raw_response: dict = await make_api_request(
api_base_url=self._api_base_url,
Expand All @@ -498,8 +506,10 @@ async def get_functions(self) -> list[Function]:
####
# Model
####

async def get_models(self) -> List[Model]:
pending_credits = [self.get_model_credits(model_id) for model_id in DefaultModelIds]
pending_credits = [self.get_model_credits(
model_id) for model_id in DefaultModelIds]

models = [Model(
id=model_id,
Expand Down Expand Up @@ -534,7 +544,8 @@ async def get_model_credits(self, model: Union[KnownModelId, CustomModel]) -> in
api_base_url=self._api_base_url,
api_key=self._api_key,
method='GET',
endpoint=f"/v1beta1/engine/credit/remaining_tokens?models={model_id}"
endpoint=f"/v1beta1/engine/credit/remaining_tokens?models={
model_id}"
)
return response['model_tokens'].get(model_id, 0)

Expand All @@ -546,7 +557,8 @@ async def create_session(self, function_group: str, opts: Optional[dict] = None)
email=opts.get('email') if opts else "",
functionGroup=function_group,
preferencesEnabled=False,
requestModel=opts.get('model') if opts and 'model' in opts else DefaultModelId
requestModel=opts.get(
'model') if opts and 'model' in opts else DefaultModelId
)
response = await make_api_request(
api_base_url=self._api_base_url,
Expand Down Expand Up @@ -575,8 +587,10 @@ async def share_function_group(
api_base_url=self._api_base_url,
api_key=self._api_key,
method='PUT',
endpoint=f"/v1beta1/function-groups/{function_group_id}/permissions/",
endpoint=f"/v1beta1/function-groups/{
function_group_id}/permissions/",
payload=payload
)
logger.debug(f"FG successfully shared: {function_group_id} with {target_user_email}")
logger.debug(f"FG successfully shared: {
function_group_id} with {target_user_email}")
return raw_response
Loading