from __future__ import annotations

import functools
import hmac
import http
from typing import Any, Awaitable, Callable, Iterable, Tuple, cast

from ..datastructures import Headers
from ..exceptions import InvalidHeader
from ..headers import build_www_authenticate_basic, parse_authorization_basic
from .server import HTTPResponse, WebSocketServerProtocol


__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]

# Change to tuple[str, str] when dropping Python < 3.9.
Credentials = Tuple[str, str]


def is_credentials(value: Any) -> bool:
    try:
        username, password = value
    except (TypeError, ValueError):
        return False
    else:
        return isinstance(username, str) and isinstance(password, str)


class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
    """
    WebSocket server protocol that enforces HTTP Basic Auth.

    """

    realm: str = ""
    """
    Scope of protection.

    If provided, it should contain only ASCII characters because the
    encoding of non-ASCII characters is undefined.
    """

    username: str | None = None
    """Username of the authenticated user."""

    def __init__(
        self,
        *args: Any,
        realm: str | None = None,
        check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
        **kwargs: Any,
    ) -> None:
        if realm is not None:
            self.realm = realm  # shadow class attribute
        self._check_credentials = check_credentials
        super().__init__(*args, **kwargs)

    async def check_credentials(self, username: str, password: str) -> bool:
        """
        Check whether credentials are authorized.

        This coroutine may be overridden in a subclass, for example to
        authenticate against a database or an external service.

        Args:
            username: HTTP Basic Auth username.
            password: HTTP Basic Auth password.

        Returns:
            :obj:`True` if the handshake should continue;
            :obj:`False` if it should fail with an HTTP 401 error.

        """
        if self._check_credentials is not None:
            return await self._check_credentials(username, password)

        return False

    async def process_request(
        self,
        path: str,
        request_headers: Headers,
    ) -> HTTPResponse | None:
        """
        Check HTTP Basic Auth and return an HTTP 401 response if needed.

        """
        try:
            authorization = request_headers["Authorization"]
        except KeyError:
            return (
                http.HTTPStatus.UNAUTHORIZED,
                [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
                b"Missing credentials\n",
            )

        try:
            username, password = parse_authorization_basic(authorization)
        except InvalidHeader:
            return (
                http.HTTPStatus.UNAUTHORIZED,
                [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
                b"Unsupported credentials\n",
            )

        if not await self.check_credentials(username, password):
            return (
                http.HTTPStatus.UNAUTHORIZED,
                [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
                b"Invalid credentials\n",
            )

        self.username = username

        return await super().process_request(path, request_headers)


def basic_auth_protocol_factory(
    realm: str | None = None,
    credentials: Credentials | Iterable[Credentials] | None = None,
    check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
    create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None,
) -> Callable[..., BasicAuthWebSocketServerProtocol]:
    """
    Protocol factory that enforces HTTP Basic Auth.

    :func:`basic_auth_protocol_factory` is designed to integrate with
    :func:`~websockets.legacy.server.serve` like this::

        serve(
            ...,
            create_protocol=basic_auth_protocol_factory(
                realm="my dev server",
                credentials=("hello", "iloveyou"),
            )
        )

    Args:
        realm: Scope of protection. It should contain only ASCII characters
            because the encoding of non-ASCII characters is undefined.
            Refer to section 2.2 of :rfc:`7235` for details.
        credentials: Hard coded authorized credentials. It can be a
            ``(username, password)`` pair or a list of such pairs.
        check_credentials: Coroutine that verifies credentials.
            It receives ``username`` and ``password`` arguments
            and returns a :class:`bool`. One of ``credentials`` or
            ``check_credentials`` must be provided but not both.
        create_protocol: Factory that creates the protocol. By default, this
            is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced
            by a subclass.
    Raises:
        TypeError: If the ``credentials`` or ``check_credentials`` argument is
            wrong.

    """
    if (credentials is None) == (check_credentials is None):
        raise TypeError("provide either credentials or check_credentials")

    if credentials is not None:
        if is_credentials(credentials):
            credentials_list = [cast(Credentials, credentials)]
        elif isinstance(credentials, Iterable):
            credentials_list = list(cast(Iterable[Credentials], credentials))
            if not all(is_credentials(item) for item in credentials_list):
                raise TypeError(f"invalid credentials argument: {credentials}")
        else:
            raise TypeError(f"invalid credentials argument: {credentials}")

        credentials_dict = dict(credentials_list)

        async def check_credentials(username: str, password: str) -> bool:
            try:
                expected_password = credentials_dict[username]
            except KeyError:
                return False
            return hmac.compare_digest(expected_password, password)

    if create_protocol is None:
        create_protocol = BasicAuthWebSocketServerProtocol

    # Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] |
    # Callable[..., BasicAuthWebSocketServerProtocol]" not callable  [misc]
    create_protocol = cast(
        Callable[..., BasicAuthWebSocketServerProtocol], create_protocol
    )
    return functools.partial(
        create_protocol,
        realm=realm,
        check_credentials=check_credentials,
    )
