From 4ac18c2bfc8396fe7c11c7c0b47b200e24469d1d Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 7 Jan 2025 14:19:25 -0500 Subject: [PATCH] Do not use core Airflow Flask related resources in FAB provider (tests of `www`) --- .../auth_manager/security_manager/override.py | 2 +- .../src/airflow/providers/fab/www/app.py | 46 +- .../src/airflow/providers/fab/www/auth.py | 125 +++++ .../fab/www/extensions/init_appbuilder.py | 128 +++++- .../fab/www/extensions/init_views.py | 108 ++++- .../airflow/no_roles_permissions.html | 42 ++ .../src/airflow/providers/fab/www/views.py | 79 ++++ .../api/auth/backend/test_basic_auth.py | 2 +- .../api/auth/backend/test_session.py | 2 +- .../api_endpoints/test_asset_endpoint.py | 325 ------------- .../auth_manager/api_endpoints/test_auth.py | 44 +- .../auth_manager/api_endpoints/test_cors.py | 6 +- .../api_endpoints/test_dag_endpoint.py | 226 ---------- .../api_endpoints/test_dag_run_endpoint.py | 267 ----------- .../api_endpoints/test_dag_source_endpoint.py | 132 ------ .../test_dag_warning_endpoint.py | 84 ---- .../api_endpoints/test_event_log_endpoint.py | 151 ------- .../test_import_error_endpoint.py | 222 --------- .../test_task_instance_endpoint.py | 426 ------------------ .../api_endpoints/test_variable_endpoint.py | 88 ---- .../api_endpoints/test_xcom_endpoint.py | 268 ----------- providers/tests/fab/auth_manager/conftest.py | 2 +- .../fab/auth_manager/test_fab_auth_manager.py | 2 +- .../tests/fab/auth_manager/test_security.py | 8 +- .../auth_manager/views/test_permissions.py | 2 +- .../fab/auth_manager/views/test_roles_list.py | 2 +- .../tests/fab/auth_manager/views/test_user.py | 2 +- .../fab/auth_manager/views/test_user_edit.py | 2 +- .../fab/auth_manager/views/test_user_stats.py | 2 +- 29 files changed, 521 insertions(+), 2274 deletions(-) create mode 100644 providers/src/airflow/providers/fab/www/auth.py create mode 100644 providers/src/airflow/providers/fab/www/templates/airflow/no_roles_permissions.html delete mode 100644 providers/tests/fab/auth_manager/api_endpoints/test_asset_endpoint.py delete mode 100644 providers/tests/fab/auth_manager/api_endpoints/test_dag_endpoint.py delete mode 100644 providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py delete mode 100644 providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py delete mode 100644 providers/tests/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py delete mode 100644 providers/tests/fab/auth_manager/api_endpoints/test_event_log_endpoint.py delete mode 100644 providers/tests/fab/auth_manager/api_endpoints/test_import_error_endpoint.py delete mode 100644 providers/tests/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py delete mode 100644 providers/tests/fab/auth_manager/api_endpoints/test_variable_endpoint.py delete mode 100644 providers/tests/fab/auth_manager/api_endpoints/test_xcom_endpoint.py diff --git a/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py b/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py index d0e00b0977ce8..5c8b58c2139d6 100644 --- a/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py @@ -110,9 +110,9 @@ from airflow.providers.fab.www.security import permissions from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2 from airflow.providers.fab.www.session import ( + AirflowDatabaseSessionInterface, AirflowDatabaseSessionInterface as FabAirflowDatabaseSessionInterface, ) -from airflow.www.session import AirflowDatabaseSessionInterface if TYPE_CHECKING: from airflow.providers.fab.www.security.permissions import RESOURCE_ASSET diff --git a/providers/src/airflow/providers/fab/www/app.py b/providers/src/airflow/providers/fab/www/app.py index 0414fc5e408b5..ee96e4a218ec3 100644 --- a/providers/src/airflow/providers/fab/www/app.py +++ b/providers/src/airflow/providers/fab/www/app.py @@ -17,22 +17,26 @@ # under the License. from __future__ import annotations -from os.path import isabs - from flask import Flask from flask_appbuilder import SQLA from flask_wtf.csrf import CSRFProtect -from sqlalchemy.engine.url import make_url from airflow import settings from airflow.configuration import conf -from airflow.exceptions import AirflowConfigException from airflow.logging_config import configure_logging from airflow.providers.fab.www.extensions.init_appbuilder import init_appbuilder from airflow.providers.fab.www.extensions.init_jinja_globals import init_jinja_globals from airflow.providers.fab.www.extensions.init_manifest_files import configure_manifest_files from airflow.providers.fab.www.extensions.init_security import init_api_auth, init_xframe_protection -from airflow.providers.fab.www.extensions.init_views import init_error_handlers, init_plugins +from airflow.providers.fab.www.extensions.init_views import ( + init_api_auth_provider, + init_api_connexion, + init_api_error_handlers, + init_appbuilder_views, + init_error_handlers, + init_plugins, +) +from airflow.utils.json import AirflowJsonProvider app: Flask | None = None @@ -41,44 +45,56 @@ csrf = CSRFProtect() -def create_app(): +def create_app(config=None, testing=False): """Create a new instance of Airflow WWW app.""" flask_app = Flask(__name__) flask_app.secret_key = conf.get("webserver", "SECRET_KEY") + webserver_config = conf.get_mandatory_value("webserver", "config_file") + # Enable customizations in webserver_config.py to be applied via Flask.current_app. + with flask_app.app_context(): + flask_app.config.from_pyfile(webserver_config, silent=True) + + flask_app.config["TESTING"] = testing flask_app.config["SQLALCHEMY_DATABASE_URI"] = conf.get("database", "SQL_ALCHEMY_CONN") - url = make_url(flask_app.config["SQLALCHEMY_DATABASE_URI"]) - if url.drivername == "sqlite" and url.database and not isabs(url.database): - raise AirflowConfigException( - f'Cannot use relative path: `{conf.get("database", "SQL_ALCHEMY_CONN")}` to connect to sqlite. ' - "Please use absolute path such as `sqlite:////tmp/airflow.db`." - ) + if config: + flask_app.config.from_mapping(config) if "SQLALCHEMY_ENGINE_OPTIONS" not in flask_app.config: flask_app.config["SQLALCHEMY_ENGINE_OPTIONS"] = settings.prepare_engine_args() + # Configure the JSON encoder used by `|tojson` filter from Flask + flask_app.json_provider_class = AirflowJsonProvider + flask_app.json = AirflowJsonProvider(flask_app) + + csrf.init_app(flask_app) + db = SQLA() db.session = settings.Session db.init_app(flask_app) + init_api_auth(flask_app) configure_logging() configure_manifest_files(flask_app) - init_api_auth(flask_app) with flask_app.app_context(): init_appbuilder(flask_app) + init_appbuilder_views(flask_app) init_plugins(flask_app) + init_api_auth_provider(flask_app) init_error_handlers(flask_app) + init_api_connexion(flask_app) + init_api_error_handlers(flask_app) # needs to be after all api inits to let them add their path first init_jinja_globals(flask_app) init_xframe_protection(flask_app) return flask_app -def cached_app(): +def cached_app(config=None, testing=False): """Return cached instance of Airflow WWW app.""" global app if not app: - app = create_app() + app = create_app(config=config, testing=testing) return app diff --git a/providers/src/airflow/providers/fab/www/auth.py b/providers/src/airflow/providers/fab/www/auth.py new file mode 100644 index 0000000000000..198acb29f9a69 --- /dev/null +++ b/providers/src/airflow/providers/fab/www/auth.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from functools import wraps +from typing import TYPE_CHECKING, Callable, TypeVar, cast + +from flask import flash, redirect, render_template, request, url_for + +from airflow.api_fastapi.app import get_auth_manager +from airflow.auth.managers.models.resource_details import ( + AccessView, + DagAccessEntity, + DagDetails, +) +from airflow.configuration import conf +from airflow.utils.net import get_hostname + +if TYPE_CHECKING: + from airflow.auth.managers.base_auth_manager import ResourceMethod + +T = TypeVar("T", bound=Callable) + +log = logging.getLogger(__name__) + + +def get_access_denied_message(): + return conf.get("webserver", "access_denied_message") + + +def _has_access(*, is_authorized: bool, func: Callable, args, kwargs): + """ + Define the behavior whether the user is authorized to access the resource. + + :param is_authorized: whether the user is authorized to access the resource + :param func: the function to call if the user is authorized + :param args: the arguments of ``func`` + :param kwargs: the keyword arguments ``func`` + + :meta private: + """ + if is_authorized: + return func(*args, **kwargs) + elif get_auth_manager().is_logged_in() and not get_auth_manager().is_authorized_view( + access_view=AccessView.WEBSITE + ): + return ( + render_template( + "airflow/no_roles_permissions.html", + hostname=get_hostname() if conf.getboolean("webserver", "EXPOSE_HOSTNAME") else "", + logout_url=get_auth_manager().get_url_logout(), + ), + 403, + ) + elif not get_auth_manager().is_logged_in(): + return redirect(get_auth_manager().get_url_login(next_url=request.url)) + else: + access_denied = get_access_denied_message() + flash(access_denied, "danger") + return redirect(url_for("Airflow.index")) + + +def has_access_dag(method: ResourceMethod, access_entity: DagAccessEntity | None = None) -> Callable[[T], T]: + def has_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + dag_id_kwargs = kwargs.get("dag_id") + dag_id_args = request.args.get("dag_id") + dag_id_form = request.form.get("dag_id") + dag_id_json = request.json.get("dag_id") if request.is_json else None + all_dag_ids = [dag_id_kwargs, dag_id_args, dag_id_form, dag_id_json] + unique_dag_ids = set(dag_id for dag_id in all_dag_ids if dag_id is not None) + + if len(unique_dag_ids) > 1: + log.warning( + "There are different dag_ids passed in the request: %s. Returning 403.", unique_dag_ids + ) + log.warning( + "kwargs: %s, args: %s, form: %s, json: %s", + dag_id_kwargs, + dag_id_args, + dag_id_form, + dag_id_json, + ) + return ( + render_template( + "airflow/no_roles_permissions.html", + hostname=get_hostname() if conf.getboolean("webserver", "EXPOSE_HOSTNAME") else "", + logout_url=get_auth_manager().get_url_logout(), + ), + 403, + ) + dag_id = unique_dag_ids.pop() if unique_dag_ids else None + + is_authorized = get_auth_manager().is_authorized_dag( + method=method, + access_entity=access_entity, + details=None if not dag_id else DagDetails(id=dag_id), + ) + + return _has_access( + is_authorized=is_authorized, + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return has_access_decorator diff --git a/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py b/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py index 9cf353490c3ac..d353b137c59a9 100644 --- a/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py +++ b/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py @@ -36,10 +36,10 @@ ) from flask_appbuilder.filters import TemplateFilters from flask_appbuilder.menu import Menu -from flask_appbuilder.views import IndexView +from flask_appbuilder.views import IndexView, UtilView from airflow import settings -from airflow.api_fastapi.app import create_auth_manager +from airflow.api_fastapi.app import create_auth_manager, get_auth_manager from airflow.configuration import conf from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2 @@ -76,7 +76,30 @@ def dynamic_class_import(class_path): class AirflowAppBuilder: - """This is the base class for all the framework.""" + """ + This is the base class for all the framework. + + This is where you will register all your views + and create the menu structure. + Will hold your flask app object, all your views, and security classes. + Initialize your application like this for SQLAlchemy:: + from flask import Flask + from flask_appbuilder import SQLA, AppBuilder + app = Flask(__name__) + app.config.from_object('config') + db = SQLA(app) + appbuilder = AppBuilder(app, db.session) + When using MongoEngine:: + from flask import Flask + from flask_appbuilder import AppBuilder + from flask_appbuilder.security.mongoengine.manager import SecurityManager + from flask_mongoengine import MongoEngine + app = Flask(__name__) + app.config.from_object('config') + dbmongo = MongoEngine(app) + appbuilder = AppBuilder(app) + You can also create everything as an application factory. + """ baseviews: list[BaseView | Session] = [] # Flask app @@ -158,6 +181,7 @@ def init_app(self, app, session): app.config.setdefault("LANGUAGES", {"en": {"flag": "gb", "name": "English"}}) app.config.setdefault("ADDON_MANAGERS", []) app.config.setdefault("RATELIMIT_ENABLED", self.auth_rate_limited) + app.config.setdefault("FAB_API_MAX_PAGE_SIZE", 100) app.config.setdefault("FAB_BASE_TEMPLATE", self.base_template) app.config.setdefault("FAB_STATIC_FOLDER", self.static_folder) app.config.setdefault("FAB_STATIC_URL_PATH", self.static_url_path) @@ -180,6 +204,9 @@ def init_app(self, app, session): else: self.menu = self.menu or Menu() + if self.update_perms: # default is True, if False takes precedence from config + self.update_perms = app.config.get("FAB_UPDATE_PERMS", True) + self._addon_managers = app.config["ADDON_MANAGERS"] self.session = session auth_manager = create_auth_manager() @@ -195,7 +222,12 @@ def init_app(self, app, session): app.before_request(self.sm.before_request) self._add_admin_views() self._add_addon_views() + if self.app: + self._add_menu_permissions() + else: + self.post_init() self._init_extension(app) + self._swap_url_filter() def _init_extension(self, app): app.appbuilder = self @@ -203,6 +235,24 @@ def _init_extension(self, app): app.extensions = {} app.extensions["appbuilder"] = self + def _swap_url_filter(self): + """Use our url filtering util function so there is consistency between FAB and Airflow routes.""" + from flask_appbuilder.security import views as fab_sec_views + + from airflow.www.views import get_safe_url + + fab_sec_views.get_safe_redirect = get_safe_url + + def post_init(self): + for baseview in self.baseviews: + # instantiate the views and add session + self._check_and_init(baseview) + # Register the views has blueprints + if baseview.__class__.__name__ not in self.get_app.blueprints.keys(): + self.register_blueprint(baseview) + # Add missing permissions where needed + self.add_permissions() + @property def get_app(self): """ @@ -233,6 +283,20 @@ def app_name(self): """ return self.get_app.config["APP_NAME"] + @property + def require_confirmation_dag_change(self): + """ + Get the value of the require_confirmation_dag_change configuration. + + The logic is: + - return True, in page dag.html, when user trigger/pause the dag from UI. + Once confirmation box will be shown before triggering the dag. + - Default value is False. + + :return: Boolean + """ + return self.get_app.config["REQUIRE_CONFIRMATION_DAG_CHANGE"] + @property def app_theme(self): """ @@ -282,6 +346,10 @@ def _add_admin_views(self): """Register indexview, utilview (back function), babel views and Security views.""" self.indexview = self._check_and_init(self.indexview) self.add_view_no_menu(self.indexview) + self.add_view_no_menu(UtilView()) + self.bm.register_views() + + get_auth_manager().register_views() def _add_addon_views(self): """Register declared addons.""" @@ -328,9 +396,9 @@ def add_view( :param name: The string name that identifies the menu. :param href: - Override the generated link for the menu. + Override the generated href for the menu. You can use an url string or an endpoint name - if non provided default_view from view will be set as link. + if non provided default_view from view will be set as href. :param icon: Font-Awesome icon name, optional. :param label: @@ -423,7 +491,7 @@ def add_link( :param name: The string name that identifies the menu. :param href: - Override the generated link for the menu. + Override the generated href for the menu. You can use an url string or an endpoint name :param icon: Font-Awesome icon name, optional. @@ -498,9 +566,49 @@ def add_view_no_menu(self, baseview, endpoint=None, static_folder=None): log.warning(LOGMSG_WAR_FAB_VIEW_EXISTS, baseview.__class__.__name__) return baseview + def security_cleanup(self): + """ + Clean up security. + + This method is useful if you have changed the name of your menus or + classes. Changing them leaves behind permissions that are not associated + with anything. You can use it always or just sometimes to perform a + security cleanup. + + .. warning:: + + This deletes any permission that is no longer part of any registered + view or menu. Only invoke AFTER YOU HAVE REGISTERED ALL VIEWS. + """ + if not hasattr(self.sm, "security_cleanup"): + raise NotImplementedError("The auth manager used does not support security_cleanup method.") + self.sm.security_cleanup(self.baseviews, self.menu) + + def security_converge(self, dry=False) -> dict: + """ + Migrates all permissions to the new names on all the Roles. + + This method is useful when you use: + + - ``class_permission_name`` + - ``previous_class_permission_name`` + - ``method_permission_name`` + - ``previous_method_permission_name`` + + :param dry: If True will not change DB + :return: Dict with all computed necessary operations + """ + return self.sm.security_converge(self.baseviews, self.menu, dry) + + def get_url_for_login_with(self, next_url: str | None = None) -> str: + return get_auth_manager().get_url_login(next_url=next_url) + + @property + def get_url_for_login(self): + return get_auth_manager().get_url_login() + @property def get_url_for_index(self): - # TODO: Return the fast api application homepage return url_for(f"{self.indexview.endpoint}.{self.indexview.default_view}") def get_url_for_locale(self, lang): @@ -513,6 +621,12 @@ def add_limits(self, baseview) -> None: if hasattr(baseview, "limits"): self.sm.add_limit_view(baseview) + def add_permissions(self, update_perms=False): + if self.update_perms or update_perms: + for baseview in self.baseviews: + self._add_permission(baseview, update_perms=update_perms) + self._add_menu_permissions(update_perms=update_perms) + def _add_permission(self, baseview, update_perms=False): if self.update_perms or update_perms: try: diff --git a/providers/src/airflow/providers/fab/www/extensions/init_views.py b/providers/src/airflow/providers/fab/www/extensions/init_views.py index 382bcaf9ca748..301b00b5eadaa 100644 --- a/providers/src/airflow/providers/fab/www/extensions/init_views.py +++ b/providers/src/airflow/providers/fab/www/extensions/init_views.py @@ -18,17 +18,58 @@ import logging from functools import cached_property +from pathlib import Path from typing import TYPE_CHECKING -from connexion import Resolver +from connexion import FlaskApi, Resolver from connexion.decorators.validation import RequestBodyValidator -from connexion.exceptions import BadRequestProblem +from connexion.exceptions import BadRequestProblem, ProblemException +from flask import request + +from airflow.api_connexion.exceptions import common_error_handler +from airflow.api_fastapi.app import get_auth_manager +from airflow.configuration import conf +from airflow.providers.fab.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED +from airflow.utils.yaml import safe_load if TYPE_CHECKING: from flask import Flask log = logging.getLogger(__name__) +# providers/src/airflow/providers/fab/www/extensions/init_views.py => airflow/ +ROOT_APP_DIR = Path(__file__).parents[7].joinpath("airflow").resolve() + + +def init_appbuilder_views(app): + """Initialize Web UI views.""" + from airflow.www import views + + appbuilder = app.appbuilder + + appbuilder.session.remove() + appbuilder.add_view_no_menu(views.AutocompleteView()) + appbuilder.add_view_no_menu(views.Airflow()) + + +def set_cors_headers_on_response(response): + """Add response headers.""" + allow_headers = conf.get("api", "access_control_allow_headers") + allow_methods = conf.get("api", "access_control_allow_methods") + allow_origins = conf.get("api", "access_control_allow_origins") + if allow_headers: + response.headers["Access-Control-Allow-Headers"] = allow_headers + if allow_methods: + response.headers["Access-Control-Allow-Methods"] = allow_methods + if allow_origins == "*": + response.headers["Access-Control-Allow-Origin"] = "*" + elif allow_origins: + allowed_origins = allow_origins.split(" ") + origin = request.environ.get("HTTP_ORIGIN", allowed_origins[0]) + if origin in allowed_origins: + response.headers["Access-Control-Allow-Origin"] = origin + return response + class _LazyResolution: """ @@ -78,6 +119,59 @@ def validate_schema(self, data, url): return super().validate_schema(data, url) +base_paths: list[str] = [] # contains the list of base paths that have api endpoints + + +def init_api_error_handlers(app: Flask) -> None: + """Add error handlers for 404 and 405 errors for existing API paths.""" + + @app.errorhandler(404) + def _handle_api_not_found(ex): + if any([request.path.startswith(p) for p in base_paths]): + # 404 errors are never handled on the blueprint level + # unless raised from a view func so actual 404 errors, + # i.e. "no route for it" defined, need to be handled + # here on the application level + return common_error_handler(ex) + else: + from airflow.providers.fab.www.views import not_found + + return not_found(ex) + + @app.errorhandler(405) + def _handle_method_not_allowed(ex): + if any([request.path.startswith(p) for p in base_paths]): + return common_error_handler(ex) + else: + from airflow.providers.fab.www.views import method_not_allowed + + return method_not_allowed(ex) + + app.register_error_handler(ProblemException, common_error_handler) + + +def init_api_connexion(app: Flask) -> None: + """Initialize Stable API.""" + base_path = "/api/v1" + base_paths.append(base_path) + + with ROOT_APP_DIR.joinpath("api_connexion", "openapi", "v1.yaml").open() as f: + specification = safe_load(f) + api_bp = FlaskApi( + specification=specification, + resolver=_LazyResolver(), + base_path=base_path, + options={"swagger_ui": SWAGGER_ENABLED, "swagger_path": SWAGGER_BUNDLE.__fspath__()}, + strict_validation=True, + validate_responses=True, + validator_map={"body": _CustomErrorRequestBodyValidator}, + ).blueprint + api_bp.after_request(set_cors_headers_on_response) + + app.register_blueprint(api_bp) + app.extensions["csrf"].exempt(api_bp) + + def init_plugins(app): """Integrate Flask and FAB with plugins.""" from airflow import plugins_manager @@ -118,3 +212,13 @@ def init_error_handlers(app: Flask): app.register_error_handler(500, views.show_traceback) app.register_error_handler(404, views.not_found) + + +def init_api_auth_provider(app): + """Initialize the API offered by the auth manager.""" + auth_mgr = get_auth_manager() + blueprint = auth_mgr.get_api_endpoints() + if blueprint: + base_paths.append(blueprint.url_prefix) + app.register_blueprint(blueprint) + app.extensions["csrf"].exempt(blueprint) diff --git a/providers/src/airflow/providers/fab/www/templates/airflow/no_roles_permissions.html b/providers/src/airflow/providers/fab/www/templates/airflow/no_roles_permissions.html new file mode 100644 index 0000000000000..fa619c403c030 --- /dev/null +++ b/providers/src/airflow/providers/fab/www/templates/airflow/no_roles_permissions.html @@ -0,0 +1,42 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +#} + + + + + Airflow + + + +
+ pin-logo +

