# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import uuid
from collections import defaultdict
from typing import TYPE_CHECKING

from streamlit import util
from streamlit.runtime.stats import CacheStat, group_stats
from streamlit.runtime.uploaded_file_manager import (
    UploadedFileManager,
    UploadedFileRec,
    UploadFileUrlInfo,
)

if TYPE_CHECKING:
    from collections.abc import Sequence


class MemoryUploadedFileManager(UploadedFileManager):
    """Holds files uploaded by users of the running Streamlit app.
    This class can be used safely from multiple threads simultaneously.
    """

    def __init__(self, upload_endpoint: str):
        self.file_storage: dict[str, dict[str, UploadedFileRec]] = defaultdict(dict)
        self.endpoint = upload_endpoint

    def get_files(
        self, session_id: str, file_ids: Sequence[str]
    ) -> list[UploadedFileRec]:
        """Return a  list of UploadedFileRec for a given sequence of file_ids.

        Parameters
        ----------
        session_id
            The ID of the session that owns the files.
        file_ids
            The sequence of ids associated with files to retrieve.

        Returns
        -------
        List[UploadedFileRec]
            A list of URL UploadedFileRec instances, each instance contains information
            about uploaded file.
        """
        session_storage = self.file_storage[session_id]
        file_recs = []

        for file_id in file_ids:
            file_rec = session_storage.get(file_id, None)
            if file_rec is not None:
                file_recs.append(file_rec)

        return file_recs

    def remove_session_files(self, session_id: str) -> None:
        """Remove all files associated with a given session."""
        self.file_storage.pop(session_id, None)

    def __repr__(self) -> str:
        return util.repr_(self)

    def add_file(
        self,
        session_id: str,
        file: UploadedFileRec,
    ) -> None:
        """
        Safe to call from any thread.

        Parameters
        ----------
        session_id
            The ID of the session that owns the file.
        file
            The file to add.
        """

        self.file_storage[session_id][file.file_id] = file

    def remove_file(self, session_id, file_id):
        """Remove file with given file_id associated with a given session."""
        session_storage = self.file_storage[session_id]
        session_storage.pop(file_id, None)

    def get_upload_urls(
        self, session_id: str, file_names: Sequence[str]
    ) -> list[UploadFileUrlInfo]:
        """Return a list of UploadFileUrlInfo for a given sequence of file_names."""
        result = []
        for _ in file_names:
            file_id = str(uuid.uuid4())
            result.append(
                UploadFileUrlInfo(
                    file_id=file_id,
                    upload_url=f"{self.endpoint}/{session_id}/{file_id}",
                    delete_url=f"{self.endpoint}/{session_id}/{file_id}",
                )
            )
        return result

    def get_stats(self) -> list[CacheStat]:
        """Return the manager's CacheStats.

        Safe to call from any thread.
        """
        # Flatten all files into a single list
        all_files: list[UploadedFileRec] = []
        # Make copy of self.file_storage for thread safety, to be sure
        # that main storage won't be changed form other thread
        file_storage_copy = self.file_storage.copy()

        for session_storage in file_storage_copy.values():
            all_files.extend(session_storage.values())

        stats: list[CacheStat] = [
            CacheStat(
                category_name="UploadedFileManager",
                cache_name="",
                byte_length=len(file.data),
            )
            for file in all_files
        ]
        return group_stats(stats)
