diff --git a/apiclient_pydantic_generator/parser.py b/apiclient_pydantic_generator/parser.py index efcbb9a..55169b9 100644 --- a/apiclient_pydantic_generator/parser.py +++ b/apiclient_pydantic_generator/parser.py @@ -27,14 +27,13 @@ from datamodel_code_generator.types import DataType, DataTypeManager, StrictTypes from pydantic import BaseModel, validator - RE_APPLICATION_JSON_PATTERN: Pattern[str] = re.compile(r'^application/.*json$') class CachedPropertyModel(BaseModel): class Config: arbitrary_types_allowed = True - keep_untouched = (cached_property, ) + keep_untouched = (cached_property,) class Response(BaseModel): @@ -359,8 +358,8 @@ def parse_request_body( super().parse_request_body(name, request_body, path) arguments: List[Argument] = [] for ( - media_type, - media_obj, + media_type, + media_obj, ) in request_body.content.items(): if isinstance(media_obj.schema_, (JsonSchemaObject, ReferenceObject)): # pragma: no cover if RE_APPLICATION_JSON_PATTERN.match(media_type): @@ -384,18 +383,23 @@ def parse_responses( path: List[str], ) -> Dict[str, Dict[str, DataType]]: data_types = super().parse_responses(name, responses, path) - status_code_200 = data_types.get('200') - if status_code_200: - data_type = list(status_code_200.values())[0] - if data_type: - self.data_types.append(data_type) - type_hint = data_type.type_hint # TODO: change to lazy loading - else: - type_hint = 'None' + type_hint = 'None' + for code in [200, 201, 202]: + response_model = self._get_response(code, data_types) + if response_model: + type_hint = response_model.type_hint # TODO: change to lazy loading + break self._temporary_operation['response'] = type_hint - return data_types + def _get_response(self, status_code: int, data_types: dict[str, dict[str, DataType]]) -> DataType: + response_model = data_types.get(f"{status_code}") + if response_model: + data_type = list(response_model.values())[0] + if data_type: + self.data_types.append(data_type) + return data_type + def parse_operation( self, raw_operation: Dict[str, Any], diff --git a/apiclient_pydantic_generator/templates/client.jinja2 b/apiclient_pydantic_generator/templates/client.jinja2 index c120846..17f40e3 100644 --- a/apiclient_pydantic_generator/templates/client.jinja2 +++ b/apiclient_pydantic_generator/templates/client.jinja2 @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Any from {% if base_cls.from_ %}{{base_cls.from_}}{% else %}.{% endif %} import {{base_cls.import_}}{% if base_cls.alias %} as {{base_cls.alias}}{% endif %} from apiclient_pydantic import serialize_all_methods diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ddb310f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +from pathlib import Path + +import pytest + + +@pytest.fixture +def resources_folder(): + return Path(__file__).parent / "resources" diff --git a/tests/resources/openapi.json b/tests/resources/openapi.json new file mode 100644 index 0000000..df810d8 --- /dev/null +++ b/tests/resources/openapi.json @@ -0,0 +1,362 @@ +{ + "openapi": "3.1.0", + "info": { + "title": "Audit API", + "description": "Dummy", + "version": "1.0.0" + }, + "paths": { + "/health": { + "get": { + "summary": "Health", + "operationId": "health", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + } + } + } + }, + "/version": { + "get": { + "summary": "Version", + "operationId": "get_version", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + } + } + } + }, + "/": { + "get": { + "summary": "Home", + "description": "Home", + "operationId": "home", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + } + } + } + }, + "/v1/query/{query_id}": { + "get": { + "tags": [ + "audit" + ], + "summary": "Get Query", + "operationId": "get_query", + "parameters": [ + { + "name": "query_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Query Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/QueryOut" + } + } + } + }, + "404": { + "description": "Not found" + }, + "400": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenericResponse" + } + } + }, + "description": "Bad Request" + }, + "500": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenericResponse" + } + } + }, + "description": "Internal Server Error" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/v1/query": { + "post": { + "tags": [ + "audit" + ], + "summary": "Post Query", + "operationId": "post_query", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/QueryRequest" + }, + "example": { + "Obj": "C", + "Id": "ae8d0a40-203f-4135-b7df-9b4fab56a1ea", + "FromTimestamp": "2023-11-01T12:00:00", + "ToTimestamp": "2023-11-01T15:00:00" + } + } + }, + "required": true + }, + "responses": { + "201": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/QueryRequestOut" + } + } + } + }, + "404": { + "description": "Not found" + }, + "400": { + "description": "Bad Request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenericResponse" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenericResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "GenericResponse": { + "properties": { + "RequestId": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Requestid" + }, + "Message": { + "type": "string", + "title": "Message" + } + }, + "type": "object", + "required": [ + "RequestId", + "Message" + ], + "title": "GenericResponse" + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "type": "array", + "title": "Detail" + } + }, + "type": "object", + "title": "HTTPValidationError" + }, + "Object": { + "type": "string", + "enum": [ + "A", + "B", + "C" + ], + "title": "Object" + }, + "QueryOut": { + "properties": { + "Status": { + "$ref": "#/components/schemas/Status" + }, + "Url": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Url" + } + }, + "type": "object", + "required": [ + "Status", + "Url" + ], + "title": "QueryOut" + }, + "QueryRequest": { + "properties": { + "Obj": { + "$ref": "#/components/schemas/Object" + }, + "Id": { + "anyOf": [ + { + "type": "string", + "format": "uuid" + }, + { + "type": "null" + } + ], + "title": "Id" + }, + "FromTimestamp": { + "type": "string", + "format": "date-time", + "title": "Fromtimestamp" + }, + "ToTimestamp": { + "type": "string", + "format": "date-time", + "title": "Totimestamp" + } + }, + "type": "object", + "required": [ + "Obj", + "FromTimestamp", + "ToTimestamp" + ], + "title": "QueryRequest" + }, + "QueryRequestOut": { + "properties": { + "QueryId": { + "type": "string", + "title": "Queryid" + } + }, + "type": "object", + "required": [ + "QueryId" + ], + "title": "QueryRequestOut" + }, + "Status": { + "type": "string", + "enum": [ + "QUEUED", + "RUNNING", + "SUCCEEDED", + "FAILED", + "CANCELLED" + ], + "title": "Status" + }, + "ValidationError": { + "properties": { + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "type": "array", + "title": "Location" + }, + "msg": { + "type": "string", + "title": "Message" + }, + "type": { + "type": "string", + "title": "Error Type" + } + }, + "type": "object", + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError" + } + } + } +} diff --git a/tests/test_format.py b/tests/test_format.py index 160b2f3..d8e4175 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -1,9 +1,6 @@ from apiclient_pydantic_generator.format import YapfCodeFormatter -from ward import test - -@test('test simple format') -def _(): +def test_reformat(): expect_code = """\ from enum import Enum from pathlib import Path diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..1ed4a7e --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,13 @@ +from apiclient_pydantic_generator.parser import OpenAPIParser + + +def test_parser(resources_folder): + with (resources_folder / "openapi.json").open('r') as f: + input_text = "".join(f.readlines()) + parser = OpenAPIParser(input_text) + parser.parse() + + operations = parser.operations + op = operations.get("#/paths/v1/query/post") + assert op.response == 'QueryRequestOut' +