from hypothesis import given, settings
from typing import Any, Dict

import hypothesis.strategies as st
import pytest

from chromadb.api import ServerAPI
from chromadb.config import System
from chromadb.test.conftest import _fastapi_fixture
from chromadb.test.auth.strategies import (
    random_token,
    random_token_transport_header,
    token_test_conf,
)


@settings(max_examples=10)
@given(token_test_conf(), random_token_transport_header(), st.booleans())
def test_fastapi_server_token_authn_allows_when_it_should_allow(
    tconf: Dict[str, Any], transport_header: str, persistence: bool
) -> None:
    for user in tconf["users"]:
        for token in user["tokens"]:
            api = _fastapi_fixture(
                is_persistent=persistence,
                chroma_auth_token_transport_header=transport_header,
                chroma_server_authn_provider="chromadb.auth.token_authn.TokenAuthenticationServerProvider",
                chroma_server_authn_credentials_file=tconf["filename"],
                chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
                chroma_client_auth_credentials=token,
            )
            _sys: System = next(api)
            _sys.reset_state()
            _api = _sys.instance(ServerAPI)
            _api.heartbeat()
            assert _api.list_collections() == []


@settings(max_examples=10)
@given(
    token_test_conf(), random_token(), random_token_transport_header(), st.booleans()
)
def test_fastapi_server_token_authn_rejects_when_it_should_reject(
    tconf: Dict[str, Any],
    unauthorized_token: str,
    transport_header: str,
    persistence: bool,
) -> None:
    # Make sure we actually have an unauthorized token
    for user in tconf["users"]:
        for t in user["tokens"]:
            if t == unauthorized_token:
                return

    for user in tconf["users"]:
        for t in user["tokens"]:
            _api = _fastapi_fixture(
                is_persistent=persistence,
                chroma_auth_token_transport_header=transport_header,
                chroma_server_authn_provider="chromadb.auth.token_authn.TokenAuthenticationServerProvider",
                chroma_server_authn_credentials_file=tconf["filename"],
                chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
                chroma_client_auth_credentials=unauthorized_token,
            )
            _sys: System = next(_api)
            _sys.reset_state()
            api = _sys.instance(ServerAPI)
            api.heartbeat()
            with pytest.raises(Exception) as e:
                api.list_collections()

            assert "Forbidden" in str(e)
