import asyncio
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from ..types import DEFAULT_TIMEOUT, Callback, _Hook

if TYPE_CHECKING:
    from .channel import AsyncRealtimeChannel

logger = logging.getLogger(__name__)


class AsyncPush:
    def __init__(
        self,
        channel: "AsyncRealtimeChannel",
        event: str,
        payload: Optional[Dict[str, Any]] = None,
        timeout: int = DEFAULT_TIMEOUT,
    ):
        self.channel = channel
        self.event = event
        self.payload = payload or {}
        self.timeout = timeout
        self.rec_hooks: List[_Hook] = []
        self.ref: Optional[str] = None
        self.ref_event: Optional[str] = None
        self.received_resp: Optional[Dict[str, Any]] = None
        self.sent = False
        self.timeout_task: Optional[asyncio.Task] = None

    async def resend(self):
        self._cancel_ref_event()
        self.ref = ""
        self.ref_event = None
        self.received_resp = None
        self.sent = False
        await self.send()

    async def send(self):
        if self._has_received("timeout"):
            return

        self.start_timeout()
        self.sent = True

        try:
            await self.channel.socket.send(
                {
                    "topic": self.channel.topic,
                    "event": self.event,
                    "payload": self.payload,
                    "ref": self.ref,
                    "join_ref": self.channel.join_push.ref,
                }
            )
        except Exception as e:
            logger.error(f"send push failed: {e}")

    def update_payload(self, payload: Dict[str, Any]):
        self.payload = {**self.payload, **payload}

    def receive(self, status: str, callback: Callback) -> "AsyncPush":
        if self._has_received(status):
            callback(self.received_resp.get("response", {}))

        self.rec_hooks.append(_Hook(status, callback))
        return self

    def start_timeout(self):
        if self.timeout_task:
            return

        self.ref = self.channel.socket._make_ref()
        self.ref_event = self.channel._reply_event_name(self.ref)

        def on_reply(payload, *args):
            self._cancel_ref_event()
            self._cancel_timeout()
            self.received_resp = payload
            self._match_receive(**self.received_resp)

        self.channel._on(self.ref_event, on_reply)

        async def timeout(self):
            await asyncio.sleep(self.timeout)
            self.trigger("timeout", {})

        self.timeout_task = asyncio.create_task(timeout(self))

    def trigger(self, status: str, response: Any):
        if self.ref_event:
            payload = {
                "status": status,
                "response": response,
            }
            self.channel._trigger(self.ref_event, payload)

    def destroy(self):
        self._cancel_ref_event()
        self._cancel_timeout()

    def _cancel_ref_event(self):
        if not self.ref_event:
            return

        self.channel._off(self.ref_event, {})

    def _cancel_timeout(self):
        if not self.timeout_task:
            return

        self.timeout_task.cancel()
        self.timeout_task = None

    def _match_receive(self, status: str, response: Any):
        for hook in self.rec_hooks:
            if hook.status == status:
                hook.callback(response)

    def _has_received(self, status: str):
        return self.received_resp and self.received_resp.get("status") == status
