# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import hashlib
import io
import os
import re
from pathlib import Path

from streamlit import runtime
from streamlit.runtime import caching

# Regular expression to match the SRT timestamp format
# It matches the
# "hours:minutes:seconds,milliseconds --> hours:minutes:seconds,milliseconds" format
SRT_VALIDATION_REGEX = r"\d{2}:\d{2}:\d{2},\d{3} --> \d{2}:\d{2}:\d{2},\d{3}"

SRT_CONVERSION_REGEX = r"(\d{2}:\d{2}:\d{2}),(\d{3})"

SUBTITLE_ALLOWED_FORMATS = (".srt", ".vtt")


def _is_srt(stream: str | io.BytesIO | bytes) -> bool:
    # Handle raw bytes
    if isinstance(stream, bytes):
        stream = io.BytesIO(stream)

    # Convert str to io.BytesIO if 'stream' is a string
    if isinstance(stream, str):
        stream = io.BytesIO(stream.encode("utf-8"))

    # Set the stream position to the beginning in case it's been moved
    stream.seek(0)

    # Read enough bytes to reliably check for SRT patterns
    # This might be adjusted, but 33 bytes should be enough to read the first numeric
    # line, the full timestamp line, and a bit of the next line
    header = stream.read(33)

    try:
        header_str = header.decode("utf-8").strip()  # Decode and strip whitespace
    except UnicodeDecodeError:
        # If it's not valid utf-8, it's probably not a valid SRT file
        return False

    # Split the header into lines and process them
    lines = header_str.split("\n")

    # Check for the pattern of an SRT file: digit(s), newline, timestamp
    if len(lines) >= 2 and lines[0].isdigit():
        match = re.search(SRT_VALIDATION_REGEX, lines[1])
        if match:
            return True

    return False


def _srt_to_vtt(srt_data: str | bytes) -> bytes:
    """
    Convert subtitles from SubRip (.srt) format to WebVTT (.vtt) format.
    This function accepts the content of the .srt file either as a string
    or as a BytesIO stream.
    Parameters
    ----------
    srt_data : str or bytes
        The content of the .srt file as a string or a bytes stream.
    Returns
    -------
    bytes
        The content converted into .vtt format.
    """

    # If the input is a bytes stream, convert it to a string
    if isinstance(srt_data, bytes):
        # Decode the bytes to a UTF-8 string
        try:
            srt_data = srt_data.decode("utf-8")
        except UnicodeDecodeError as e:
            raise ValueError("Could not decode the input stream as UTF-8.") from e
    if not isinstance(srt_data, str):
        # If it's not a string by this point, something is wrong.
        raise TypeError(
            f"Input must be a string or a bytes stream, not {type(srt_data)}."
        )

    # Replace SubRip timing with WebVTT timing
    vtt_data = re.sub(SRT_CONVERSION_REGEX, r"\1.\2", srt_data)

    # Add WebVTT file header
    vtt_content = "WEBVTT\n\n" + vtt_data
    # Convert the vtt content to bytes
    vtt_content = vtt_content.strip().encode("utf-8")

    return vtt_content


def _handle_string_or_path_data(data_or_path: str | Path) -> bytes:
    """Handles string data, either as a file path or raw content."""
    if os.path.isfile(data_or_path):
        path = Path(data_or_path)
        file_extension = path.suffix.lower()

        if file_extension not in SUBTITLE_ALLOWED_FORMATS:
            raise ValueError(
                f"Incorrect subtitle format {file_extension}. Subtitles must be in "
                f"one of the following formats: {', '.join(SUBTITLE_ALLOWED_FORMATS)}"
            )
        with open(data_or_path, "rb") as file:
            content = file.read()
        return _srt_to_vtt(content) if file_extension == ".srt" else content
    elif isinstance(data_or_path, Path):
        raise ValueError(f"File {data_or_path} does not exist.")

    content_string = data_or_path.strip()

    if content_string.startswith("WEBVTT") or content_string == "":
        return content_string.encode("utf-8")
    elif _is_srt(content_string):
        return _srt_to_vtt(content_string)
    raise ValueError("The provided string neither matches valid VTT nor SRT format.")


def _handle_stream_data(stream: io.BytesIO) -> bytes:
    """Handles io.BytesIO data, converting SRT to VTT content if needed."""
    stream.seek(0)
    stream_data = stream.getvalue()
    return _srt_to_vtt(stream_data) if _is_srt(stream) else stream_data


def _handle_bytes_data(data: bytes) -> bytes:
    """Handles io.BytesIO data, converting SRT to VTT content if needed."""
    return _srt_to_vtt(data) if _is_srt(data) else data


def process_subtitle_data(
    coordinates: str,
    data: str | bytes | Path | io.BytesIO,
    label: str,
) -> str:
    # Determine the type of data and process accordingly
    if isinstance(data, (str, Path)):
        subtitle_data = _handle_string_or_path_data(data)
    elif isinstance(data, io.BytesIO):
        subtitle_data = _handle_stream_data(data)
    elif isinstance(data, bytes):
        subtitle_data = _handle_bytes_data(data)
    else:
        raise TypeError(f"Invalid binary data format for subtitle: {type(data)}.")

    if runtime.exists():
        filename = hashlib.md5(label.encode()).hexdigest()
        # Save the processed data and return the file URL
        file_url = runtime.get_instance().media_file_mgr.add(
            path_or_data=subtitle_data,
            mimetype="text/vtt",
            coordinates=coordinates,
            file_name=f"{filename}.vtt",
        )
        caching.save_media_data(subtitle_data, "text/vtt", coordinates)
        return file_url
    else:
        # When running in "raw mode", we can't access the MediaFileManager.
        return ""