Your user has no roles and/or permissions!

+

Unfortunately your user has no roles, and therefore you cannot use Airflow.

+

Please contact your Airflow administrator + (authentication + may be misconfigured) or +

+ + log out to try again. +
+

+

{{ hostname }}

+
+ + diff --git a/providers/src/airflow/providers/fab/www/views.py b/providers/src/airflow/providers/fab/www/views.py index 48bf0bfddffaf..2ecb87daf8995 100644 --- a/providers/src/airflow/providers/fab/www/views.py +++ b/providers/src/airflow/providers/fab/www/views.py @@ -20,15 +20,28 @@ import sys import traceback +import lazy_object_proxy from flask import ( + current_app, render_template, ) +from flask_appbuilder import BaseView, expose +from itsdangerous import URLSafeSerializer +from airflow import settings from airflow.api_fastapi.app import get_auth_manager +from airflow.auth.managers.models.resource_details import DagDetails from airflow.configuration import conf +from airflow.executors.executor_loader import ExecutorLoader +from airflow.jobs.scheduler_job_runner import SchedulerJobRunner +from airflow.jobs.triggerer_job_runner import TriggererJobRunner +from airflow.utils.docs import get_docs_url from airflow.utils.net import get_hostname from airflow.version import version +FILTER_TAGS_COOKIE = "tags_filter" +FILTER_LASTRUN_COOKIE = "last_run_filter" + def not_found(error): """Show Not Found on screen for any error in the Webserver.""" @@ -43,6 +56,19 @@ def not_found(error): ) +def method_not_allowed(error): + """Show Method Not Allowed on screen for any error in the Webserver.""" + return ( + render_template( + "airflow/error.html", + hostname=get_hostname() if conf.getboolean("webserver", "EXPOSE_HOSTNAME") else "", + status_code=405, + error_message="Received an invalid request.", + ), + 405, + ) + + def show_traceback(error): """Show Traceback for a given error.""" is_logged_in = get_auth_manager().is_logged_in() @@ -64,3 +90,56 @@ def show_traceback(error): ), 500, ) + + +class AirflowBaseView(BaseView): + """Base View to set Airflow related properties.""" + + from airflow import macros + + route_base = "" + + extra_args = { + # Make our macros available to our UI templates too. + "macros": macros, + "get_docs_url": get_docs_url, + } + + if not conf.getboolean("core", "unit_test_mode"): + executor, _ = ExecutorLoader.import_default_executor_cls() + extra_args["sqlite_warning"] = settings.engine and (settings.engine.dialect.name == "sqlite") + if not executor.is_production: + extra_args["production_executor_warning"] = executor.__name__ + extra_args["otel_metrics_on"] = conf.getboolean("metrics", "otel_on") + extra_args["otel_traces_on"] = conf.getboolean("traces", "otel_on") + + line_chart_attr = { + "legend.maxKeyLength": 200, + } + + def render_template(self, *args, **kwargs): + # Add triggerer_job only if we need it + if TriggererJobRunner.is_needed(): + kwargs["triggerer_job"] = lazy_object_proxy.Proxy(TriggererJobRunner.most_recent_job) + + if "dag" in kwargs: + kwargs["can_edit_dag"] = get_auth_manager().is_authorized_dag( + method="PUT", details=DagDetails(id=kwargs["dag"].dag_id) + ) + url_serializer = URLSafeSerializer(current_app.config["SECRET_KEY"]) + kwargs["dag_file_token"] = url_serializer.dumps(kwargs["dag"].fileloc) + + return super().render_template( + *args, + # Cache this at most once per request, not for the lifetime of the view instance + scheduler_job=lazy_object_proxy.Proxy(SchedulerJobRunner.most_recent_job), + **kwargs, + ) + + +class Airflow(AirflowBaseView): + """Main Airflow application.""" + + @expose("/home") + def index(self): + return self.render_template("airflow/main.html") diff --git a/providers/tests/fab/auth_manager/api/auth/backend/test_basic_auth.py b/providers/tests/fab/auth_manager/api/auth/backend/test_basic_auth.py index af893b87c8b60..c4a03486de938 100644 --- a/providers/tests/fab/auth_manager/api/auth/backend/test_basic_auth.py +++ b/providers/tests/fab/auth_manager/api/auth/backend/test_basic_auth.py @@ -23,7 +23,7 @@ from flask_appbuilder.const import AUTH_LDAP from airflow.providers.fab.auth_manager.api.auth.backend.basic_auth import requires_authentication -from airflow.www import app as application +from airflow.providers.fab.www import app as application @pytest.fixture diff --git a/providers/tests/fab/auth_manager/api/auth/backend/test_session.py b/providers/tests/fab/auth_manager/api/auth/backend/test_session.py index ed7cd2bf45869..0b1d4f512ec8d 100644 --- a/providers/tests/fab/auth_manager/api/auth/backend/test_session.py +++ b/providers/tests/fab/auth_manager/api/auth/backend/test_session.py @@ -22,7 +22,7 @@ from flask import Response from airflow.providers.fab.auth_manager.api.auth.backend.session import requires_authentication -from airflow.www import app as application +from airflow.providers.fab.www import app as application @pytest.fixture diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_asset_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_asset_endpoint.py deleted file mode 100644 index 29eb85acb0675..0000000000000 --- a/providers/tests/fab/auth_manager/api_endpoints/test_asset_endpoint.py +++ /dev/null @@ -1,325 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from collections.abc import Generator - -import pytest -import time_machine - -from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.providers.fab.www.security import permissions -from airflow.utils import timezone - -from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import clear_db_assets, clear_db_runs -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS -from tests_common.test_utils.www import _check_last_log - -try: - from airflow.models.asset import AssetDagRunQueue, AssetModel -except ImportError: - if AIRFLOW_V_3_0_PLUS: - raise - else: - pass - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - create_user( - app, - username="test_queued_event", - role_name="TestQueuedEvent", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET), - ], - ) - - yield app - - delete_user(app, username="test_queued_event") - - -class TestAssetEndpoint: - default_time = "2020-06-11T18:00:00+00:00" - - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() - clear_db_assets() - clear_db_runs() - - def teardown_method(self) -> None: - clear_db_assets() - clear_db_runs() - - def _create_asset(self, session): - asset_model = AssetModel( - id=1, - uri="s3://bucket/key", - extra={"foo": "bar"}, - created_at=timezone.parse(self.default_time), - updated_at=timezone.parse(self.default_time), - ) - session.add(asset_model) - session.commit() - return asset_model - - -class TestQueuedEventEndpoint(TestAssetEndpoint): - @pytest.fixture - def time_freezer(self) -> Generator: - freezer = time_machine.travel(self.default_time, tick=False) - freezer.start() - - yield - - freezer.stop() - - def _create_asset_dag_run_queues(self, dag_id, asset_id, session): - ddrq = AssetDagRunQueue(target_dag_id=dag_id, asset_id=asset_id) - session.add(ddrq) - session.commit() - return ddrq - - -class TestGetDagAssetQueuedEvent(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - asset_id = self._create_asset(session).id - self._create_asset_dag_run_queues(dag_id, asset_id, session) - asset_uri = "s3://bucket/key" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - - def test_should_respond_404(self): - dag_id = "not_exists" - asset_uri = "not_exists" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert response.json == { - "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } - - -class TestDeleteDagAssetQueuedEvent(TestAssetEndpoint): - def test_delete_should_respond_204(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - asset_uri = "s3://bucket/key" - asset_id = self._create_asset(session).id - - ddrq = AssetDagRunQueue(target_dag_id=dag_id, asset_id=asset_id) - session.add(ddrq) - session.commit() - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 1 - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 204 - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 0 - _check_last_log(session, dag_id=dag_id, event="api.delete_dag_asset_queued_event", logical_date=None) - - def test_should_respond_404(self): - dag_id = "not_exists" - asset_uri = "not_exists" - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert response.json == { - "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } - - -class TestGetDagAssetQueuedEvents(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - asset_id = self._create_asset(session).id - self._create_asset_dag_run_queues(dag_id, asset_id, session) - - response = self.client.get( - f"/api/v1/dags/{dag_id}/assets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "queued_events": [ - { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - ], - "total_entries": 1, - } - - def test_should_respond_404(self): - dag_id = "not_exists" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/assets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert response.json == { - "detail": "Queue event with dag_id: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } - - -class TestDeleteDagDatasetQueuedEvents(TestAssetEndpoint): - def test_should_respond_404(self): - dag_id = "not_exists" - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/assets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert response.json == { - "detail": "Queue event with dag_id: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } - - -class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - asset_id = self._create_asset(session).id - self._create_asset_dag_run_queues(dag_id, asset_id, session) - asset_uri = "s3://bucket/key" - - response = self.client.get( - f"/api/v1/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "queued_events": [ - { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - ], - "total_entries": 1, - } - - def test_should_respond_404(self): - asset_uri = "not_exists" - - response = self.client.get( - f"/api/v1/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert response.json == { - "detail": "Queue event with asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } - - -class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint): - def test_delete_should_respond_204(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - asset_id = self._create_asset(session).id - self._create_asset_dag_run_queues(dag_id, asset_id, session) - asset_uri = "s3://bucket/key" - - response = self.client.delete( - f"/api/v1/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 204 - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 0 - _check_last_log(session, dag_id=None, event="api.delete_asset_queued_events", logical_date=None) - - def test_should_respond_404(self): - asset_uri = "not_exists" - - response = self.client.delete( - f"/api/v1/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert response.json == { - "detail": "Queue event with asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_auth.py b/providers/tests/fab/auth_manager/api_endpoints/test_auth.py index 4f8bc12702ff4..63a381b080271 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_auth.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_auth.py @@ -25,7 +25,6 @@ from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import clear_db_pools from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS -from tests_common.test_utils.www import client_with_login pytestmark = [ pytest.mark.db_test, @@ -55,7 +54,7 @@ def set_attrs(self, minimal_app_for_auth_api): class TestBasicAuth(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") def with_basic_auth_backend(self, minimal_app_for_auth_api): - from airflow.www.extensions.init_security import init_api_auth + from airflow.providers.fab.www.extensions.init_security import init_api_auth old_auth = getattr(minimal_app_for_auth_api, "api_auth") @@ -132,44 +131,3 @@ def test_invalid_auth_header(self, token): assert response.headers["Content-Type"] == "application/problem+json" assert response.headers["WWW-Authenticate"] == "Basic" assert_401(response) - - -class TestSessionWithBasicAuthFallback(BaseTestAuth): - @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_auth_api): - from airflow.www.extensions.init_security import init_api_auth - - old_auth = getattr(minimal_app_for_auth_api, "api_auth") - - try: - with conf_vars( - { - ( - "api", - "auth_backends", - ): "airflow.providers.fab.auth_manager.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth" - } - ): - init_api_auth(minimal_app_for_auth_api) - yield - finally: - setattr(minimal_app_for_auth_api, "api_auth", old_auth) - - def test_basic_auth_fallback(self): - token = "Basic " + b64encode(b"test:test").decode() - clear_db_pools() - - # request uses session - admin_user = client_with_login(self.app, username="test", password="test") - response = admin_user.get("/api/v1/pools") - assert response.status_code == 200 - - # request uses basic auth - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 200 - - # request without session or basic auth header - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools") - assert response.status_code == 401 diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_cors.py b/providers/tests/fab/auth_manager/api_endpoints/test_cors.py index b8947925b1ec5..ca2ec12c0422f 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_cors.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_cors.py @@ -52,7 +52,7 @@ def set_attrs(self, minimal_app_for_auth_api): class TestEmptyCors(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") def with_basic_auth_backend(self, minimal_app_for_auth_api): - from airflow.www.extensions.init_security import init_api_auth + from airflow.providers.fab.www.extensions.init_security import init_api_auth old_auth = getattr(minimal_app_for_auth_api, "api_auth") @@ -80,7 +80,7 @@ def test_empty_cors_headers(self): class TestCorsOrigin(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") def with_basic_auth_backend(self, minimal_app_for_auth_api): - from airflow.www.extensions.init_security import init_api_auth + from airflow.providers.fab.www.extensions.init_security import init_api_auth old_auth = getattr(minimal_app_for_auth_api, "api_auth") @@ -124,7 +124,7 @@ def test_cors_origin_reflection(self): class TestCorsWildcard(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") def with_basic_auth_backend(self, minimal_app_for_auth_api): - from airflow.www.extensions.init_security import init_api_auth + from airflow.providers.fab.www.extensions.init_security import init_api_auth old_auth = getattr(minimal_app_for_auth_api, "api_auth") diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_endpoint.py deleted file mode 100644 index 9b7ba51f25b08..0000000000000 --- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_endpoint.py +++ /dev/null @@ -1,226 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pendulum -import pytest - -from airflow.models import DagModel -from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.providers.fab.www.security import permissions -from airflow.utils.session import provide_session - -from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS -from tests_common.test_utils.www import _check_last_log - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -@pytest.fixture -def current_file_token(url_safe_serializer) -> str: - return url_safe_serializer.dumps(__file__) - - -DAG_ID = "test_dag" -TASK_ID = "op1" -DAG2_ID = "test_dag2" -DAG3_ID = "test_dag3" -UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')" - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - - create_user(app, username="test_granular_permissions", role_name="TestGranularDag") - app.appbuilder.sm.sync_perm_for_dag( - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, - ) - app.appbuilder.sm.sync_perm_for_dag( - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, - ) - - yield app - - delete_user(app, username="test_granular_permissions") - - -class TestDagEndpoint: - @staticmethod - def clean_db(): - clear_db_runs() - clear_db_dags() - clear_db_serialized_dags() - - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.clean_db() - self.app = configured_app - self.client = self.app.test_client() # type:ignore - self.dag_id = DAG_ID - self.dag2_id = DAG2_ID - self.dag3_id = DAG3_ID - - def teardown_method(self) -> None: - self.clean_db() - - @provide_session - def _create_dag_models(self, count, dag_id_prefix="TEST_DAG", is_paused=False, session=None): - for num in range(1, count + 1): - dag_model = DagModel( - dag_id=f"{dag_id_prefix}_{num}", - fileloc=f"/tmp/dag_{num}.py", - timetable_summary="2 2 * * *", - is_active=True, - is_paused=is_paused, - ) - session.add(dag_model) - - @provide_session - def _create_dag_model_for_details_endpoint(self, dag_id, session=None): - dag_model = DagModel( - dag_id=dag_id, - fileloc="/tmp/dag.py", - timetable_summary="2 2 * * *", - is_active=True, - is_paused=False, - ) - session.add(dag_model) - - @provide_session - def _create_dag_model_for_details_endpoint_with_asset_expression(self, dag_id, session=None): - dag_model = DagModel( - dag_id=dag_id, - fileloc="/tmp/dag.py", - timetable_summary="2 2 * * *", - is_active=True, - is_paused=False, - asset_expression={ - "any": [ - "s3://dag1/output_1.txt", - {"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]}, - ] - }, - ) - session.add(dag_model) - - @provide_session - def _create_deactivated_dag(self, session=None): - dag_model = DagModel( - dag_id="TEST_DAG_DELETED_1", - fileloc="/tmp/dag_del_1.py", - timetable_summary="2 2 * * *", - is_active=False, - ) - session.add(dag_model) - - -class TestGetDag(TestDagEndpoint): - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(1) - response = self.client.get( - "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - - def test_should_respond_403_with_granular_access_for_different_dag(self): - self._create_dag_models(3) - response = self.client.get( - "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 403 - - -class TestGetDags(TestDagEndpoint): - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(3) - response = self.client.get( - "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" - - -class TestPatchDag(TestDagEndpoint): - @provide_session - def _create_dag_model(self, session=None): - dag_model = DagModel( - dag_id="TEST_DAG_1", fileloc="/tmp/dag_1.py", timetable_summary="2 2 * * *", is_paused=True - ) - session.add(dag_model) - return dag_model - - def test_should_respond_200_on_patch_with_granular_dag_access(self, session): - self._create_dag_models(1) - response = self.client.patch( - "/api/v1/dags/TEST_DAG_1", - json={ - "is_paused": False, - }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", logical_date=None) - - def test_validation_error_raises_400(self): - patch_body = { - "ispaused": True, - } - dag_model = self._create_dag_model() - response = self.client.patch( - f"/api/v1/dags/{dag_model.dag_id}", - json=patch_body, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 400 - assert response.json == { - "detail": "{'ispaused': ['Unknown field.']}", - "status": 400, - "title": "Bad Request", - "type": EXCEPTIONS_LINK_MAP[400], - } - - -class TestPatchDags(TestDagEndpoint): - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(3) - response = self.client.patch( - "api/v1/dags?dag_id_pattern=~", - json={ - "is_paused": False, - }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py deleted file mode 100644 index e745d3d655bdc..0000000000000 --- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py +++ /dev/null @@ -1,267 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from datetime import timedelta - -import pytest - -from airflow.models.dag import DagModel -from airflow.models.dagrun import DagRun -from airflow.models.param import Param -from airflow.providers.fab.www.security import permissions -from airflow.utils import timezone -from airflow.utils.session import create_session -from airflow.utils.state import DagRunState - -from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import ( - create_user, - delete_roles, - delete_user, -) -from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -try: - from airflow.utils.types import DagRunTriggeredByType, DagRunType -except ImportError: - if AIRFLOW_V_3_0_PLUS: - raise - else: - pass - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - - create_user( - app, - username="test_no_dag_run_create_permission", - role_name="TestNoDagRunCreatePermission", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, - username="test_dag_view_only", - role_name="TestViewDags", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, - username="test_view_dags", - role_name="TestViewDags", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, - username="test_granular_permissions", - role_name="TestGranularDag", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)], - ) - app.appbuilder.sm.sync_perm_for_dag( - "TEST_DAG_ID", - access_control={ - "TestGranularDag": {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}, - "TestNoDagRunCreatePermission": {permissions.RESOURCE_DAG_RUN: {permissions.ACTION_CAN_CREATE}}, - }, - ) - - yield app - - delete_user(app, username="test_dag_view_only") - delete_user(app, username="test_view_dags") - delete_user(app, username="test_granular_permissions") - delete_user(app, username="test_no_dag_run_create_permission") - delete_roles(app) - - -class TestDagRunEndpoint: - default_time = "2020-06-11T18:00:00+00:00" - default_time_2 = "2020-06-12T18:00:00+00:00" - default_time_3 = "2020-06-13T18:00:00+00:00" - - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - clear_db_runs() - clear_db_serialized_dags() - clear_db_dags() - - @pytest.fixture(autouse=True) - def create_dag(self, dag_maker, setup_attrs): - with dag_maker( - "TEST_DAG_ID", schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)} - ): - pass - - dag_maker.sync_dagbag_to_db() - - def teardown_method(self) -> None: - clear_db_runs() - clear_db_dags() - clear_db_serialized_dags() - - def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): - dag_runs = [] - dags = [] - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - - for i in range(idx_start, idx_start + 2): - dagrun_model = DagRun( - dag_id="TEST_DAG_ID", - run_id=f"TEST_DAG_RUN_ID_{i}", - run_type=DagRunType.MANUAL, - logical_date=timezone.parse(self.default_time) + timedelta(days=i - 1), - start_date=timezone.parse(self.default_time), - external_trigger=True, - state=state, - **triggered_by_kwargs, - ) - dagrun_model.updated_at = timezone.parse(self.default_time) - dag_runs.append(dagrun_model) - - if extra_dag: - for i in range(idx_start + 2, idx_start + 4): - dags.append(DagModel(dag_id=f"TEST_DAG_ID_{i}")) - dag_runs.append( - DagRun( - dag_id=f"TEST_DAG_ID_{i}", - run_id=f"TEST_DAG_RUN_ID_{i}", - run_type=DagRunType.MANUAL, - logical_date=timezone.parse(self.default_time_2), - start_date=timezone.parse(self.default_time), - external_trigger=True, - state=state, - ) - ) - if commit: - with create_session() as session: - session.add_all(dag_runs) - session.add_all(dags) - return dag_runs - - -class TestGetDagRuns(TestDagRunEndpoint): - def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): - self._create_test_dag_run(extra_dag=True) - expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"] - response = self.client.get( - "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] - assert dag_run_ids == expected_dag_run_ids - - -class TestGetDagRunBatch(TestDagRunEndpoint): - def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): - self._create_test_dag_run(extra_dag=True) - expected_response_json_1 = { - "dag_id": "TEST_DAG_ID", - "dag_run_id": "TEST_DAG_RUN_ID_1", - "end_date": None, - "state": "running", - "logical_date": self.default_time, - "external_trigger": True, - "start_date": self.default_time, - "conf": {}, - "data_interval_end": None, - "data_interval_start": None, - "last_scheduling_decision": None, - "run_type": "manual", - "note": None, - } - expected_response_json_1.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) - expected_response_json_2 = { - "dag_id": "TEST_DAG_ID", - "dag_run_id": "TEST_DAG_RUN_ID_2", - "end_date": None, - "state": "running", - "logical_date": self.default_time_2, - "external_trigger": True, - "start_date": self.default_time, - "conf": {}, - "data_interval_end": None, - "data_interval_start": None, - "last_scheduling_decision": None, - "run_type": "manual", - "note": None, - } - expected_response_json_2.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) - - response = self.client.post( - "api/v1/dags/~/dagRuns/list", - json={"dag_ids": []}, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - assert response.json == { - "dag_runs": [ - expected_response_json_1, - expected_response_json_2, - ], - "total_entries": 2, - } - - -class TestPostDagRun(TestDagRunEndpoint): - def test_dagrun_trigger_with_dag_level_permissions(self): - response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns", - json={"conf": {"validated_number": 1}}, - environ_overrides={"REMOTE_USER": "test_no_dag_run_create_permission"}, - ) - assert response.status_code == 200 - - @pytest.mark.parametrize( - "username", - ["test_dag_view_only", "test_view_dags", "test_granular_permissions"], - ) - def test_should_raises_403_unauthorized(self, username): - response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns", - json={ - "dag_run_id": "TEST_DAG_RUN_ID_1", - "logical_date": self.default_time, - }, - environ_overrides={"REMOTE_USER": username}, - ) - assert response.status_code == 403 diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py deleted file mode 100644 index 66cd6477c9e9b..0000000000000 --- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py +++ /dev/null @@ -1,132 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import ast -import os - -import pytest - -from airflow.models import DagBag -from airflow.providers.fab.www.security import permissions - -from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import ( - clear_db_dag_code, - clear_db_dags, - clear_db_serialized_dags, - parse_and_sync_to_db, -) -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -EXAMPLE_DAG_ID = "example_bash_operator" -TEST_DAG_ID = "latest_only" -NOT_READABLE_DAG_ID = "latest_only_with_trigger" -TEST_MULTIPLE_DAGS_ID = "asset_produces_1" - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - create_user( - app, - username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], - ) - app.appbuilder.sm.sync_perm_for_dag( - TEST_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( - EXAMPLE_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( - TEST_MULTIPLE_DAGS_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - - yield app - - delete_user(app, username="test") - - -class TestGetSource: - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - self.clear_db() - - def teardown_method(self) -> None: - self.clear_db() - - @staticmethod - def clear_db(): - clear_db_dags() - clear_db_serialized_dags() - clear_db_dag_code() - - @staticmethod - def _get_dag_file_docstring(fileloc: str) -> str | None: - with open(fileloc) as f: - file_contents = f.read() - module = ast.parse(file_contents) - docstring = ast.get_docstring(module) - return docstring - - def test_should_respond_403_not_readable(self, url_safe_serializer): - parse_and_sync_to_db(os.devnull, include_examples=True) - dagbag = DagBag(read_dags_from_db=True) - dag = dagbag.get_dag(NOT_READABLE_DAG_ID) - - response = self.client.get( - f"/api/v1/dagSources/{dag.dag_id}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, - ) - read_dag = self.client.get( - f"/api/v1/dags/{NOT_READABLE_DAG_ID}", - environ_overrides={"REMOTE_USER": "test"}, - ) - assert response.status_code == 403 - assert read_dag.status_code == 403 - - def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer): - parse_and_sync_to_db(os.devnull, include_examples=True) - dagbag = DagBag(read_dags_from_db=True) - dag = dagbag.get_dag(TEST_MULTIPLE_DAGS_ID) - - response = self.client.get( - f"/api/v1/dagSources/{dag.dag_id}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, - ) - - read_dag = self.client.get( - f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}", - environ_overrides={"REMOTE_USER": "test"}, - ) - assert response.status_code == 403 - assert read_dag.status_code == 200 diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py deleted file mode 100644 index 34862c6b4c7af..0000000000000 --- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py +++ /dev/null @@ -1,84 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.models.dag import DagModel -from airflow.models.dagwarning import DagWarning -from airflow.providers.fab.www.security import permissions -from airflow.utils.session import create_session - -from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import clear_db_dag_warnings, clear_db_dags -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - create_user( - app, # type:ignore - username="test_with_dag2_read", - role_name="TestWithDag2Read", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), - (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"), - ], - ) - - yield app - - delete_user(app, username="test_with_dag2_read") - - -class TestBaseDagWarning: - timestamp = "2020-06-10T12:00" - - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - - def teardown_method(self) -> None: - clear_db_dag_warnings() - clear_db_dags() - - -class TestGetDagWarningEndpoint(TestBaseDagWarning): - def setup_class(self): - clear_db_dag_warnings() - clear_db_dags() - - def setup_method(self): - with create_session() as session: - session.add(DagModel(dag_id="dag1")) - session.add(DagWarning("dag1", "non-existent pool", "test message")) - session.commit() - - def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): - response = self.client.get( - "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, - query_string={"dag_id": "dag1"}, - ) - assert response.status_code == 403 diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_event_log_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_event_log_endpoint.py deleted file mode 100644 index 168bcae4b9f38..0000000000000 --- a/providers/tests/fab/auth_manager/api_endpoints/test_event_log_endpoint.py +++ /dev/null @@ -1,151 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.models import Log -from airflow.providers.fab.www.security import permissions -from airflow.utils import timezone - -from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import clear_db_logs -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - create_user( - app, - username="test_granular", - role_name="TestGranular", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], - ) - app.appbuilder.sm.sync_perm_for_dag( - "TEST_DAG_ID_1", - access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( - "TEST_DAG_ID_2", - access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, - ) - - yield app - - delete_user(app, username="test_granular") - - -@pytest.fixture -def task_instance(session, create_task_instance, request): - return create_task_instance( - session=session, - dag_id="TEST_DAG_ID", - task_id="TEST_TASK_ID", - run_id="TEST_RUN_ID", - logical_date=request.instance.default_time, - ) - - -@pytest.fixture -def create_log_model(create_task_instance, task_instance, session, request): - def maker(event, when, **kwargs): - log_model = Log( - event=event, - task_instance=task_instance, - **kwargs, - ) - log_model.dttm = when - - session.add(log_model) - session.flush() - return log_model - - return maker - - -class TestEventLogEndpoint: - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - clear_db_logs() - self.default_time = timezone.parse("2020-06-10T20:00:00+00:00") - self.default_time_2 = timezone.parse("2020-06-11T07:00:00+00:00") - - def teardown_method(self) -> None: - clear_db_logs() - - -class TestGetEventLogs(TestEventLogEndpoint): - def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session): - eventlog1 = create_log_model( - event="TEST_EVENT_1", - dag_id="TEST_DAG_ID_1", - task_id="TEST_TASK_ID_1", - owner="TEST_OWNER_1", - when=self.default_time, - ) - eventlog2 = create_log_model( - event="TEST_EVENT_2", - dag_id="TEST_DAG_ID_2", - task_id="TEST_TASK_ID_2", - owner="TEST_OWNER_2", - when=self.default_time_2, - ) - session.add_all([eventlog1, eventlog2]) - session.commit() - for attr in ["dag_id", "task_id", "owner", "event"]: - attr_value = f"TEST_{attr}_1".upper() - response = self.client.get( - f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} - ) - assert response.status_code == 200 - assert response.json["total_entries"] == 1 - assert len(response.json["event_logs"]) == 1 - assert response.json["event_logs"][0][attr] == attr_value - - def test_should_filter_eventlogs_by_included_events(self, create_log_model): - for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: - create_log_model(event=event, when=self.default_time) - response = self.client.get( - "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, - ) - assert response.status_code == 200 - response_data = response.json - assert len(response_data["event_logs"]) == 2 - assert response_data["total_entries"] == 2 - assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]} - - def test_should_filter_eventlogs_by_excluded_events(self, create_log_model): - for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: - create_log_model(event=event, when=self.default_time) - response = self.client.get( - "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, - ) - assert response.status_code == 200 - response_data = response.json - assert len(response_data["event_logs"]) == 1 - assert response_data["total_entries"] == 1 - assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]} diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_import_error_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_import_error_endpoint.py deleted file mode 100644 index 19509aa558146..0000000000000 --- a/providers/tests/fab/auth_manager/api_endpoints/test_import_error_endpoint.py +++ /dev/null @@ -1,222 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.models.dag import DagModel -from airflow.providers.fab.www.security import permissions -from airflow.utils import timezone - -from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.compat import ParseImportError -from tests_common.test_utils.db import clear_db_dags, clear_db_import_errors -from tests_common.test_utils.permissions import _resource_name -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - -TEST_DAG_IDS = ["test_dag", "test_dag2"] - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - create_user( - app, - username="test_single_dag", - role_name="TestSingleDAG", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], - ) - # For some reason, DAG level permissions are not synced when in the above list of perms, - # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( - [ - { - "role": "TestSingleDAG", - "perms": [ - ( - permissions.ACTION_CAN_READ, - _resource_name(TEST_DAG_IDS[0], permissions.RESOURCE_DAG), - ) - ], - } - ] - ) - - yield app - - delete_user(app, username="test_single_dag") - - -class TestBaseImportError: - timestamp = "2020-06-10T12:00" - - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - - clear_db_import_errors() - clear_db_dags() - - def teardown_method(self) -> None: - clear_db_import_errors() - clear_db_dags() - - @staticmethod - def _normalize_import_errors(import_errors): - for i, import_error in enumerate(import_errors, 1): - import_error["import_error_id"] = i - - -class TestGetImportErrorEndpoint(TestBaseImportError): - def test_should_raise_403_forbidden_without_dag_read(self, session): - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 403 - - def test_should_return_200_with_single_dag_read(self, session): - dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py") - session.add(dag_model) - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - response_data["import_error_id"] = 1 - assert response_data == { - "filename": "Lorem_ipsum.py", - "import_error_id": 1, - "stack_trace": "Lorem ipsum", - "timestamp": "2020-06-10T12:00:00+00:00", - } - - def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session): - for dag_id in TEST_DAG_IDS: - dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py") - session.add(dag_model) - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - response_data["import_error_id"] = 1 - assert response_data == { - "filename": "Lorem_ipsum.py", - "import_error_id": 1, - "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", - "timestamp": "2020-06-10T12:00:00+00:00", - } - - -class TestGetImportErrorsEndpoint(TestBaseImportError): - def test_get_import_errors_single_dag(self, session): - for dag_id in TEST_DAG_IDS: - fake_filename = f"/tmp/{dag_id}.py" - dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) - session.add(dag_model) - importerror = ParseImportError( - filename=fake_filename, - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(importerror) - session.commit() - - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - self._normalize_import_errors(response_data["import_errors"]) - assert response_data == { - "import_errors": [ - { - "filename": "/tmp/test_dag.py", - "import_error_id": 1, - "stack_trace": "Lorem ipsum", - "timestamp": "2020-06-10T12:00:00+00:00", - }, - ], - "total_entries": 1, - } - - def test_get_import_errors_single_dag_in_dagfile(self, session): - for dag_id in TEST_DAG_IDS: - fake_filename = "/tmp/all_in_one.py" - dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) - session.add(dag_model) - - importerror = ParseImportError( - filename="/tmp/all_in_one.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(importerror) - session.commit() - - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - self._normalize_import_errors(response_data["import_errors"]) - assert response_data == { - "import_errors": [ - { - "filename": "/tmp/all_in_one.py", - "import_error_id": 1, - "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", - "timestamp": "2020-06-10T12:00:00+00:00", - }, - ], - "total_entries": 1, - } diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py deleted file mode 100644 index 24613fe7f2ae6..0000000000000 --- a/providers/tests/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py +++ /dev/null @@ -1,426 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import datetime as dt -import urllib - -import pytest - -from airflow.models import DagRun, TaskInstance -from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.providers.fab.www.security import permissions -from airflow.utils.session import provide_session -from airflow.utils.state import State -from airflow.utils.timezone import datetime -from airflow.utils.types import DagRunType - -from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import ( - create_user, - delete_roles, - delete_user, -) -from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - -DEFAULT_DATETIME_1 = datetime(2020, 1, 1) -DEFAULT_DATETIME_STR_1 = "2020-01-01T00:00:00+00:00" -DEFAULT_DATETIME_STR_2 = "2020-01-02T00:00:00+00:00" - -QUOTED_DEFAULT_DATETIME_STR_1 = urllib.parse.quote(DEFAULT_DATETIME_STR_1) -QUOTED_DEFAULT_DATETIME_STR_2 = urllib.parse.quote(DEFAULT_DATETIME_STR_2) - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - create_user( - app, - username="test_dag_read_only", - role_name="TestDagReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, - username="test_task_read_only", - role_name="TestTaskReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, - username="test_read_only_one_dag", - role_name="TestReadOnlyOneDag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms, - # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( - [ - { - "role": "TestReadOnlyOneDag", - "perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")], - } - ] - ) - - yield app - - delete_user(app, username="test_dag_read_only") - delete_user(app, username="test_task_read_only") - delete_user(app, username="test_read_only_one_dag") - delete_roles(app) - - -class TestTaskInstanceEndpoint: - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app, dagbag) -> None: - self.default_time = DEFAULT_DATETIME_1 - self.ti_init = { - "logical_date": self.default_time, - "state": State.RUNNING, - } - self.ti_extras = { - "start_date": self.default_time + dt.timedelta(days=1), - "end_date": self.default_time + dt.timedelta(days=2), - "pid": 100, - "duration": 10000, - "pool": "default_pool", - "queue": "default_queue", - "job_id": 0, - } - self.app = configured_app - self.client = self.app.test_client() # type:ignore - clear_db_runs() - clear_rendered_ti_fields() - self.dagbag = dagbag - - def create_task_instances( - self, - session, - dag_id: str = "example_python_operator", - update_extras: bool = True, - task_instances=None, - dag_run_state=State.RUNNING, - with_ti_history=False, - ): - """Method to create task instances using kwargs and default arguments""" - - dag = self.dagbag.get_dag(dag_id) - tasks = dag.tasks - counter = len(tasks) - if task_instances is not None: - counter = min(len(task_instances), counter) - - run_id = "TEST_DAG_RUN_ID" - logical_date = self.ti_init.pop("logical_date", self.default_time) - dr = None - - tis = [] - for i in range(counter): - if task_instances is None: - pass - elif update_extras: - self.ti_extras.update(task_instances[i]) - else: - self.ti_init.update(task_instances[i]) - - if "logical_date" in self.ti_init: - run_id = f"TEST_DAG_RUN_ID_{i}" - logical_date = self.ti_init.pop("logical_date") - dr = None - - if not dr: - dr = DagRun( - run_id=run_id, - dag_id=dag_id, - logical_date=logical_date, - run_type=DagRunType.MANUAL, - state=dag_run_state, - ) - session.add(dr) - ti = TaskInstance(task=tasks[i], **self.ti_init) - session.add(ti) - ti.dag_run = dr - ti.note = "placeholder-note" - - for key, value in self.ti_extras.items(): - setattr(ti, key, value) - tis.append(ti) - - session.commit() - if with_ti_history: - for ti in tis: - ti.try_number = 1 - session.merge(ti) - session.commit() - dag.clear() - for ti in tis: - ti.try_number = 2 - ti.queue = "default_queue" - session.merge(ti) - session.commit() - return tis - - -class TestGetTaskInstance(TestTaskInstanceEndpoint): - def setup_method(self): - clear_db_runs() - - def teardown_method(self): - clear_db_runs() - - @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) - @provide_session - def test_should_respond_200(self, username, session): - self.create_task_instances(session) - # Update ti and set operator to None to - # test that operator field is nullable. - # This prevents issue when users upgrade to 2.0+ - # from 1.10.x - # https://github.com/apache/airflow/issues/14421 - session.query(TaskInstance).update({TaskInstance.operator: None}, synchronize_session="fetch") - session.commit() - response = self.client.get( - "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": username}, - ) - assert response.status_code == 200 - - -class TestGetTaskInstances(TestTaskInstanceEndpoint): - @pytest.mark.parametrize( - "task_instances, user, expected_ti", - [ - pytest.param( - { - "example_python_operator": 2, - "example_skip_dag": 1, - }, - "test_read_only_one_dag", - 2, - ), - pytest.param( - { - "example_python_operator": 1, - "example_skip_dag": 2, - }, - "test_read_only_one_dag", - 1, - ), - ], - ) - def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session): - for dag_id in task_instances: - self.create_task_instances( - session, - task_instances=[ - {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=i)} - for i in range(task_instances[dag_id]) - ], - dag_id=dag_id, - ) - response = self.client.get( - "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user} - ) - assert response.status_code == 200 - assert response.json["total_entries"] == expected_ti - assert len(response.json["task_instances"]) == expected_ti - - -class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): - @pytest.mark.parametrize( - "task_instances, update_extras, payload, expected_ti_count, username", - [ - pytest.param( - [ - {"pool": "test_pool_1"}, - {"pool": "test_pool_2"}, - {"pool": "test_pool_3"}, - ], - True, - {"pool": ["test_pool_1", "test_pool_2"]}, - 2, - "test_dag_read_only", - id="test pool filter", - ), - pytest.param( - [ - {"state": State.RUNNING}, - {"state": State.QUEUED}, - {"state": State.SUCCESS}, - {"state": State.NONE}, - ], - False, - {"state": ["running", "queued", "none"]}, - 3, - "test_task_read_only", - id="test state filter", - ), - pytest.param( - [ - {"state": State.NONE}, - {"state": State.NONE}, - {"state": State.NONE}, - {"state": State.NONE}, - ], - False, - {}, - 4, - "test_task_read_only", - id="test dag with null states", - ), - pytest.param( - [ - {"end_date": DEFAULT_DATETIME_1}, - {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, - {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, - ], - True, - { - "end_date_gte": DEFAULT_DATETIME_STR_1, - "end_date_lte": DEFAULT_DATETIME_STR_2, - }, - 2, - "test_task_read_only", - id="test end date filter", - ), - pytest.param( - [ - {"start_date": DEFAULT_DATETIME_1}, - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, - ], - True, - { - "start_date_gte": DEFAULT_DATETIME_STR_1, - "start_date_lte": DEFAULT_DATETIME_STR_2, - }, - 2, - "test_dag_read_only", - id="test start date filter", - ), - ], - ) - def test_should_respond_200( - self, task_instances, update_extras, payload, expected_ti_count, username, session - ): - self.create_task_instances( - session, - update_extras=update_extras, - task_instances=task_instances, - ) - response = self.client.post( - "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": username}, - json=payload, - ) - assert response.status_code == 200, response.json - assert expected_ti_count == response.json["total_entries"] - assert expected_ti_count == len(response.json["task_instances"]) - - def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session): - self.create_task_instances(session=session) - self.create_task_instances(session=session, dag_id="example_skip_dag") - payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]} - - response = self.client.post( - "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test_read_only_one_dag"}, - json=payload, - ) - assert response.status_code == 403 - assert response.json == { - "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']", - "status": 403, - "title": "Forbidden", - "type": EXCEPTIONS_LINK_MAP[403], - } - - -class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint): - @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username): - response = self.client.post( - "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": username}, - json={ - "dry_run": True, - "task_id": "print_the_context", - "logical_date": DEFAULT_DATETIME_1.isoformat(), - "include_upstream": True, - "include_downstream": True, - "include_future": True, - "include_past": True, - "new_state": "failed", - }, - ) - assert response.status_code == 403 - - -class TestPatchTaskInstance(TestTaskInstanceEndpoint): - ENDPOINT_URL = ( - "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" - ) - - @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username): - response = self.client.patch( - self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": username}, - json={ - "dry_run": True, - "new_state": "failed", - }, - ) - assert response.status_code == 403 - - -class TestGetTaskInstanceTry(TestTaskInstanceEndpoint): - def setup_method(self): - clear_db_runs() - - def teardown_method(self): - clear_db_runs() - - @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) - @provide_session - def test_should_respond_200(self, username, session): - self.create_task_instances(session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True) - - response = self.client.get( - "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1", - environ_overrides={"REMOTE_USER": username}, - ) - assert response.status_code == 200 diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_variable_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_variable_endpoint.py deleted file mode 100644 index fcf29ab1af964..0000000000000 --- a/providers/tests/fab/auth_manager/api_endpoints/test_variable_endpoint.py +++ /dev/null @@ -1,88 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.models import Variable -from airflow.providers.fab.www.security import permissions - -from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import clear_db_variables -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - - create_user( - app, - username="test_read_only", - role_name="TestReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), - ], - ) - create_user( - app, - username="test_delete_only", - role_name="TestDeleteOnly", - permissions=[ - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), - ], - ) - - yield app - - delete_user(app, username="test_read_only") - delete_user(app, username="test_delete_only") - - -class TestVariableEndpoint: - @pytest.fixture(autouse=True) - def setup_method(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - clear_db_variables() - - def teardown_method(self) -> None: - clear_db_variables() - - -class TestGetVariable(TestVariableEndpoint): - @pytest.mark.parametrize( - "user, expected_status_code", - [ - ("test_read_only", 200), - ("test_delete_only", 403), - ], - ) - def test_read_variable(self, user, expected_status_code): - expected_value = '{"foo": 1}' - Variable.set("TEST_VARIABLE_KEY", expected_value) - response = self.client.get( - "/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": user} - ) - assert response.status_code == expected_status_code - if expected_status_code == 200: - assert response.json == {"key": "TEST_VARIABLE_KEY", "value": expected_value, "description": None} diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_xcom_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_xcom_endpoint.py deleted file mode 100644 index 85f40b7557c11..0000000000000 --- a/providers/tests/fab/auth_manager/api_endpoints/test_xcom_endpoint.py +++ /dev/null @@ -1,268 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from datetime import timedelta - -import pytest - -from airflow.models.dag import DagModel -from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstance -from airflow.models.xcom import BaseXCom, XCom -from airflow.operators.empty import EmptyOperator -from airflow.providers.fab.www.security import permissions -from airflow.utils import timezone -from airflow.utils.session import create_session -from airflow.utils.types import DagRunType - -from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -class CustomXCom(BaseXCom): - @classmethod - def deserialize_value(cls, xcom: XCom): - return f"real deserialized {super().deserialize_value(xcom)}" - - def orm_deserialize_value(self): - return f"orm deserialized {super().orm_deserialize_value()}" - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - - create_user( - app, - username="test_granular_permissions", - role_name="TestGranularDag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], - ) - app.appbuilder.sm.sync_perm_for_dag( - "test-dag-id-1", - access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, - ) - - yield app - - delete_user(app, username="test_granular_permissions") - - -def _compare_xcom_collections(collection1: dict, collection_2: dict): - assert collection1.get("total_entries") == collection_2.get("total_entries") - - def sort_key(record): - return ( - ( - record.get("dag_id"), - record.get("task_id"), - record.get("logical_date"), - record.get("map_index"), - record.get("key"), - ) - if AIRFLOW_V_3_0_PLUS - else ( - record.get("dag_id"), - record.get("task_id"), - record.get("execution_date"), - record.get("map_index"), - record.get("key"), - ) - ) - - assert sorted(collection1.get("xcom_entries", []), key=sort_key) == sorted( - collection_2.get("xcom_entries", []), key=sort_key - ) - - -class TestXComEndpoint: - @staticmethod - def clean_db(): - clear_db_dags() - clear_db_runs() - clear_db_xcom() - - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - """ - Setup For XCom endpoint TC - """ - self.app = configured_app - self.client = self.app.test_client() # type:ignore - # clear existing xcoms - self.clean_db() - - def teardown_method(self) -> None: - """ - Clear Hanging XComs - """ - self.clean_db() - - -class TestGetXComEntries(TestXComEndpoint): - def test_should_respond_200_with_tilde_and_granular_dag_access(self): - dag_id_1 = "test-dag-id-1" - task_id_1 = "test-task-id-1" - logical_date = "2005-04-02T00:00:00+00:00" - logical_date_parsed = timezone.parse(logical_date) - dag_run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, logical_date_parsed) - self._create_xcom_entries(dag_id_1, dag_run_id_1, logical_date_parsed, task_id_1) - - dag_id_2 = "test-dag-id-2" - task_id_2 = "test-task-id-2" - run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, logical_date_parsed) - self._create_xcom_entries(dag_id_2, run_id_2, logical_date_parsed, task_id_2) - self._create_invalid_xcom_entries(logical_date_parsed) - response = self.client.get( - "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - - assert response.status_code == 200 - response_data = response.json - for xcom_entry in response_data["xcom_entries"]: - xcom_entry["timestamp"] = "TIMESTAMP" - date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" - _compare_xcom_collections( - response_data, - { - "xcom_entries": [ - { - "dag_id": dag_id_1, - date_key: logical_date, - "key": "test-xcom-key-1", - "task_id": task_id_1, - "timestamp": "TIMESTAMP", - "map_index": -1, - }, - { - "dag_id": dag_id_1, - date_key: logical_date, - "key": "test-xcom-key-2", - "task_id": task_id_1, - "timestamp": "TIMESTAMP", - "map_index": -1, - }, - ], - "total_entries": 2, - }, - ) - - def _create_xcom_entries(self, dag_id, run_id, logical_date, task_id, mapped_ti=False): - with create_session() as session: - dag = DagModel(dag_id=dag_id) - session.add(dag) - if AIRFLOW_V_3_0_PLUS: - dagrun = DagRun( - dag_id=dag_id, - run_id=run_id, - logical_date=logical_date, - start_date=logical_date, - run_type=DagRunType.MANUAL, - ) - else: - dagrun = DagRun( - dag_id=dag_id, - run_id=run_id, - execution_date=logical_date, - start_date=logical_date, - run_type=DagRunType.MANUAL, - ) - session.add(dagrun) - if mapped_ti: - for i in [0, 1]: - ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, map_index=i) - ti.dag_id = dag_id - session.add(ti) - else: - ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id) - ti.dag_id = dag_id - session.add(ti) - - for i in [1, 2]: - if mapped_ti: - key = "test-xcom-key" - map_index = i - 1 - else: - key = f"test-xcom-key-{i}" - map_index = -1 - - XCom.set( - key=key, value="TEST", run_id=run_id, task_id=task_id, dag_id=dag_id, map_index=map_index - ) - - def _create_invalid_xcom_entries(self, logical_date): - """ - Invalid XCom entries to test join query - """ - with create_session() as session: - dag = DagModel(dag_id="invalid_dag") - session.add(dag) - if AIRFLOW_V_3_0_PLUS: - dagrun = DagRun( - dag_id="invalid_dag", - run_id="invalid_run_id", - logical_date=logical_date + timedelta(days=1), - start_date=logical_date, - run_type=DagRunType.MANUAL, - ) - else: - dagrun = DagRun( - dag_id="invalid_dag", - run_id="invalid_run_id", - execution_date=logical_date + timedelta(days=1), - start_date=logical_date, - run_type=DagRunType.MANUAL, - ) - session.add(dagrun) - if AIRFLOW_V_3_0_PLUS: - dagrun1 = DagRun( - dag_id="invalid_dag", - run_id="not_this_run_id", - logical_date=logical_date, - start_date=logical_date, - run_type=DagRunType.MANUAL, - ) - else: - dagrun1 = DagRun( - dag_id="invalid_dag", - run_id="not_this_run_id", - execution_date=logical_date, - start_date=logical_date, - run_type=DagRunType.MANUAL, - ) - session.add(dagrun1) - ti = TaskInstance(EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id") - ti.dag_id = "invalid_dag" - session.add(ti) - for i in [1, 2]: - XCom.set( - key=f"invalid-xcom-key-{i}", - value="TEST", - run_id="not_this_run_id", - task_id="invalid_task", - dag_id="invalid_dag", - ) diff --git a/providers/tests/fab/auth_manager/conftest.py b/providers/tests/fab/auth_manager/conftest.py index f26a08d19c3e3..1f3d0074f9258 100644 --- a/providers/tests/fab/auth_manager/conftest.py +++ b/providers/tests/fab/auth_manager/conftest.py @@ -20,7 +20,7 @@ import pytest -from airflow.www import app +from airflow.providers.fab.www import app from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import parse_and_sync_to_db diff --git a/providers/tests/fab/auth_manager/test_fab_auth_manager.py b/providers/tests/fab/auth_manager/test_fab_auth_manager.py index 077350ac10c6e..334e178195d5b 100644 --- a/providers/tests/fab/auth_manager/test_fab_auth_manager.py +++ b/providers/tests/fab/auth_manager/test_fab_auth_manager.py @@ -27,6 +27,7 @@ from flask_appbuilder.menu import Menu from airflow.exceptions import AirflowConfigException, AirflowException +from airflow.providers.fab.www.extensions.init_appbuilder import init_appbuilder from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user @@ -62,7 +63,6 @@ RESOURCE_VARIABLE, RESOURCE_WEBSITE, ) -from airflow.www.extensions.init_appbuilder import init_appbuilder if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod diff --git a/providers/tests/fab/auth_manager/test_security.py b/providers/tests/fab/auth_manager/test_security.py index 95c5545d87883..3ccd9a21ee1fc 100644 --- a/providers/tests/fab/auth_manager/test_security.py +++ b/providers/tests/fab/auth_manager/test_security.py @@ -36,6 +36,8 @@ from airflow.exceptions import AirflowException from airflow.models import DagModel from airflow.models.dag import DAG +from airflow.providers.fab.www.auth import get_access_denied_message, has_access_dag +from airflow.providers.fab.www.utils import CustomSQLAInterface from tests_common.test_utils.compat import ignore_provider_compatibility_error @@ -45,11 +47,9 @@ from airflow.providers.fab.auth_manager.models.anonymous_user import AnonymousUser from airflow.api_fastapi.app import get_auth_manager +from airflow.providers.fab.www import app as application from airflow.providers.fab.www.security import permissions from airflow.providers.fab.www.security.permissions import ACTION_CAN_READ -from airflow.www import app as application -from airflow.www.auth import get_access_denied_message -from airflow.www.utils import CustomSQLAInterface from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import ( create_user, @@ -1162,8 +1162,6 @@ def test_dag_id_consistency( fail: bool, ): with app.test_request_context() as mock_context: - from airflow.www.auth import has_access_dag - mock_context.request.args = {"dag_id": dag_id_args} if dag_id_args else {} kwargs = {"dag_id": dag_id_kwargs} if dag_id_kwargs else {} mock_context.request.form = {"dag_id": dag_id_form} if dag_id_form else {} diff --git a/providers/tests/fab/auth_manager/views/test_permissions.py b/providers/tests/fab/auth_manager/views/test_permissions.py index b2eb0b47c5c1f..9341fe8479a60 100644 --- a/providers/tests/fab/auth_manager/views/test_permissions.py +++ b/providers/tests/fab/auth_manager/views/test_permissions.py @@ -19,8 +19,8 @@ import pytest +from airflow.providers.fab.www import app as application from airflow.providers.fab.www.security import permissions -from airflow.www import app as application from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning diff --git a/providers/tests/fab/auth_manager/views/test_roles_list.py b/providers/tests/fab/auth_manager/views/test_roles_list.py index e728b2ae32837..0f58cb10a812d 100644 --- a/providers/tests/fab/auth_manager/views/test_roles_list.py +++ b/providers/tests/fab/auth_manager/views/test_roles_list.py @@ -19,8 +19,8 @@ import pytest +from airflow.providers.fab.www import app as application from airflow.providers.fab.www.security import permissions -from airflow.www import app as application from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning diff --git a/providers/tests/fab/auth_manager/views/test_user.py b/providers/tests/fab/auth_manager/views/test_user.py index 7dadeeaf525de..8a35c327c0733 100644 --- a/providers/tests/fab/auth_manager/views/test_user.py +++ b/providers/tests/fab/auth_manager/views/test_user.py @@ -19,8 +19,8 @@ import pytest +from airflow.providers.fab.www import app as application from airflow.providers.fab.www.security import permissions -from airflow.www import app as application from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning diff --git a/providers/tests/fab/auth_manager/views/test_user_edit.py b/providers/tests/fab/auth_manager/views/test_user_edit.py index afd2e537125d3..4874186270c11 100644 --- a/providers/tests/fab/auth_manager/views/test_user_edit.py +++ b/providers/tests/fab/auth_manager/views/test_user_edit.py @@ -19,8 +19,8 @@ import pytest +from airflow.providers.fab.www import app as application from airflow.providers.fab.www.security import permissions -from airflow.www import app as application from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning diff --git a/providers/tests/fab/auth_manager/views/test_user_stats.py b/providers/tests/fab/auth_manager/views/test_user_stats.py index 1e08c94dfb719..14b06f2a9027c 100644 --- a/providers/tests/fab/auth_manager/views/test_user_stats.py +++ b/providers/tests/fab/auth_manager/views/test_user_stats.py @@ -19,8 +19,8 @@ import pytest +from airflow.providers.fab.www import app as application from airflow.providers.fab.www.security import permissions -from airflow.www import app as application from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning