# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Tests for the open metrics view."""

from collections.abc import Callable, Iterable
from typing import assert_never, cast
from unittest import mock

import requests
from django.test import SimpleTestCase
from django.urls import reverse
from django_prometheus.testutils import assert_metric_diff, save_registry
from prometheus_client import CONTENT_TYPE_LATEST as PROMETHEUS_CONTENT_TYPE
from prometheus_client import (
    CollectorRegistry,
    Counter,
    Gauge,
    Histogram,
    Metric,
    Summary,
)
from prometheus_client.openmetrics.exposition import (
    CONTENT_TYPE_LATEST as OPENMETRICS_CONTENT_TYPE,
)
from prometheus_client.openmetrics.parser import (
    text_string_to_metric_families as openmetrics_parser,
)
from prometheus_client.parser import (
    text_string_to_metric_families as prometheus_parser,
)

from debusine.server.views.open_metrics import (
    OpenMetricsView,
    RendererFormats,
    extract_media_type,
)
from debusine.test.django import TestCase


class HelperTests(SimpleTestCase):
    """Tests for simple helper methods."""

    def test_extract_media_type(self) -> None:
        self.assertEqual(
            extract_media_type("text/plain; version=1; charset=utf-8"),
            "text/plain",
        )


class OpenMetricsViewTests(TestCase):
    """Tests for OpenMetricsView."""

    def get_metrics(
        self, renderer_format: RendererFormats = RendererFormats.PROMETHEUS
    ) -> dict[str, Metric]:
        headers: dict[str, str]
        expected_content_type: str
        parser: Callable[[str], Iterable[Metric]]

        match renderer_format:
            case RendererFormats.PROMETHEUS:
                headers = {}
                expected_content_type = PROMETHEUS_CONTENT_TYPE
                parser = prometheus_parser
            case RendererFormats.OPENMETRICS:
                headers = {"Accept": "application/openmetrics-text"}
                expected_content_type = OPENMETRICS_CONTENT_TYPE
                parser = cast(
                    Callable[[str], Iterable[Metric]], openmetrics_parser
                )
            case _ as unreachable:
                assert_never(unreachable)

        response = self.client.get(reverse("api:open-metrics"), headers=headers)
        self.assertEqual(response.status_code, 200)
        self.assertEqual(
            response.headers["Content-Type"], expected_content_type
        )
        text = response.content.decode()
        if renderer_format == RendererFormats.OPENMETRICS:
            self.assertEqual(text.splitlines().count("# EOF"), 1)
        return {metric.name: metric for metric in parser(text)}

    def test_output_headers_prometheus(self) -> None:
        self.get_metrics(RendererFormats.PROMETHEUS)

    def test_output_headers_openmetrics(self) -> None:
        self.get_metrics(RendererFormats.OPENMETRICS)

    def test_default_collector_metrics_disabled(self) -> None:
        metrics = self.get_metrics()

        self.assertNotIn("python_gc_objects_collected", metrics)
        self.assertNotIn("python_info", metrics)
        self.assertNotIn("process_virtual_memory_bytes", metrics)

    def test_default_collector_metrics_disabled_single_process(self) -> None:
        with mock.patch.object(
            OpenMetricsView,
            "_is_multiprocess",
            new_callable=mock.PropertyMock(return_value=False),
        ):
            metrics = self.get_metrics()

        self.assertNotIn("python_gc_objects_collected", metrics)
        self.assertNotIn("python_info", metrics)
        self.assertNotIn("process_virtual_memory_bytes", metrics)

    def test_non_database_metrics_registered(self) -> None:
        # Ensure that websocket metrics are registered.
        import debusine.server.consumers  # noqa: F401

        self.client.get(reverse("homepage:homepage"))

        metrics = self.get_metrics()
        responses = {
            (
                sample.labels["status"],
                sample.labels["view"],
                sample.labels["method"],
            ): sample.value
            for sample in metrics[
                "django_http_responses_total_by_status_view_method"
            ].samples
        }
        self.assertGreaterEqual(
            responses[("200", "homepage:homepage", "GET")], 1.0
        )
        self.assertIn("debusine_active_websocket_connections", metrics)


class OpenMetricsEmitViewTests(TestCase):
    """Tests for OpenMetricsEmitView."""

    def setUp(self) -> None:
        super().setUp()
        self.registry = CollectorRegistry()
        self.prefix = self.__class__.__name__
        self.counter = Counter(
            f"{self.prefix}_test_counter",
            "test_counter",
            labelnames=("foo",),
            registry=self.registry,
        )
        self.gauge = Gauge(
            f"{self.prefix}_test_gauge",
            "test_gauge",
            labelnames=("foo",),
            registry=self.registry,
        )
        self.summary = Summary(
            f"{self.prefix}_test_summary",
            "test_summary",
            labelnames=("foo",),
            registry=self.registry,
        )
        self.histogram = Histogram(
            f"{self.prefix}_test_histogram",
            "test_histogram",
            labelnames=("foo",),
            registry=self.registry,
        )
        self.enterContext(
            mock.patch("debusine.db.metrics.REGISTRY", new=self.registry)
        )
        self.frozen_registry = save_registry(registry=self.registry)

    def test_anonymous(self) -> None:
        response = self.client.post(
            reverse("api:open-metrics-emit"),
            data={
                "metric_type": "counter",
                "name": f"{self.prefix}_test_counter",
                "labels": {},
                "value": 1,
            },
            content_type="application/json",
        )
        self.assertResponseProblem(
            response,
            "Error",
            detail_pattern=r"Authentication credentials were not provided\.",
            status_code=requests.codes.forbidden,
        )

    def test_user_auth(self) -> None:
        token = self.playground.create_user_token()
        response = self.client.post(
            reverse("api:open-metrics-emit"),
            data={
                "metric_type": "counter",
                "name": f"{self.prefix}_test_counter",
                "labels": {},
                "value": 1,
            },
            headers={"token": token.key},
            content_type="application/json",
        )
        self.assertResponseProblem(
            response,
            "Error",
            detail_pattern=(
                r"You do not have permission to perform this action\."
            ),
            status_code=requests.codes.forbidden,
        )

    def test_metric_not_found(self) -> None:
        token = self.playground.create_worker_token()
        response = self.client.post(
            reverse("api:open-metrics-emit"),
            data={
                "metric_type": "counter",
                "name": "nonexistent",
                "labels": {},
                "value": 1,
            },
            headers={"token": token.key},
            content_type="application/json",
        )
        self.assertResponseProblem(
            response,
            "Metric not found",
            detail_pattern="Metric 'nonexistent' not found",
            status_code=requests.codes.not_found,
        )

    def test_bad_metric_type(self) -> None:
        token = self.playground.create_worker_token()
        for metric_type, name in (
            ("counter", f"{self.prefix}_test_gauge"),
            ("gauge", f"{self.prefix}_test_counter"),
            ("summary", f"{self.prefix}_test_histogram"),
            ("histogram", f"{self.prefix}_test_summary"),
        ):
            with self.subTest(metric_type=metric_type):
                response = self.client.post(
                    reverse("api:open-metrics-emit"),
                    data={
                        "metric_type": metric_type,
                        "name": name,
                        "labels": {},
                        "value": 1,
                    },
                    headers={"token": token.key},
                    content_type="application/json",
                )
                self.assertResponseProblem(
                    response,
                    "Bad metric type",
                    detail_pattern=(
                        f"Metric '{name}' is not a {metric_type.capitalize()}"
                    ),
                )

    def test_success(self) -> None:
        token = self.playground.create_worker_token()
        for metric_type, name, value, expected_diffs in (
            (
                "counter",
                f"{self.prefix}_test_counter",
                1,
                ((f"{self.prefix}_test_counter_total", 1, {}),),
            ),
            (
                "gauge",
                f"{self.prefix}_test_gauge",
                2.5,
                ((f"{self.prefix}_test_gauge", 2.5, {}),),
            ),
            (
                "summary",
                f"{self.prefix}_test_summary",
                512,
                (
                    (f"{self.prefix}_test_summary_count", 1, {}),
                    (f"{self.prefix}_test_summary_sum", 512, {}),
                ),
            ),
            (
                "histogram",
                f"{self.prefix}_test_histogram",
                10.0,
                (
                    (f"{self.prefix}_test_histogram_bucket", 0, {"le": "7.5"}),
                    (f"{self.prefix}_test_histogram_bucket", 1, {"le": "10.0"}),
                    (f"{self.prefix}_test_histogram_count", 1, {}),
                    (f"{self.prefix}_test_histogram_sum", 10.0, {}),
                ),
            ),
        ):
            with self.subTest(metric_type=metric_type):
                response = self.client.post(
                    reverse("api:open-metrics-emit"),
                    data={
                        "metric_type": metric_type,
                        "name": name,
                        "labels": {"foo": "bar"},
                        "value": value,
                    },
                    headers={"token": token.key},
                    content_type="application/json",
                )
                self.assertEqual(
                    response.status_code, requests.codes.no_content
                )
                for (
                    expected_name,
                    expected_diff,
                    expected_labels,
                ) in expected_diffs:
                    with self.subTest(expected_name=expected_name):
                        assert_metric_diff(
                            self.frozen_registry,
                            expected_diff,
                            expected_name,
                            registry=self.registry,
                            foo="bar",
                            **expected_labels,
                        )
