diff --git a/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py b/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py
index c74c4f120836d..897c232e15cec 100644
--- a/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py
+++ b/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py
@@ -109,10 +109,10 @@
from airflow.providers.fab.auth_manager.views.user_stats import CustomUserStatsChartView
from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2
from airflow.providers.fab.www.session import (
+ AirflowDatabaseSessionInterface,
AirflowDatabaseSessionInterface as FabAirflowDatabaseSessionInterface,
)
from airflow.security import permissions
-from airflow.www.session import AirflowDatabaseSessionInterface
if TYPE_CHECKING:
from airflow.security.permissions import RESOURCE_ASSET
diff --git a/providers/src/airflow/providers/fab/www/app.py b/providers/src/airflow/providers/fab/www/app.py
index 0414fc5e408b5..43dd36742997d 100644
--- a/providers/src/airflow/providers/fab/www/app.py
+++ b/providers/src/airflow/providers/fab/www/app.py
@@ -17,22 +17,25 @@
# under the License.
from __future__ import annotations
-from os.path import isabs
-
from flask import Flask
from flask_appbuilder import SQLA
from flask_wtf.csrf import CSRFProtect
-from sqlalchemy.engine.url import make_url
from airflow import settings
from airflow.configuration import conf
-from airflow.exceptions import AirflowConfigException
from airflow.logging_config import configure_logging
from airflow.providers.fab.www.extensions.init_appbuilder import init_appbuilder
from airflow.providers.fab.www.extensions.init_jinja_globals import init_jinja_globals
from airflow.providers.fab.www.extensions.init_manifest_files import configure_manifest_files
from airflow.providers.fab.www.extensions.init_security import init_api_auth, init_xframe_protection
-from airflow.providers.fab.www.extensions.init_views import init_error_handlers, init_plugins
+from airflow.providers.fab.www.extensions.init_views import (
+ init_api_auth_provider,
+ init_api_connexion,
+ init_api_error_handlers,
+ init_error_handlers,
+ init_plugins,
+)
+from airflow.utils.json import AirflowJsonProvider
app: Flask | None = None
@@ -41,44 +44,55 @@
csrf = CSRFProtect()
-def create_app():
+def create_app(config=None, testing=False):
"""Create a new instance of Airflow WWW app."""
flask_app = Flask(__name__)
flask_app.secret_key = conf.get("webserver", "SECRET_KEY")
+ webserver_config = conf.get_mandatory_value("webserver", "config_file")
+ # Enable customizations in webserver_config.py to be applied via Flask.current_app.
+ with flask_app.app_context():
+ flask_app.config.from_pyfile(webserver_config, silent=True)
+
+ flask_app.config["TESTING"] = testing
flask_app.config["SQLALCHEMY_DATABASE_URI"] = conf.get("database", "SQL_ALCHEMY_CONN")
- url = make_url(flask_app.config["SQLALCHEMY_DATABASE_URI"])
- if url.drivername == "sqlite" and url.database and not isabs(url.database):
- raise AirflowConfigException(
- f'Cannot use relative path: `{conf.get("database", "SQL_ALCHEMY_CONN")}` to connect to sqlite. '
- "Please use absolute path such as `sqlite:////tmp/airflow.db`."
- )
+ if config:
+ flask_app.config.from_mapping(config)
if "SQLALCHEMY_ENGINE_OPTIONS" not in flask_app.config:
flask_app.config["SQLALCHEMY_ENGINE_OPTIONS"] = settings.prepare_engine_args()
+ # Configure the JSON encoder used by `|tojson` filter from Flask
+ flask_app.json_provider_class = AirflowJsonProvider
+ flask_app.json = AirflowJsonProvider(flask_app)
+
+ csrf.init_app(flask_app)
+
db = SQLA()
db.session = settings.Session
db.init_app(flask_app)
+ init_api_auth(flask_app)
configure_logging()
configure_manifest_files(flask_app)
- init_api_auth(flask_app)
with flask_app.app_context():
init_appbuilder(flask_app)
init_plugins(flask_app)
+ init_api_auth_provider(flask_app)
init_error_handlers(flask_app)
+ init_api_connexion(flask_app)
+ init_api_error_handlers(flask_app) # needs to be after all api inits to let them add their path first
init_jinja_globals(flask_app)
init_xframe_protection(flask_app)
return flask_app
-def cached_app():
+def cached_app(config=None, testing=False):
"""Return cached instance of Airflow WWW app."""
global app
if not app:
- app = create_app()
+ app = create_app(config=config, testing=testing)
return app
diff --git a/providers/src/airflow/providers/fab/www/auth.py b/providers/src/airflow/providers/fab/www/auth.py
new file mode 100644
index 0000000000000..198acb29f9a69
--- /dev/null
+++ b/providers/src/airflow/providers/fab/www/auth.py
@@ -0,0 +1,125 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from functools import wraps
+from typing import TYPE_CHECKING, Callable, TypeVar, cast
+
+from flask import flash, redirect, render_template, request, url_for
+
+from airflow.api_fastapi.app import get_auth_manager
+from airflow.auth.managers.models.resource_details import (
+ AccessView,
+ DagAccessEntity,
+ DagDetails,
+)
+from airflow.configuration import conf
+from airflow.utils.net import get_hostname
+
+if TYPE_CHECKING:
+ from airflow.auth.managers.base_auth_manager import ResourceMethod
+
+T = TypeVar("T", bound=Callable)
+
+log = logging.getLogger(__name__)
+
+
+def get_access_denied_message():
+ return conf.get("webserver", "access_denied_message")
+
+
+def _has_access(*, is_authorized: bool, func: Callable, args, kwargs):
+ """
+ Define the behavior whether the user is authorized to access the resource.
+
+ :param is_authorized: whether the user is authorized to access the resource
+ :param func: the function to call if the user is authorized
+ :param args: the arguments of ``func``
+ :param kwargs: the keyword arguments ``func``
+
+ :meta private:
+ """
+ if is_authorized:
+ return func(*args, **kwargs)
+ elif get_auth_manager().is_logged_in() and not get_auth_manager().is_authorized_view(
+ access_view=AccessView.WEBSITE
+ ):
+ return (
+ render_template(
+ "airflow/no_roles_permissions.html",
+ hostname=get_hostname() if conf.getboolean("webserver", "EXPOSE_HOSTNAME") else "",
+ logout_url=get_auth_manager().get_url_logout(),
+ ),
+ 403,
+ )
+ elif not get_auth_manager().is_logged_in():
+ return redirect(get_auth_manager().get_url_login(next_url=request.url))
+ else:
+ access_denied = get_access_denied_message()
+ flash(access_denied, "danger")
+ return redirect(url_for("Airflow.index"))
+
+
+def has_access_dag(method: ResourceMethod, access_entity: DagAccessEntity | None = None) -> Callable[[T], T]:
+ def has_access_decorator(func: T):
+ @wraps(func)
+ def decorated(*args, **kwargs):
+ dag_id_kwargs = kwargs.get("dag_id")
+ dag_id_args = request.args.get("dag_id")
+ dag_id_form = request.form.get("dag_id")
+ dag_id_json = request.json.get("dag_id") if request.is_json else None
+ all_dag_ids = [dag_id_kwargs, dag_id_args, dag_id_form, dag_id_json]
+ unique_dag_ids = set(dag_id for dag_id in all_dag_ids if dag_id is not None)
+
+ if len(unique_dag_ids) > 1:
+ log.warning(
+ "There are different dag_ids passed in the request: %s. Returning 403.", unique_dag_ids
+ )
+ log.warning(
+ "kwargs: %s, args: %s, form: %s, json: %s",
+ dag_id_kwargs,
+ dag_id_args,
+ dag_id_form,
+ dag_id_json,
+ )
+ return (
+ render_template(
+ "airflow/no_roles_permissions.html",
+ hostname=get_hostname() if conf.getboolean("webserver", "EXPOSE_HOSTNAME") else "",
+ logout_url=get_auth_manager().get_url_logout(),
+ ),
+ 403,
+ )
+ dag_id = unique_dag_ids.pop() if unique_dag_ids else None
+
+ is_authorized = get_auth_manager().is_authorized_dag(
+ method=method,
+ access_entity=access_entity,
+ details=None if not dag_id else DagDetails(id=dag_id),
+ )
+
+ return _has_access(
+ is_authorized=is_authorized,
+ func=func,
+ args=args,
+ kwargs=kwargs,
+ )
+
+ return cast(T, decorated)
+
+ return has_access_decorator
diff --git a/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py b/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py
index 9cf353490c3ac..ce2d559d3ad2a 100644
--- a/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py
+++ b/providers/src/airflow/providers/fab/www/extensions/init_appbuilder.py
@@ -39,7 +39,7 @@
from flask_appbuilder.views import IndexView
from airflow import settings
-from airflow.api_fastapi.app import create_auth_manager
+from airflow.api_fastapi.app import create_auth_manager, get_auth_manager
from airflow.configuration import conf
from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2
@@ -283,6 +283,8 @@ def _add_admin_views(self):
self.indexview = self._check_and_init(self.indexview)
self.add_view_no_menu(self.indexview)
+ get_auth_manager().register_views()
+
def _add_addon_views(self):
"""Register declared addons."""
for addon in self._addon_managers:
diff --git a/providers/src/airflow/providers/fab/www/extensions/init_views.py b/providers/src/airflow/providers/fab/www/extensions/init_views.py
index 382bcaf9ca748..e8e6c6fa6c41a 100644
--- a/providers/src/airflow/providers/fab/www/extensions/init_views.py
+++ b/providers/src/airflow/providers/fab/www/extensions/init_views.py
@@ -18,17 +18,47 @@
import logging
from functools import cached_property
+from pathlib import Path
from typing import TYPE_CHECKING
-from connexion import Resolver
+from connexion import FlaskApi, Resolver
from connexion.decorators.validation import RequestBodyValidator
-from connexion.exceptions import BadRequestProblem
+from connexion.exceptions import BadRequestProblem, ProblemException
+from flask import request
+
+from airflow.api_connexion.exceptions import common_error_handler
+from airflow.api_fastapi.app import get_auth_manager
+from airflow.configuration import conf
+from airflow.providers.fab.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED
+from airflow.utils.yaml import safe_load
if TYPE_CHECKING:
from flask import Flask
log = logging.getLogger(__name__)
+# providers/src/airflow/providers/fab/www/extensions/init_views.py => airflow/
+ROOT_APP_DIR = Path(__file__).parents[7].joinpath("airflow").resolve()
+
+
+def set_cors_headers_on_response(response):
+ """Add response headers."""
+ allow_headers = conf.get("api", "access_control_allow_headers")
+ allow_methods = conf.get("api", "access_control_allow_methods")
+ allow_origins = conf.get("api", "access_control_allow_origins")
+ if allow_headers:
+ response.headers["Access-Control-Allow-Headers"] = allow_headers
+ if allow_methods:
+ response.headers["Access-Control-Allow-Methods"] = allow_methods
+ if allow_origins == "*":
+ response.headers["Access-Control-Allow-Origin"] = "*"
+ elif allow_origins:
+ allowed_origins = allow_origins.split(" ")
+ origin = request.environ.get("HTTP_ORIGIN", allowed_origins[0])
+ if origin in allowed_origins:
+ response.headers["Access-Control-Allow-Origin"] = origin
+ return response
+
class _LazyResolution:
"""
@@ -78,6 +108,59 @@ def validate_schema(self, data, url):
return super().validate_schema(data, url)
+base_paths: list[str] = [] # contains the list of base paths that have api endpoints
+
+
+def init_api_error_handlers(app: Flask) -> None:
+ """Add error handlers for 404 and 405 errors for existing API paths."""
+
+ @app.errorhandler(404)
+ def _handle_api_not_found(ex):
+ if any([request.path.startswith(p) for p in base_paths]):
+ # 404 errors are never handled on the blueprint level
+ # unless raised from a view func so actual 404 errors,
+ # i.e. "no route for it" defined, need to be handled
+ # here on the application level
+ return common_error_handler(ex)
+ else:
+ from airflow.providers.fab.www.views import not_found
+
+ return not_found(ex)
+
+ @app.errorhandler(405)
+ def _handle_method_not_allowed(ex):
+ if any([request.path.startswith(p) for p in base_paths]):
+ return common_error_handler(ex)
+ else:
+ from airflow.providers.fab.www.views import method_not_allowed
+
+ return method_not_allowed(ex)
+
+ app.register_error_handler(ProblemException, common_error_handler)
+
+
+def init_api_connexion(app: Flask) -> None:
+ """Initialize Stable API."""
+ base_path = "/api/v1"
+ base_paths.append(base_path)
+
+ with ROOT_APP_DIR.joinpath("api_connexion", "openapi", "v1.yaml").open() as f:
+ specification = safe_load(f)
+ api_bp = FlaskApi(
+ specification=specification,
+ resolver=_LazyResolver(),
+ base_path=base_path,
+ options={"swagger_ui": SWAGGER_ENABLED, "swagger_path": SWAGGER_BUNDLE.__fspath__()},
+ strict_validation=True,
+ validate_responses=True,
+ validator_map={"body": _CustomErrorRequestBodyValidator},
+ ).blueprint
+ api_bp.after_request(set_cors_headers_on_response)
+
+ app.register_blueprint(api_bp)
+ app.extensions["csrf"].exempt(api_bp)
+
+
def init_plugins(app):
"""Integrate Flask and FAB with plugins."""
from airflow import plugins_manager
@@ -118,3 +201,13 @@ def init_error_handlers(app: Flask):
app.register_error_handler(500, views.show_traceback)
app.register_error_handler(404, views.not_found)
+
+
+def init_api_auth_provider(app):
+ """Initialize the API offered by the auth manager."""
+ auth_mgr = get_auth_manager()
+ blueprint = auth_mgr.get_api_endpoints()
+ if blueprint:
+ base_paths.append(blueprint.url_prefix)
+ app.register_blueprint(blueprint)
+ app.extensions["csrf"].exempt(blueprint)
diff --git a/providers/src/airflow/providers/fab/www/templates/airflow/no_roles_permissions.html b/providers/src/airflow/providers/fab/www/templates/airflow/no_roles_permissions.html
new file mode 100644
index 0000000000000..fa619c403c030
--- /dev/null
+++ b/providers/src/airflow/providers/fab/www/templates/airflow/no_roles_permissions.html
@@ -0,0 +1,42 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+#}
+
+
+
+
+ Airflow
+
+
+
+
+
+
Your user has no roles and/or permissions!
+
Unfortunately your user has no roles, and therefore you cannot use Airflow.
+
Please contact your Airflow administrator
+ (authentication
+ may be misconfigured) or
+
+
+
{{ hostname }}
+
+
+
diff --git a/providers/src/airflow/providers/fab/www/views.py b/providers/src/airflow/providers/fab/www/views.py
index 48bf0bfddffaf..ef270237cc87d 100644
--- a/providers/src/airflow/providers/fab/www/views.py
+++ b/providers/src/airflow/providers/fab/www/views.py
@@ -43,6 +43,19 @@ def not_found(error):
)
+def method_not_allowed(error):
+ """Show Method Not Allowed on screen for any error in the Webserver."""
+ return (
+ render_template(
+ "airflow/error.html",
+ hostname=get_hostname() if conf.getboolean("webserver", "EXPOSE_HOSTNAME") else "",
+ status_code=405,
+ error_message="Received an invalid request.",
+ ),
+ 405,
+ )
+
+
def show_traceback(error):
"""Show Traceback for a given error."""
is_logged_in = get_auth_manager().is_logged_in()
diff --git a/providers/tests/fab/auth_manager/api/auth/backend/test_basic_auth.py b/providers/tests/fab/auth_manager/api/auth/backend/test_basic_auth.py
index af893b87c8b60..c4a03486de938 100644
--- a/providers/tests/fab/auth_manager/api/auth/backend/test_basic_auth.py
+++ b/providers/tests/fab/auth_manager/api/auth/backend/test_basic_auth.py
@@ -23,7 +23,7 @@
from flask_appbuilder.const import AUTH_LDAP
from airflow.providers.fab.auth_manager.api.auth.backend.basic_auth import requires_authentication
-from airflow.www import app as application
+from airflow.providers.fab.www import app as application
@pytest.fixture
diff --git a/providers/tests/fab/auth_manager/api/auth/backend/test_session.py b/providers/tests/fab/auth_manager/api/auth/backend/test_session.py
index ed7cd2bf45869..0b1d4f512ec8d 100644
--- a/providers/tests/fab/auth_manager/api/auth/backend/test_session.py
+++ b/providers/tests/fab/auth_manager/api/auth/backend/test_session.py
@@ -22,7 +22,7 @@
from flask import Response
from airflow.providers.fab.auth_manager.api.auth.backend.session import requires_authentication
-from airflow.www import app as application
+from airflow.providers.fab.www import app as application
@pytest.fixture
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_asset_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_asset_endpoint.py
deleted file mode 100644
index 79949b4cd8df0..0000000000000
--- a/providers/tests/fab/auth_manager/api_endpoints/test_asset_endpoint.py
+++ /dev/null
@@ -1,325 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from collections.abc import Generator
-
-import pytest
-import time_machine
-
-from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
-from airflow.security import permissions
-from airflow.utils import timezone
-
-from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
-from tests_common.test_utils.db import clear_db_assets, clear_db_runs
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-from tests_common.test_utils.www import _check_last_log
-
-try:
- from airflow.models.asset import AssetDagRunQueue, AssetModel
-except ImportError:
- if AIRFLOW_V_3_0_PLUS:
- raise
- else:
- pass
-
-pytestmark = [
- pytest.mark.db_test,
- pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
-]
-
-
-@pytest.fixture(scope="module")
-def configured_app(minimal_app_for_auth_api):
- app = minimal_app_for_auth_api
- create_user(
- app,
- username="test_queued_event",
- role_name="TestQueuedEvent",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET),
- ],
- )
-
- yield app
-
- delete_user(app, username="test_queued_event")
-
-
-class TestAssetEndpoint:
- default_time = "2020-06-11T18:00:00+00:00"
-
- @pytest.fixture(autouse=True)
- def setup_attrs(self, configured_app) -> None:
- self.app = configured_app
- self.client = self.app.test_client()
- clear_db_assets()
- clear_db_runs()
-
- def teardown_method(self) -> None:
- clear_db_assets()
- clear_db_runs()
-
- def _create_asset(self, session):
- asset_model = AssetModel(
- id=1,
- uri="s3://bucket/key",
- extra={"foo": "bar"},
- created_at=timezone.parse(self.default_time),
- updated_at=timezone.parse(self.default_time),
- )
- session.add(asset_model)
- session.commit()
- return asset_model
-
-
-class TestQueuedEventEndpoint(TestAssetEndpoint):
- @pytest.fixture
- def time_freezer(self) -> Generator:
- freezer = time_machine.travel(self.default_time, tick=False)
- freezer.start()
-
- yield
-
- freezer.stop()
-
- def _create_asset_dag_run_queues(self, dag_id, asset_id, session):
- ddrq = AssetDagRunQueue(target_dag_id=dag_id, asset_id=asset_id)
- session.add(ddrq)
- session.commit()
- return ddrq
-
-
-class TestGetDagAssetQueuedEvent(TestQueuedEventEndpoint):
- @pytest.mark.usefixtures("time_freezer")
- def test_should_respond_200(self, session, create_dummy_dag):
- dag, _ = create_dummy_dag()
- dag_id = dag.dag_id
- asset_id = self._create_asset(session).id
- self._create_asset_dag_run_queues(dag_id, asset_id, session)
- asset_uri = "s3://bucket/key"
-
- response = self.client.get(
- f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 200
- assert response.json == {
- "created_at": self.default_time,
- "uri": "s3://bucket/key",
- "dag_id": "dag",
- }
-
- def test_should_respond_404(self):
- dag_id = "not_exists"
- asset_uri = "not_exists"
-
- response = self.client.get(
- f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 404
- assert response.json == {
- "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found",
- "status": 404,
- "title": "Queue event not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- }
-
-
-class TestDeleteDagAssetQueuedEvent(TestAssetEndpoint):
- def test_delete_should_respond_204(self, session, create_dummy_dag):
- dag, _ = create_dummy_dag()
- dag_id = dag.dag_id
- asset_uri = "s3://bucket/key"
- asset_id = self._create_asset(session).id
-
- ddrq = AssetDagRunQueue(target_dag_id=dag_id, asset_id=asset_id)
- session.add(ddrq)
- session.commit()
- conn = session.query(AssetDagRunQueue).all()
- assert len(conn) == 1
-
- response = self.client.delete(
- f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 204
- conn = session.query(AssetDagRunQueue).all()
- assert len(conn) == 0
- _check_last_log(session, dag_id=dag_id, event="api.delete_dag_asset_queued_event", logical_date=None)
-
- def test_should_respond_404(self):
- dag_id = "not_exists"
- asset_uri = "not_exists"
-
- response = self.client.delete(
- f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 404
- assert response.json == {
- "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found",
- "status": 404,
- "title": "Queue event not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- }
-
-
-class TestGetDagAssetQueuedEvents(TestQueuedEventEndpoint):
- @pytest.mark.usefixtures("time_freezer")
- def test_should_respond_200(self, session, create_dummy_dag):
- dag, _ = create_dummy_dag()
- dag_id = dag.dag_id
- asset_id = self._create_asset(session).id
- self._create_asset_dag_run_queues(dag_id, asset_id, session)
-
- response = self.client.get(
- f"/api/v1/dags/{dag_id}/assets/queuedEvent",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 200
- assert response.json == {
- "queued_events": [
- {
- "created_at": self.default_time,
- "uri": "s3://bucket/key",
- "dag_id": "dag",
- }
- ],
- "total_entries": 1,
- }
-
- def test_should_respond_404(self):
- dag_id = "not_exists"
-
- response = self.client.get(
- f"/api/v1/dags/{dag_id}/assets/queuedEvent",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 404
- assert response.json == {
- "detail": "Queue event with dag_id: `not_exists` was not found",
- "status": 404,
- "title": "Queue event not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- }
-
-
-class TestDeleteDagDatasetQueuedEvents(TestAssetEndpoint):
- def test_should_respond_404(self):
- dag_id = "not_exists"
-
- response = self.client.delete(
- f"/api/v1/dags/{dag_id}/assets/queuedEvent",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 404
- assert response.json == {
- "detail": "Queue event with dag_id: `not_exists` was not found",
- "status": 404,
- "title": "Queue event not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- }
-
-
-class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint):
- @pytest.mark.usefixtures("time_freezer")
- def test_should_respond_200(self, session, create_dummy_dag):
- dag, _ = create_dummy_dag()
- dag_id = dag.dag_id
- asset_id = self._create_asset(session).id
- self._create_asset_dag_run_queues(dag_id, asset_id, session)
- asset_uri = "s3://bucket/key"
-
- response = self.client.get(
- f"/api/v1/assets/queuedEvent/{asset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 200
- assert response.json == {
- "queued_events": [
- {
- "created_at": self.default_time,
- "uri": "s3://bucket/key",
- "dag_id": "dag",
- }
- ],
- "total_entries": 1,
- }
-
- def test_should_respond_404(self):
- asset_uri = "not_exists"
-
- response = self.client.get(
- f"/api/v1/assets/queuedEvent/{asset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 404
- assert response.json == {
- "detail": "Queue event with asset uri: `not_exists` was not found",
- "status": 404,
- "title": "Queue event not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- }
-
-
-class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint):
- def test_delete_should_respond_204(self, session, create_dummy_dag):
- dag, _ = create_dummy_dag()
- dag_id = dag.dag_id
- asset_id = self._create_asset(session).id
- self._create_asset_dag_run_queues(dag_id, asset_id, session)
- asset_uri = "s3://bucket/key"
-
- response = self.client.delete(
- f"/api/v1/assets/queuedEvent/{asset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 204
- conn = session.query(AssetDagRunQueue).all()
- assert len(conn) == 0
- _check_last_log(session, dag_id=None, event="api.delete_asset_queued_events", logical_date=None)
-
- def test_should_respond_404(self):
- asset_uri = "not_exists"
-
- response = self.client.delete(
- f"/api/v1/assets/queuedEvent/{asset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 404
- assert response.json == {
- "detail": "Queue event with asset uri: `not_exists` was not found",
- "status": 404,
- "title": "Queue event not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- }
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_auth.py b/providers/tests/fab/auth_manager/api_endpoints/test_auth.py
index 4f8bc12702ff4..63a381b080271 100644
--- a/providers/tests/fab/auth_manager/api_endpoints/test_auth.py
+++ b/providers/tests/fab/auth_manager/api_endpoints/test_auth.py
@@ -25,7 +25,6 @@
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import clear_db_pools
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-from tests_common.test_utils.www import client_with_login
pytestmark = [
pytest.mark.db_test,
@@ -55,7 +54,7 @@ def set_attrs(self, minimal_app_for_auth_api):
class TestBasicAuth(BaseTestAuth):
@pytest.fixture(autouse=True, scope="class")
def with_basic_auth_backend(self, minimal_app_for_auth_api):
- from airflow.www.extensions.init_security import init_api_auth
+ from airflow.providers.fab.www.extensions.init_security import init_api_auth
old_auth = getattr(minimal_app_for_auth_api, "api_auth")
@@ -132,44 +131,3 @@ def test_invalid_auth_header(self, token):
assert response.headers["Content-Type"] == "application/problem+json"
assert response.headers["WWW-Authenticate"] == "Basic"
assert_401(response)
-
-
-class TestSessionWithBasicAuthFallback(BaseTestAuth):
- @pytest.fixture(autouse=True, scope="class")
- def with_basic_auth_backend(self, minimal_app_for_auth_api):
- from airflow.www.extensions.init_security import init_api_auth
-
- old_auth = getattr(minimal_app_for_auth_api, "api_auth")
-
- try:
- with conf_vars(
- {
- (
- "api",
- "auth_backends",
- ): "airflow.providers.fab.auth_manager.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"
- }
- ):
- init_api_auth(minimal_app_for_auth_api)
- yield
- finally:
- setattr(minimal_app_for_auth_api, "api_auth", old_auth)
-
- def test_basic_auth_fallback(self):
- token = "Basic " + b64encode(b"test:test").decode()
- clear_db_pools()
-
- # request uses session
- admin_user = client_with_login(self.app, username="test", password="test")
- response = admin_user.get("/api/v1/pools")
- assert response.status_code == 200
-
- # request uses basic auth
- with self.app.test_client() as test_client:
- response = test_client.get("/api/v1/pools", headers={"Authorization": token})
- assert response.status_code == 200
-
- # request without session or basic auth header
- with self.app.test_client() as test_client:
- response = test_client.get("/api/v1/pools")
- assert response.status_code == 401
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_cors.py b/providers/tests/fab/auth_manager/api_endpoints/test_cors.py
index b8947925b1ec5..ca2ec12c0422f 100644
--- a/providers/tests/fab/auth_manager/api_endpoints/test_cors.py
+++ b/providers/tests/fab/auth_manager/api_endpoints/test_cors.py
@@ -52,7 +52,7 @@ def set_attrs(self, minimal_app_for_auth_api):
class TestEmptyCors(BaseTestAuth):
@pytest.fixture(autouse=True, scope="class")
def with_basic_auth_backend(self, minimal_app_for_auth_api):
- from airflow.www.extensions.init_security import init_api_auth
+ from airflow.providers.fab.www.extensions.init_security import init_api_auth
old_auth = getattr(minimal_app_for_auth_api, "api_auth")
@@ -80,7 +80,7 @@ def test_empty_cors_headers(self):
class TestCorsOrigin(BaseTestAuth):
@pytest.fixture(autouse=True, scope="class")
def with_basic_auth_backend(self, minimal_app_for_auth_api):
- from airflow.www.extensions.init_security import init_api_auth
+ from airflow.providers.fab.www.extensions.init_security import init_api_auth
old_auth = getattr(minimal_app_for_auth_api, "api_auth")
@@ -124,7 +124,7 @@ def test_cors_origin_reflection(self):
class TestCorsWildcard(BaseTestAuth):
@pytest.fixture(autouse=True, scope="class")
def with_basic_auth_backend(self, minimal_app_for_auth_api):
- from airflow.www.extensions.init_security import init_api_auth
+ from airflow.providers.fab.www.extensions.init_security import init_api_auth
old_auth = getattr(minimal_app_for_auth_api, "api_auth")
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_endpoint.py
deleted file mode 100644
index 853ba3e643606..0000000000000
--- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_endpoint.py
+++ /dev/null
@@ -1,226 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pendulum
-import pytest
-
-from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
-from airflow.models import DagModel
-from airflow.security import permissions
-from airflow.utils.session import provide_session
-
-from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
-from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-from tests_common.test_utils.www import _check_last_log
-
-pytestmark = [
- pytest.mark.db_test,
- pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
-]
-
-
-@pytest.fixture
-def current_file_token(url_safe_serializer) -> str:
- return url_safe_serializer.dumps(__file__)
-
-
-DAG_ID = "test_dag"
-TASK_ID = "op1"
-DAG2_ID = "test_dag2"
-DAG3_ID = "test_dag3"
-UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')"
-
-
-@pytest.fixture(scope="module")
-def configured_app(minimal_app_for_auth_api):
- app = minimal_app_for_auth_api
-
- create_user(app, username="test_granular_permissions", role_name="TestGranularDag")
- app.appbuilder.sm.sync_perm_for_dag(
- "TEST_DAG_1",
- access_control={
- "TestGranularDag": {
- permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}
- },
- },
- )
- app.appbuilder.sm.sync_perm_for_dag(
- "TEST_DAG_1",
- access_control={
- "TestGranularDag": {
- permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}
- },
- },
- )
-
- yield app
-
- delete_user(app, username="test_granular_permissions")
-
-
-class TestDagEndpoint:
- @staticmethod
- def clean_db():
- clear_db_runs()
- clear_db_dags()
- clear_db_serialized_dags()
-
- @pytest.fixture(autouse=True)
- def setup_attrs(self, configured_app) -> None:
- self.clean_db()
- self.app = configured_app
- self.client = self.app.test_client() # type:ignore
- self.dag_id = DAG_ID
- self.dag2_id = DAG2_ID
- self.dag3_id = DAG3_ID
-
- def teardown_method(self) -> None:
- self.clean_db()
-
- @provide_session
- def _create_dag_models(self, count, dag_id_prefix="TEST_DAG", is_paused=False, session=None):
- for num in range(1, count + 1):
- dag_model = DagModel(
- dag_id=f"{dag_id_prefix}_{num}",
- fileloc=f"/tmp/dag_{num}.py",
- timetable_summary="2 2 * * *",
- is_active=True,
- is_paused=is_paused,
- )
- session.add(dag_model)
-
- @provide_session
- def _create_dag_model_for_details_endpoint(self, dag_id, session=None):
- dag_model = DagModel(
- dag_id=dag_id,
- fileloc="/tmp/dag.py",
- timetable_summary="2 2 * * *",
- is_active=True,
- is_paused=False,
- )
- session.add(dag_model)
-
- @provide_session
- def _create_dag_model_for_details_endpoint_with_asset_expression(self, dag_id, session=None):
- dag_model = DagModel(
- dag_id=dag_id,
- fileloc="/tmp/dag.py",
- timetable_summary="2 2 * * *",
- is_active=True,
- is_paused=False,
- asset_expression={
- "any": [
- "s3://dag1/output_1.txt",
- {"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]},
- ]
- },
- )
- session.add(dag_model)
-
- @provide_session
- def _create_deactivated_dag(self, session=None):
- dag_model = DagModel(
- dag_id="TEST_DAG_DELETED_1",
- fileloc="/tmp/dag_del_1.py",
- timetable_summary="2 2 * * *",
- is_active=False,
- )
- session.add(dag_model)
-
-
-class TestGetDag(TestDagEndpoint):
- def test_should_respond_200_with_granular_dag_access(self):
- self._create_dag_models(1)
- response = self.client.get(
- "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"}
- )
- assert response.status_code == 200
-
- def test_should_respond_403_with_granular_access_for_different_dag(self):
- self._create_dag_models(3)
- response = self.client.get(
- "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"}
- )
- assert response.status_code == 403
-
-
-class TestGetDags(TestDagEndpoint):
- def test_should_respond_200_with_granular_dag_access(self):
- self._create_dag_models(3)
- response = self.client.get(
- "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"}
- )
- assert response.status_code == 200
- assert len(response.json["dags"]) == 1
- assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1"
-
-
-class TestPatchDag(TestDagEndpoint):
- @provide_session
- def _create_dag_model(self, session=None):
- dag_model = DagModel(
- dag_id="TEST_DAG_1", fileloc="/tmp/dag_1.py", timetable_summary="2 2 * * *", is_paused=True
- )
- session.add(dag_model)
- return dag_model
-
- def test_should_respond_200_on_patch_with_granular_dag_access(self, session):
- self._create_dag_models(1)
- response = self.client.patch(
- "/api/v1/dags/TEST_DAG_1",
- json={
- "is_paused": False,
- },
- environ_overrides={"REMOTE_USER": "test_granular_permissions"},
- )
- assert response.status_code == 200
- _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", logical_date=None)
-
- def test_validation_error_raises_400(self):
- patch_body = {
- "ispaused": True,
- }
- dag_model = self._create_dag_model()
- response = self.client.patch(
- f"/api/v1/dags/{dag_model.dag_id}",
- json=patch_body,
- environ_overrides={"REMOTE_USER": "test_granular_permissions"},
- )
- assert response.status_code == 400
- assert response.json == {
- "detail": "{'ispaused': ['Unknown field.']}",
- "status": 400,
- "title": "Bad Request",
- "type": EXCEPTIONS_LINK_MAP[400],
- }
-
-
-class TestPatchDags(TestDagEndpoint):
- def test_should_respond_200_with_granular_dag_access(self):
- self._create_dag_models(3)
- response = self.client.patch(
- "api/v1/dags?dag_id_pattern=~",
- json={
- "is_paused": False,
- },
- environ_overrides={"REMOTE_USER": "test_granular_permissions"},
- )
- assert response.status_code == 200
- assert len(response.json["dags"]) == 1
- assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1"
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py
deleted file mode 100644
index 52c8bbca185fa..0000000000000
--- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py
+++ /dev/null
@@ -1,272 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from datetime import timedelta
-
-import pytest
-
-from airflow.models.dag import DAG, DagModel
-from airflow.models.dagrun import DagRun
-from airflow.models.param import Param
-from airflow.security import permissions
-from airflow.utils import timezone
-from airflow.utils.session import create_session
-from airflow.utils.state import DagRunState
-
-from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import (
- create_user,
- delete_roles,
- delete_user,
-)
-from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-
-try:
- from airflow.utils.types import DagRunTriggeredByType, DagRunType
-except ImportError:
- if AIRFLOW_V_3_0_PLUS:
- raise
- else:
- pass
-
-pytestmark = [
- pytest.mark.db_test,
- pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
-]
-
-
-@pytest.fixture(scope="module")
-def configured_app(minimal_app_for_auth_api):
- app = minimal_app_for_auth_api
-
- create_user(
- app,
- username="test_no_dag_run_create_permission",
- role_name="TestNoDagRunCreatePermission",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN),
- ],
- )
- create_user(
- app,
- username="test_dag_view_only",
- role_name="TestViewDags",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN),
- ],
- )
- create_user(
- app,
- username="test_view_dags",
- role_name="TestViewDags",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN),
- ],
- )
- create_user(
- app,
- username="test_granular_permissions",
- role_name="TestGranularDag",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)],
- )
- app.appbuilder.sm.sync_perm_for_dag(
- "TEST_DAG_ID",
- access_control={
- "TestGranularDag": {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ},
- "TestNoDagRunCreatePermission": {permissions.RESOURCE_DAG_RUN: {permissions.ACTION_CAN_CREATE}},
- },
- )
-
- yield app
-
- delete_user(app, username="test_dag_view_only")
- delete_user(app, username="test_view_dags")
- delete_user(app, username="test_granular_permissions")
- delete_user(app, username="test_no_dag_run_create_permission")
- delete_roles(app)
-
-
-class TestDagRunEndpoint:
- default_time = "2020-06-11T18:00:00+00:00"
- default_time_2 = "2020-06-12T18:00:00+00:00"
- default_time_3 = "2020-06-13T18:00:00+00:00"
-
- @pytest.fixture(autouse=True)
- def setup_attrs(self, configured_app) -> None:
- self.app = configured_app
- self.client = self.app.test_client() # type:ignore
- clear_db_runs()
- clear_db_serialized_dags()
- clear_db_dags()
-
- def teardown_method(self) -> None:
- clear_db_runs()
- clear_db_dags()
- clear_db_serialized_dags()
-
- def _create_dag(self, dag_id):
- dag_instance = DagModel(dag_id=dag_id)
- dag_instance.is_active = True
- with create_session() as session:
- session.add(dag_instance)
- dag = DAG(dag_id=dag_id, schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)})
- self.app.dag_bag.bag_dag(dag)
- self.app.dag_bag.sync_to_db()
- return dag_instance
-
- def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1):
- dag_runs = []
- dags = []
- triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
-
- for i in range(idx_start, idx_start + 2):
- if i == 1:
- dags.append(DagModel(dag_id="TEST_DAG_ID", is_active=True))
- dagrun_model = DagRun(
- dag_id="TEST_DAG_ID",
- run_id=f"TEST_DAG_RUN_ID_{i}",
- run_type=DagRunType.MANUAL,
- logical_date=timezone.parse(self.default_time) + timedelta(days=i - 1),
- start_date=timezone.parse(self.default_time),
- external_trigger=True,
- state=state,
- **triggered_by_kwargs,
- )
- dagrun_model.updated_at = timezone.parse(self.default_time)
- dag_runs.append(dagrun_model)
-
- if extra_dag:
- for i in range(idx_start + 2, idx_start + 4):
- dags.append(DagModel(dag_id=f"TEST_DAG_ID_{i}"))
- dag_runs.append(
- DagRun(
- dag_id=f"TEST_DAG_ID_{i}",
- run_id=f"TEST_DAG_RUN_ID_{i}",
- run_type=DagRunType.MANUAL,
- logical_date=timezone.parse(self.default_time_2),
- start_date=timezone.parse(self.default_time),
- external_trigger=True,
- state=state,
- )
- )
- if commit:
- with create_session() as session:
- session.add_all(dag_runs)
- session.add_all(dags)
- return dag_runs
-
-
-class TestGetDagRuns(TestDagRunEndpoint):
- def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self):
- self._create_test_dag_run(extra_dag=True)
- expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"]
- response = self.client.get(
- "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"}
- )
- assert response.status_code == 200
- dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]]
- assert dag_run_ids == expected_dag_run_ids
-
-
-class TestGetDagRunBatch(TestDagRunEndpoint):
- def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self):
- self._create_test_dag_run(extra_dag=True)
- expected_response_json_1 = {
- "dag_id": "TEST_DAG_ID",
- "dag_run_id": "TEST_DAG_RUN_ID_1",
- "end_date": None,
- "state": "running",
- "logical_date": self.default_time,
- "external_trigger": True,
- "start_date": self.default_time,
- "conf": {},
- "data_interval_end": None,
- "data_interval_start": None,
- "last_scheduling_decision": None,
- "run_type": "manual",
- "note": None,
- }
- expected_response_json_1.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {})
- expected_response_json_2 = {
- "dag_id": "TEST_DAG_ID",
- "dag_run_id": "TEST_DAG_RUN_ID_2",
- "end_date": None,
- "state": "running",
- "logical_date": self.default_time_2,
- "external_trigger": True,
- "start_date": self.default_time,
- "conf": {},
- "data_interval_end": None,
- "data_interval_start": None,
- "last_scheduling_decision": None,
- "run_type": "manual",
- "note": None,
- }
- expected_response_json_2.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {})
-
- response = self.client.post(
- "api/v1/dags/~/dagRuns/list",
- json={"dag_ids": []},
- environ_overrides={"REMOTE_USER": "test_granular_permissions"},
- )
- assert response.status_code == 200
- assert response.json == {
- "dag_runs": [
- expected_response_json_1,
- expected_response_json_2,
- ],
- "total_entries": 2,
- }
-
-
-class TestPostDagRun(TestDagRunEndpoint):
- def test_dagrun_trigger_with_dag_level_permissions(self):
- self._create_dag("TEST_DAG_ID")
- response = self.client.post(
- "api/v1/dags/TEST_DAG_ID/dagRuns",
- json={"conf": {"validated_number": 1}},
- environ_overrides={"REMOTE_USER": "test_no_dag_run_create_permission"},
- )
- assert response.status_code == 200
-
- @pytest.mark.parametrize(
- "username",
- ["test_dag_view_only", "test_view_dags", "test_granular_permissions"],
- )
- def test_should_raises_403_unauthorized(self, username):
- self._create_dag("TEST_DAG_ID")
- response = self.client.post(
- "api/v1/dags/TEST_DAG_ID/dagRuns",
- json={
- "dag_run_id": "TEST_DAG_RUN_ID_1",
- "logical_date": self.default_time,
- },
- environ_overrides={"REMOTE_USER": username},
- )
- assert response.status_code == 403
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py
deleted file mode 100644
index 39fd6ed4445b7..0000000000000
--- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py
+++ /dev/null
@@ -1,132 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import ast
-import os
-from typing import TYPE_CHECKING
-
-import pytest
-
-from airflow.models import DagBag
-from airflow.security import permissions
-
-from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
-from tests_common.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-
-pytestmark = [
- pytest.mark.db_test,
- pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
-]
-
-if TYPE_CHECKING:
- from airflow.models.dag import DAG
-
-ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
-EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py")
-EXAMPLE_DAG_ID = "example_bash_operator"
-TEST_DAG_ID = "latest_only"
-NOT_READABLE_DAG_ID = "latest_only_with_trigger"
-TEST_MULTIPLE_DAGS_ID = "asset_produces_1"
-
-
-@pytest.fixture(scope="module")
-def configured_app(minimal_app_for_auth_api):
- app = minimal_app_for_auth_api
- create_user(
- app,
- username="test",
- role_name="Test",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)],
- )
- app.appbuilder.sm.sync_perm_for_dag(
- TEST_DAG_ID,
- access_control={"Test": [permissions.ACTION_CAN_READ]},
- )
- app.appbuilder.sm.sync_perm_for_dag(
- EXAMPLE_DAG_ID,
- access_control={"Test": [permissions.ACTION_CAN_READ]},
- )
- app.appbuilder.sm.sync_perm_for_dag(
- TEST_MULTIPLE_DAGS_ID,
- access_control={"Test": [permissions.ACTION_CAN_READ]},
- )
-
- yield app
-
- delete_user(app, username="test")
-
-
-class TestGetSource:
- @pytest.fixture(autouse=True)
- def setup_attrs(self, configured_app) -> None:
- self.app = configured_app
- self.client = self.app.test_client() # type:ignore
- self.clear_db()
-
- def teardown_method(self) -> None:
- self.clear_db()
-
- @staticmethod
- def clear_db():
- clear_db_dags()
- clear_db_serialized_dags()
- clear_db_dag_code()
-
- @staticmethod
- def _get_dag_file_docstring(fileloc: str) -> str | None:
- with open(fileloc) as f:
- file_contents = f.read()
- module = ast.parse(file_contents)
- docstring = ast.get_docstring(module)
- return docstring
-
- def test_should_respond_403_not_readable(self, url_safe_serializer):
- dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
- dagbag.sync_to_db()
- dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID]
-
- response = self.client.get(
- f"/api/v1/dagSources/{dag.dag_id}",
- headers={"Accept": "text/plain"},
- environ_overrides={"REMOTE_USER": "test"},
- )
- read_dag = self.client.get(
- f"/api/v1/dags/{NOT_READABLE_DAG_ID}",
- environ_overrides={"REMOTE_USER": "test"},
- )
- assert response.status_code == 403
- assert read_dag.status_code == 403
-
- def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer):
- dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
- dagbag.sync_to_db()
- dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID]
-
- response = self.client.get(
- f"/api/v1/dagSources/{dag.dag_id}",
- headers={"Accept": "text/plain"},
- environ_overrides={"REMOTE_USER": "test"},
- )
-
- read_dag = self.client.get(
- f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}",
- environ_overrides={"REMOTE_USER": "test"},
- )
- assert response.status_code == 403
- assert read_dag.status_code == 200
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py
deleted file mode 100644
index e06146a988fe1..0000000000000
--- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py
+++ /dev/null
@@ -1,84 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.models.dag import DagModel
-from airflow.models.dagwarning import DagWarning
-from airflow.security import permissions
-from airflow.utils.session import create_session
-
-from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
-from tests_common.test_utils.db import clear_db_dag_warnings, clear_db_dags
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-
-pytestmark = [
- pytest.mark.db_test,
- pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
-]
-
-
-@pytest.fixture(scope="module")
-def configured_app(minimal_app_for_auth_api):
- app = minimal_app_for_auth_api
- create_user(
- app, # type:ignore
- username="test_with_dag2_read",
- role_name="TestWithDag2Read",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING),
- (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"),
- ],
- )
-
- yield app
-
- delete_user(app, username="test_with_dag2_read")
-
-
-class TestBaseDagWarning:
- timestamp = "2020-06-10T12:00"
-
- @pytest.fixture(autouse=True)
- def setup_attrs(self, configured_app) -> None:
- self.app = configured_app
- self.client = self.app.test_client() # type:ignore
-
- def teardown_method(self) -> None:
- clear_db_dag_warnings()
- clear_db_dags()
-
-
-class TestGetDagWarningEndpoint(TestBaseDagWarning):
- def setup_class(self):
- clear_db_dag_warnings()
- clear_db_dags()
-
- def setup_method(self):
- with create_session() as session:
- session.add(DagModel(dag_id="dag1"))
- session.add(DagWarning("dag1", "non-existent pool", "test message"))
- session.commit()
-
- def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self):
- response = self.client.get(
- "/api/v1/dagWarnings",
- environ_overrides={"REMOTE_USER": "test_with_dag2_read"},
- query_string={"dag_id": "dag1"},
- )
- assert response.status_code == 403
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_event_log_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_event_log_endpoint.py
deleted file mode 100644
index f5935dcd93c1e..0000000000000
--- a/providers/tests/fab/auth_manager/api_endpoints/test_event_log_endpoint.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.models import Log
-from airflow.security import permissions
-from airflow.utils import timezone
-
-from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
-from tests_common.test_utils.db import clear_db_logs
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-
-pytestmark = [
- pytest.mark.db_test,
- pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
-]
-
-
-@pytest.fixture(scope="module")
-def configured_app(minimal_app_for_auth_api):
- app = minimal_app_for_auth_api
- create_user(
- app,
- username="test_granular",
- role_name="TestGranular",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)],
- )
- app.appbuilder.sm.sync_perm_for_dag(
- "TEST_DAG_ID_1",
- access_control={"TestGranular": [permissions.ACTION_CAN_READ]},
- )
- app.appbuilder.sm.sync_perm_for_dag(
- "TEST_DAG_ID_2",
- access_control={"TestGranular": [permissions.ACTION_CAN_READ]},
- )
-
- yield app
-
- delete_user(app, username="test_granular")
-
-
-@pytest.fixture
-def task_instance(session, create_task_instance, request):
- return create_task_instance(
- session=session,
- dag_id="TEST_DAG_ID",
- task_id="TEST_TASK_ID",
- run_id="TEST_RUN_ID",
- logical_date=request.instance.default_time,
- )
-
-
-@pytest.fixture
-def create_log_model(create_task_instance, task_instance, session, request):
- def maker(event, when, **kwargs):
- log_model = Log(
- event=event,
- task_instance=task_instance,
- **kwargs,
- )
- log_model.dttm = when
-
- session.add(log_model)
- session.flush()
- return log_model
-
- return maker
-
-
-class TestEventLogEndpoint:
- @pytest.fixture(autouse=True)
- def setup_attrs(self, configured_app) -> None:
- self.app = configured_app
- self.client = self.app.test_client() # type:ignore
- clear_db_logs()
- self.default_time = timezone.parse("2020-06-10T20:00:00+00:00")
- self.default_time_2 = timezone.parse("2020-06-11T07:00:00+00:00")
-
- def teardown_method(self) -> None:
- clear_db_logs()
-
-
-class TestGetEventLogs(TestEventLogEndpoint):
- def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session):
- eventlog1 = create_log_model(
- event="TEST_EVENT_1",
- dag_id="TEST_DAG_ID_1",
- task_id="TEST_TASK_ID_1",
- owner="TEST_OWNER_1",
- when=self.default_time,
- )
- eventlog2 = create_log_model(
- event="TEST_EVENT_2",
- dag_id="TEST_DAG_ID_2",
- task_id="TEST_TASK_ID_2",
- owner="TEST_OWNER_2",
- when=self.default_time_2,
- )
- session.add_all([eventlog1, eventlog2])
- session.commit()
- for attr in ["dag_id", "task_id", "owner", "event"]:
- attr_value = f"TEST_{attr}_1".upper()
- response = self.client.get(
- f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"}
- )
- assert response.status_code == 200
- assert response.json["total_entries"] == 1
- assert len(response.json["event_logs"]) == 1
- assert response.json["event_logs"][0][attr] == attr_value
-
- def test_should_filter_eventlogs_by_included_events(self, create_log_model):
- for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]:
- create_log_model(event=event, when=self.default_time)
- response = self.client.get(
- "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2",
- environ_overrides={"REMOTE_USER": "test_granular"},
- )
- assert response.status_code == 200
- response_data = response.json
- assert len(response_data["event_logs"]) == 2
- assert response_data["total_entries"] == 2
- assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]}
-
- def test_should_filter_eventlogs_by_excluded_events(self, create_log_model):
- for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]:
- create_log_model(event=event, when=self.default_time)
- response = self.client.get(
- "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2",
- environ_overrides={"REMOTE_USER": "test_granular"},
- )
- assert response.status_code == 200
- response_data = response.json
- assert len(response_data["event_logs"]) == 1
- assert response_data["total_entries"] == 1
- assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]}
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_import_error_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_import_error_endpoint.py
deleted file mode 100644
index 84b3cb8ed347d..0000000000000
--- a/providers/tests/fab/auth_manager/api_endpoints/test_import_error_endpoint.py
+++ /dev/null
@@ -1,222 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.models.dag import DagModel
-from airflow.security import permissions
-from airflow.utils import timezone
-
-from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
-from tests_common.test_utils.compat import ParseImportError
-from tests_common.test_utils.db import clear_db_dags, clear_db_import_errors
-from tests_common.test_utils.permissions import _resource_name
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-
-pytestmark = [
- pytest.mark.db_test,
- pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
-]
-
-TEST_DAG_IDS = ["test_dag", "test_dag2"]
-
-
-@pytest.fixture(scope="module")
-def configured_app(minimal_app_for_auth_api):
- app = minimal_app_for_auth_api
- create_user(
- app,
- username="test_single_dag",
- role_name="TestSingleDAG",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)],
- )
- # For some reason, DAG level permissions are not synced when in the above list of perms,
- # so do it manually here:
- app.appbuilder.sm.bulk_sync_roles(
- [
- {
- "role": "TestSingleDAG",
- "perms": [
- (
- permissions.ACTION_CAN_READ,
- _resource_name(TEST_DAG_IDS[0], permissions.RESOURCE_DAG),
- )
- ],
- }
- ]
- )
-
- yield app
-
- delete_user(app, username="test_single_dag")
-
-
-class TestBaseImportError:
- timestamp = "2020-06-10T12:00"
-
- @pytest.fixture(autouse=True)
- def setup_attrs(self, configured_app) -> None:
- self.app = configured_app
- self.client = self.app.test_client() # type:ignore
-
- clear_db_import_errors()
- clear_db_dags()
-
- def teardown_method(self) -> None:
- clear_db_import_errors()
- clear_db_dags()
-
- @staticmethod
- def _normalize_import_errors(import_errors):
- for i, import_error in enumerate(import_errors, 1):
- import_error["import_error_id"] = i
-
-
-class TestGetImportErrorEndpoint(TestBaseImportError):
- def test_should_raise_403_forbidden_without_dag_read(self, session):
- import_error = ParseImportError(
- filename="Lorem_ipsum.py",
- stacktrace="Lorem ipsum",
- timestamp=timezone.parse(self.timestamp, timezone="UTC"),
- )
- session.add(import_error)
- session.commit()
-
- response = self.client.get(
- f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
- )
-
- assert response.status_code == 403
-
- def test_should_return_200_with_single_dag_read(self, session):
- dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py")
- session.add(dag_model)
- import_error = ParseImportError(
- filename="Lorem_ipsum.py",
- stacktrace="Lorem ipsum",
- timestamp=timezone.parse(self.timestamp, timezone="UTC"),
- )
- session.add(import_error)
- session.commit()
-
- response = self.client.get(
- f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
- )
-
- assert response.status_code == 200
- response_data = response.json
- response_data["import_error_id"] = 1
- assert response_data == {
- "filename": "Lorem_ipsum.py",
- "import_error_id": 1,
- "stack_trace": "Lorem ipsum",
- "timestamp": "2020-06-10T12:00:00+00:00",
- }
-
- def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session):
- for dag_id in TEST_DAG_IDS:
- dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py")
- session.add(dag_model)
- import_error = ParseImportError(
- filename="Lorem_ipsum.py",
- stacktrace="Lorem ipsum",
- timestamp=timezone.parse(self.timestamp, timezone="UTC"),
- )
- session.add(import_error)
- session.commit()
-
- response = self.client.get(
- f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
- )
-
- assert response.status_code == 200
- response_data = response.json
- response_data["import_error_id"] = 1
- assert response_data == {
- "filename": "Lorem_ipsum.py",
- "import_error_id": 1,
- "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file",
- "timestamp": "2020-06-10T12:00:00+00:00",
- }
-
-
-class TestGetImportErrorsEndpoint(TestBaseImportError):
- def test_get_import_errors_single_dag(self, session):
- for dag_id in TEST_DAG_IDS:
- fake_filename = f"/tmp/{dag_id}.py"
- dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename)
- session.add(dag_model)
- importerror = ParseImportError(
- filename=fake_filename,
- stacktrace="Lorem ipsum",
- timestamp=timezone.parse(self.timestamp, timezone="UTC"),
- )
- session.add(importerror)
- session.commit()
-
- response = self.client.get(
- "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"}
- )
-
- assert response.status_code == 200
- response_data = response.json
- self._normalize_import_errors(response_data["import_errors"])
- assert response_data == {
- "import_errors": [
- {
- "filename": "/tmp/test_dag.py",
- "import_error_id": 1,
- "stack_trace": "Lorem ipsum",
- "timestamp": "2020-06-10T12:00:00+00:00",
- },
- ],
- "total_entries": 1,
- }
-
- def test_get_import_errors_single_dag_in_dagfile(self, session):
- for dag_id in TEST_DAG_IDS:
- fake_filename = "/tmp/all_in_one.py"
- dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename)
- session.add(dag_model)
-
- importerror = ParseImportError(
- filename="/tmp/all_in_one.py",
- stacktrace="Lorem ipsum",
- timestamp=timezone.parse(self.timestamp, timezone="UTC"),
- )
- session.add(importerror)
- session.commit()
-
- response = self.client.get(
- "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"}
- )
-
- assert response.status_code == 200
- response_data = response.json
- self._normalize_import_errors(response_data["import_errors"])
- assert response_data == {
- "import_errors": [
- {
- "filename": "/tmp/all_in_one.py",
- "import_error_id": 1,
- "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file",
- "timestamp": "2020-06-10T12:00:00+00:00",
- },
- ],
- "total_entries": 1,
- }
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py
deleted file mode 100644
index 5f755109a0644..0000000000000
--- a/providers/tests/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py
+++ /dev/null
@@ -1,426 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import datetime as dt
-import urllib
-
-import pytest
-
-from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
-from airflow.models import DagRun, TaskInstance
-from airflow.security import permissions
-from airflow.utils.session import provide_session
-from airflow.utils.state import State
-from airflow.utils.timezone import datetime
-from airflow.utils.types import DagRunType
-
-from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import (
- create_user,
- delete_roles,
- delete_user,
-)
-from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-
-pytestmark = [
- pytest.mark.db_test,
- pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
-]
-
-DEFAULT_DATETIME_1 = datetime(2020, 1, 1)
-DEFAULT_DATETIME_STR_1 = "2020-01-01T00:00:00+00:00"
-DEFAULT_DATETIME_STR_2 = "2020-01-02T00:00:00+00:00"
-
-QUOTED_DEFAULT_DATETIME_STR_1 = urllib.parse.quote(DEFAULT_DATETIME_STR_1)
-QUOTED_DEFAULT_DATETIME_STR_2 = urllib.parse.quote(DEFAULT_DATETIME_STR_2)
-
-
-@pytest.fixture(scope="module")
-def configured_app(minimal_app_for_auth_api):
- app = minimal_app_for_auth_api
- create_user(
- app,
- username="test_dag_read_only",
- role_name="TestDagReadOnly",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
- ],
- )
- create_user(
- app,
- username="test_task_read_only",
- role_name="TestTaskReadOnly",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- ],
- )
- create_user(
- app,
- username="test_read_only_one_dag",
- role_name="TestReadOnlyOneDag",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- ],
- )
- # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms,
- # so do it manually here:
- app.appbuilder.sm.bulk_sync_roles(
- [
- {
- "role": "TestReadOnlyOneDag",
- "perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")],
- }
- ]
- )
-
- yield app
-
- delete_user(app, username="test_dag_read_only")
- delete_user(app, username="test_task_read_only")
- delete_user(app, username="test_read_only_one_dag")
- delete_roles(app)
-
-
-class TestTaskInstanceEndpoint:
- @pytest.fixture(autouse=True)
- def setup_attrs(self, configured_app, dagbag) -> None:
- self.default_time = DEFAULT_DATETIME_1
- self.ti_init = {
- "logical_date": self.default_time,
- "state": State.RUNNING,
- }
- self.ti_extras = {
- "start_date": self.default_time + dt.timedelta(days=1),
- "end_date": self.default_time + dt.timedelta(days=2),
- "pid": 100,
- "duration": 10000,
- "pool": "default_pool",
- "queue": "default_queue",
- "job_id": 0,
- }
- self.app = configured_app
- self.client = self.app.test_client() # type:ignore
- clear_db_runs()
- clear_rendered_ti_fields()
- self.dagbag = dagbag
-
- def create_task_instances(
- self,
- session,
- dag_id: str = "example_python_operator",
- update_extras: bool = True,
- task_instances=None,
- dag_run_state=State.RUNNING,
- with_ti_history=False,
- ):
- """Method to create task instances using kwargs and default arguments"""
-
- dag = self.dagbag.get_dag(dag_id)
- tasks = dag.tasks
- counter = len(tasks)
- if task_instances is not None:
- counter = min(len(task_instances), counter)
-
- run_id = "TEST_DAG_RUN_ID"
- logical_date = self.ti_init.pop("logical_date", self.default_time)
- dr = None
-
- tis = []
- for i in range(counter):
- if task_instances is None:
- pass
- elif update_extras:
- self.ti_extras.update(task_instances[i])
- else:
- self.ti_init.update(task_instances[i])
-
- if "logical_date" in self.ti_init:
- run_id = f"TEST_DAG_RUN_ID_{i}"
- logical_date = self.ti_init.pop("logical_date")
- dr = None
-
- if not dr:
- dr = DagRun(
- run_id=run_id,
- dag_id=dag_id,
- logical_date=logical_date,
- run_type=DagRunType.MANUAL,
- state=dag_run_state,
- )
- session.add(dr)
- ti = TaskInstance(task=tasks[i], **self.ti_init)
- session.add(ti)
- ti.dag_run = dr
- ti.note = "placeholder-note"
-
- for key, value in self.ti_extras.items():
- setattr(ti, key, value)
- tis.append(ti)
-
- session.commit()
- if with_ti_history:
- for ti in tis:
- ti.try_number = 1
- session.merge(ti)
- session.commit()
- dag.clear()
- for ti in tis:
- ti.try_number = 2
- ti.queue = "default_queue"
- session.merge(ti)
- session.commit()
- return tis
-
-
-class TestGetTaskInstance(TestTaskInstanceEndpoint):
- def setup_method(self):
- clear_db_runs()
-
- def teardown_method(self):
- clear_db_runs()
-
- @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"])
- @provide_session
- def test_should_respond_200(self, username, session):
- self.create_task_instances(session)
- # Update ti and set operator to None to
- # test that operator field is nullable.
- # This prevents issue when users upgrade to 2.0+
- # from 1.10.x
- # https://github.com/apache/airflow/issues/14421
- session.query(TaskInstance).update({TaskInstance.operator: None}, synchronize_session="fetch")
- session.commit()
- response = self.client.get(
- "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context",
- environ_overrides={"REMOTE_USER": username},
- )
- assert response.status_code == 200
-
-
-class TestGetTaskInstances(TestTaskInstanceEndpoint):
- @pytest.mark.parametrize(
- "task_instances, user, expected_ti",
- [
- pytest.param(
- {
- "example_python_operator": 2,
- "example_skip_dag": 1,
- },
- "test_read_only_one_dag",
- 2,
- ),
- pytest.param(
- {
- "example_python_operator": 1,
- "example_skip_dag": 2,
- },
- "test_read_only_one_dag",
- 1,
- ),
- ],
- )
- def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session):
- for dag_id in task_instances:
- self.create_task_instances(
- session,
- task_instances=[
- {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=i)}
- for i in range(task_instances[dag_id])
- ],
- dag_id=dag_id,
- )
- response = self.client.get(
- "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user}
- )
- assert response.status_code == 200
- assert response.json["total_entries"] == expected_ti
- assert len(response.json["task_instances"]) == expected_ti
-
-
-class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
- @pytest.mark.parametrize(
- "task_instances, update_extras, payload, expected_ti_count, username",
- [
- pytest.param(
- [
- {"pool": "test_pool_1"},
- {"pool": "test_pool_2"},
- {"pool": "test_pool_3"},
- ],
- True,
- {"pool": ["test_pool_1", "test_pool_2"]},
- 2,
- "test_dag_read_only",
- id="test pool filter",
- ),
- pytest.param(
- [
- {"state": State.RUNNING},
- {"state": State.QUEUED},
- {"state": State.SUCCESS},
- {"state": State.NONE},
- ],
- False,
- {"state": ["running", "queued", "none"]},
- 3,
- "test_task_read_only",
- id="test state filter",
- ),
- pytest.param(
- [
- {"state": State.NONE},
- {"state": State.NONE},
- {"state": State.NONE},
- {"state": State.NONE},
- ],
- False,
- {},
- 4,
- "test_task_read_only",
- id="test dag with null states",
- ),
- pytest.param(
- [
- {"end_date": DEFAULT_DATETIME_1},
- {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)},
- {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)},
- ],
- True,
- {
- "end_date_gte": DEFAULT_DATETIME_STR_1,
- "end_date_lte": DEFAULT_DATETIME_STR_2,
- },
- 2,
- "test_task_read_only",
- id="test end date filter",
- ),
- pytest.param(
- [
- {"start_date": DEFAULT_DATETIME_1},
- {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)},
- {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)},
- ],
- True,
- {
- "start_date_gte": DEFAULT_DATETIME_STR_1,
- "start_date_lte": DEFAULT_DATETIME_STR_2,
- },
- 2,
- "test_dag_read_only",
- id="test start date filter",
- ),
- ],
- )
- def test_should_respond_200(
- self, task_instances, update_extras, payload, expected_ti_count, username, session
- ):
- self.create_task_instances(
- session,
- update_extras=update_extras,
- task_instances=task_instances,
- )
- response = self.client.post(
- "/api/v1/dags/~/dagRuns/~/taskInstances/list",
- environ_overrides={"REMOTE_USER": username},
- json=payload,
- )
- assert response.status_code == 200, response.json
- assert expected_ti_count == response.json["total_entries"]
- assert expected_ti_count == len(response.json["task_instances"])
-
- def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session):
- self.create_task_instances(session=session)
- self.create_task_instances(session=session, dag_id="example_skip_dag")
- payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]}
-
- response = self.client.post(
- "/api/v1/dags/~/dagRuns/~/taskInstances/list",
- environ_overrides={"REMOTE_USER": "test_read_only_one_dag"},
- json=payload,
- )
- assert response.status_code == 403
- assert response.json == {
- "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']",
- "status": 403,
- "title": "Forbidden",
- "type": EXCEPTIONS_LINK_MAP[403],
- }
-
-
-class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
- @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"])
- def test_should_raise_403_forbidden(self, username):
- response = self.client.post(
- "/api/v1/dags/example_python_operator/updateTaskInstancesState",
- environ_overrides={"REMOTE_USER": username},
- json={
- "dry_run": True,
- "task_id": "print_the_context",
- "logical_date": DEFAULT_DATETIME_1.isoformat(),
- "include_upstream": True,
- "include_downstream": True,
- "include_future": True,
- "include_past": True,
- "new_state": "failed",
- },
- )
- assert response.status_code == 403
-
-
-class TestPatchTaskInstance(TestTaskInstanceEndpoint):
- ENDPOINT_URL = (
- "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context"
- )
-
- @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"])
- def test_should_raise_403_forbidden(self, username):
- response = self.client.patch(
- self.ENDPOINT_URL,
- environ_overrides={"REMOTE_USER": username},
- json={
- "dry_run": True,
- "new_state": "failed",
- },
- )
- assert response.status_code == 403
-
-
-class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
- def setup_method(self):
- clear_db_runs()
-
- def teardown_method(self):
- clear_db_runs()
-
- @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"])
- @provide_session
- def test_should_respond_200(self, username, session):
- self.create_task_instances(session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True)
-
- response = self.client.get(
- "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1",
- environ_overrides={"REMOTE_USER": username},
- )
- assert response.status_code == 200
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_variable_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_variable_endpoint.py
deleted file mode 100644
index 954a2de130ddb..0000000000000
--- a/providers/tests/fab/auth_manager/api_endpoints/test_variable_endpoint.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from airflow.models import Variable
-from airflow.security import permissions
-
-from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
-from tests_common.test_utils.db import clear_db_variables
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-
-pytestmark = [
- pytest.mark.db_test,
- pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
-]
-
-
-@pytest.fixture(scope="module")
-def configured_app(minimal_app_for_auth_api):
- app = minimal_app_for_auth_api
-
- create_user(
- app,
- username="test_read_only",
- role_name="TestReadOnly",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE),
- ],
- )
- create_user(
- app,
- username="test_delete_only",
- role_name="TestDeleteOnly",
- permissions=[
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE),
- ],
- )
-
- yield app
-
- delete_user(app, username="test_read_only")
- delete_user(app, username="test_delete_only")
-
-
-class TestVariableEndpoint:
- @pytest.fixture(autouse=True)
- def setup_method(self, configured_app) -> None:
- self.app = configured_app
- self.client = self.app.test_client() # type:ignore
- clear_db_variables()
-
- def teardown_method(self) -> None:
- clear_db_variables()
-
-
-class TestGetVariable(TestVariableEndpoint):
- @pytest.mark.parametrize(
- "user, expected_status_code",
- [
- ("test_read_only", 200),
- ("test_delete_only", 403),
- ],
- )
- def test_read_variable(self, user, expected_status_code):
- expected_value = '{"foo": 1}'
- Variable.set("TEST_VARIABLE_KEY", expected_value)
- response = self.client.get(
- "/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": user}
- )
- assert response.status_code == expected_status_code
- if expected_status_code == 200:
- assert response.json == {"key": "TEST_VARIABLE_KEY", "value": expected_value, "description": None}
diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_xcom_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_xcom_endpoint.py
deleted file mode 100644
index fb46f52a402ed..0000000000000
--- a/providers/tests/fab/auth_manager/api_endpoints/test_xcom_endpoint.py
+++ /dev/null
@@ -1,268 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from datetime import timedelta
-
-import pytest
-
-from airflow.models.dag import DagModel
-from airflow.models.dagrun import DagRun
-from airflow.models.taskinstance import TaskInstance
-from airflow.models.xcom import BaseXCom, XCom
-from airflow.operators.empty import EmptyOperator
-from airflow.security import permissions
-from airflow.utils import timezone
-from airflow.utils.session import create_session
-from airflow.utils.types import DagRunType
-
-from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
-from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-
-pytestmark = [
- pytest.mark.db_test,
- pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
-]
-
-
-class CustomXCom(BaseXCom):
- @classmethod
- def deserialize_value(cls, xcom: XCom):
- return f"real deserialized {super().deserialize_value(xcom)}"
-
- def orm_deserialize_value(self):
- return f"orm deserialized {super().orm_deserialize_value()}"
-
-
-@pytest.fixture(scope="module")
-def configured_app(minimal_app_for_auth_api):
- app = minimal_app_for_auth_api
-
- create_user(
- app,
- username="test_granular_permissions",
- role_name="TestGranularDag",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM),
- ],
- )
- app.appbuilder.sm.sync_perm_for_dag(
- "test-dag-id-1",
- access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]},
- )
-
- yield app
-
- delete_user(app, username="test_granular_permissions")
-
-
-def _compare_xcom_collections(collection1: dict, collection_2: dict):
- assert collection1.get("total_entries") == collection_2.get("total_entries")
-
- def sort_key(record):
- return (
- (
- record.get("dag_id"),
- record.get("task_id"),
- record.get("logical_date"),
- record.get("map_index"),
- record.get("key"),
- )
- if AIRFLOW_V_3_0_PLUS
- else (
- record.get("dag_id"),
- record.get("task_id"),
- record.get("execution_date"),
- record.get("map_index"),
- record.get("key"),
- )
- )
-
- assert sorted(collection1.get("xcom_entries", []), key=sort_key) == sorted(
- collection_2.get("xcom_entries", []), key=sort_key
- )
-
-
-class TestXComEndpoint:
- @staticmethod
- def clean_db():
- clear_db_dags()
- clear_db_runs()
- clear_db_xcom()
-
- @pytest.fixture(autouse=True)
- def setup_attrs(self, configured_app) -> None:
- """
- Setup For XCom endpoint TC
- """
- self.app = configured_app
- self.client = self.app.test_client() # type:ignore
- # clear existing xcoms
- self.clean_db()
-
- def teardown_method(self) -> None:
- """
- Clear Hanging XComs
- """
- self.clean_db()
-
-
-class TestGetXComEntries(TestXComEndpoint):
- def test_should_respond_200_with_tilde_and_granular_dag_access(self):
- dag_id_1 = "test-dag-id-1"
- task_id_1 = "test-task-id-1"
- logical_date = "2005-04-02T00:00:00+00:00"
- logical_date_parsed = timezone.parse(logical_date)
- dag_run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, logical_date_parsed)
- self._create_xcom_entries(dag_id_1, dag_run_id_1, logical_date_parsed, task_id_1)
-
- dag_id_2 = "test-dag-id-2"
- task_id_2 = "test-task-id-2"
- run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, logical_date_parsed)
- self._create_xcom_entries(dag_id_2, run_id_2, logical_date_parsed, task_id_2)
- self._create_invalid_xcom_entries(logical_date_parsed)
- response = self.client.get(
- "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries",
- environ_overrides={"REMOTE_USER": "test_granular_permissions"},
- )
-
- assert response.status_code == 200
- response_data = response.json
- for xcom_entry in response_data["xcom_entries"]:
- xcom_entry["timestamp"] = "TIMESTAMP"
- date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date"
- _compare_xcom_collections(
- response_data,
- {
- "xcom_entries": [
- {
- "dag_id": dag_id_1,
- date_key: logical_date,
- "key": "test-xcom-key-1",
- "task_id": task_id_1,
- "timestamp": "TIMESTAMP",
- "map_index": -1,
- },
- {
- "dag_id": dag_id_1,
- date_key: logical_date,
- "key": "test-xcom-key-2",
- "task_id": task_id_1,
- "timestamp": "TIMESTAMP",
- "map_index": -1,
- },
- ],
- "total_entries": 2,
- },
- )
-
- def _create_xcom_entries(self, dag_id, run_id, logical_date, task_id, mapped_ti=False):
- with create_session() as session:
- dag = DagModel(dag_id=dag_id)
- session.add(dag)
- if AIRFLOW_V_3_0_PLUS:
- dagrun = DagRun(
- dag_id=dag_id,
- run_id=run_id,
- logical_date=logical_date,
- start_date=logical_date,
- run_type=DagRunType.MANUAL,
- )
- else:
- dagrun = DagRun(
- dag_id=dag_id,
- run_id=run_id,
- execution_date=logical_date,
- start_date=logical_date,
- run_type=DagRunType.MANUAL,
- )
- session.add(dagrun)
- if mapped_ti:
- for i in [0, 1]:
- ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, map_index=i)
- ti.dag_id = dag_id
- session.add(ti)
- else:
- ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
- ti.dag_id = dag_id
- session.add(ti)
-
- for i in [1, 2]:
- if mapped_ti:
- key = "test-xcom-key"
- map_index = i - 1
- else:
- key = f"test-xcom-key-{i}"
- map_index = -1
-
- XCom.set(
- key=key, value="TEST", run_id=run_id, task_id=task_id, dag_id=dag_id, map_index=map_index
- )
-
- def _create_invalid_xcom_entries(self, logical_date):
- """
- Invalid XCom entries to test join query
- """
- with create_session() as session:
- dag = DagModel(dag_id="invalid_dag")
- session.add(dag)
- if AIRFLOW_V_3_0_PLUS:
- dagrun = DagRun(
- dag_id="invalid_dag",
- run_id="invalid_run_id",
- logical_date=logical_date + timedelta(days=1),
- start_date=logical_date,
- run_type=DagRunType.MANUAL,
- )
- else:
- dagrun = DagRun(
- dag_id="invalid_dag",
- run_id="invalid_run_id",
- execution_date=logical_date + timedelta(days=1),
- start_date=logical_date,
- run_type=DagRunType.MANUAL,
- )
- session.add(dagrun)
- if AIRFLOW_V_3_0_PLUS:
- dagrun1 = DagRun(
- dag_id="invalid_dag",
- run_id="not_this_run_id",
- logical_date=logical_date,
- start_date=logical_date,
- run_type=DagRunType.MANUAL,
- )
- else:
- dagrun1 = DagRun(
- dag_id="invalid_dag",
- run_id="not_this_run_id",
- execution_date=logical_date,
- start_date=logical_date,
- run_type=DagRunType.MANUAL,
- )
- session.add(dagrun1)
- ti = TaskInstance(EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id")
- ti.dag_id = "invalid_dag"
- session.add(ti)
- for i in [1, 2]:
- XCom.set(
- key=f"invalid-xcom-key-{i}",
- value="TEST",
- run_id="not_this_run_id",
- task_id="invalid_task",
- dag_id="invalid_dag",
- )
diff --git a/providers/tests/fab/auth_manager/conftest.py b/providers/tests/fab/auth_manager/conftest.py
index d400a7b86a027..a301806a50500 100644
--- a/providers/tests/fab/auth_manager/conftest.py
+++ b/providers/tests/fab/auth_manager/conftest.py
@@ -18,7 +18,7 @@
import pytest
-from airflow.www import app
+from airflow.providers.fab.www import app
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.decorators import dont_initialize_flask_app_submodules
diff --git a/providers/tests/fab/auth_manager/test_fab_auth_manager.py b/providers/tests/fab/auth_manager/test_fab_auth_manager.py
index c6c53371223fd..048207a680017 100644
--- a/providers/tests/fab/auth_manager/test_fab_auth_manager.py
+++ b/providers/tests/fab/auth_manager/test_fab_auth_manager.py
@@ -27,6 +27,7 @@
from flask_appbuilder.menu import Menu
from airflow.exceptions import AirflowConfigException, AirflowException
+from airflow.providers.fab.www.extensions.init_appbuilder import init_appbuilder
from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user
@@ -62,7 +63,6 @@
RESOURCE_VARIABLE,
RESOURCE_WEBSITE,
)
-from airflow.www.extensions.init_appbuilder import init_appbuilder
if TYPE_CHECKING:
from airflow.auth.managers.base_auth_manager import ResourceMethod
diff --git a/providers/tests/fab/auth_manager/test_security.py b/providers/tests/fab/auth_manager/test_security.py
index 67dd4179b09a5..e718589060885 100644
--- a/providers/tests/fab/auth_manager/test_security.py
+++ b/providers/tests/fab/auth_manager/test_security.py
@@ -36,6 +36,8 @@
from airflow.exceptions import AirflowException
from airflow.models import DagModel
from airflow.models.dag import DAG
+from airflow.providers.fab.www.auth import get_access_denied_message, has_access_dag
+from airflow.providers.fab.www.utils import CustomSQLAInterface
from tests_common.test_utils.compat import ignore_provider_compatibility_error
@@ -45,11 +47,9 @@
from airflow.providers.fab.auth_manager.models.anonymous_user import AnonymousUser
from airflow.api_fastapi.app import get_auth_manager
+from airflow.providers.fab.www import app as application
from airflow.security import permissions
from airflow.security.permissions import ACTION_CAN_READ
-from airflow.www import app as application
-from airflow.www.auth import get_access_denied_message
-from airflow.www.utils import CustomSQLAInterface
from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import (
create_user,
@@ -1162,8 +1162,6 @@ def test_dag_id_consistency(
fail: bool,
):
with app.test_request_context() as mock_context:
- from airflow.www.auth import has_access_dag
-
mock_context.request.args = {"dag_id": dag_id_args} if dag_id_args else {}
kwargs = {"dag_id": dag_id_kwargs} if dag_id_kwargs else {}
mock_context.request.form = {"dag_id": dag_id_form} if dag_id_form else {}
diff --git a/providers/tests/fab/auth_manager/views/test_permissions.py b/providers/tests/fab/auth_manager/views/test_permissions.py
index 1ef8f8d552131..4cb516e42942f 100644
--- a/providers/tests/fab/auth_manager/views/test_permissions.py
+++ b/providers/tests/fab/auth_manager/views/test_permissions.py
@@ -19,8 +19,8 @@
import pytest
+from airflow.providers.fab.www import app as application
from airflow.security import permissions
-from airflow.www import app as application
from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning
diff --git a/providers/tests/fab/auth_manager/views/test_roles_list.py b/providers/tests/fab/auth_manager/views/test_roles_list.py
index dd7429339f4f1..a50f883ead0b7 100644
--- a/providers/tests/fab/auth_manager/views/test_roles_list.py
+++ b/providers/tests/fab/auth_manager/views/test_roles_list.py
@@ -19,8 +19,8 @@
import pytest
+from airflow.providers.fab.www import app as application
from airflow.security import permissions
-from airflow.www import app as application
from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning
diff --git a/providers/tests/fab/auth_manager/views/test_user.py b/providers/tests/fab/auth_manager/views/test_user.py
index 3db1bd9e463c3..5ed8224bac82e 100644
--- a/providers/tests/fab/auth_manager/views/test_user.py
+++ b/providers/tests/fab/auth_manager/views/test_user.py
@@ -19,8 +19,8 @@
import pytest
+from airflow.providers.fab.www import app as application
from airflow.security import permissions
-from airflow.www import app as application
from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning
diff --git a/providers/tests/fab/auth_manager/views/test_user_edit.py b/providers/tests/fab/auth_manager/views/test_user_edit.py
index c28d11e286ba4..e616d435f26b3 100644
--- a/providers/tests/fab/auth_manager/views/test_user_edit.py
+++ b/providers/tests/fab/auth_manager/views/test_user_edit.py
@@ -19,8 +19,8 @@
import pytest
+from airflow.providers.fab.www import app as application
from airflow.security import permissions
-from airflow.www import app as application
from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning
diff --git a/providers/tests/fab/auth_manager/views/test_user_stats.py b/providers/tests/fab/auth_manager/views/test_user_stats.py
index 8a4fb820635e1..840d358a5aada 100644
--- a/providers/tests/fab/auth_manager/views/test_user_stats.py
+++ b/providers/tests/fab/auth_manager/views/test_user_stats.py
@@ -19,8 +19,8 @@
import pytest
+from airflow.providers.fab.www import app as application
from airflow.security import permissions
-from airflow.www import app as application
from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning