import asyncio
import csv
import json
import time
import typing
from typing import Optional

import requests
from fastavro import parse_schema, reader, writer

from . import EmbedResponse, EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse, ApiMeta, \
    EmbedByTypeResponseEmbeddings, ApiMetaBilledUnits, EmbedJob, CreateEmbedJobResponse, Dataset
from .datasets import DatasetsCreateResponse, DatasetsGetResponse
from .overrides import get_fields


def get_terminal_states():
    return get_success_states() | get_failed_states()


def get_success_states():
    return {"complete", "validated"}


def get_failed_states():
    return {"unknown", "failed", "skipped", "cancelled", "failed"}


def get_id(
        awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]):
    return getattr(awaitable, "job_id", None) or getattr(awaitable, "id", None) or getattr(
        getattr(awaitable, "dataset", None), "id", None)


def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse]):
    return getattr(awaitable, "status", None) or getattr(getattr(awaitable, "dataset", None), "validation_status", None)


def get_job(cohere: typing.Any,
            awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> \
        typing.Union[
            EmbedJob, DatasetsGetResponse]:
    if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse":
        return cohere.embed_jobs.get(id=get_id(awaitable))
    elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse":
        return cohere.datasets.get(id=get_id(awaitable))
    else:
        raise ValueError(f"Unexpected awaitable type {awaitable}")


async def async_get_job(cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse]) -> \
        typing.Union[
            EmbedJob, DatasetsGetResponse]:
    if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse":
        return await cohere.embed_jobs.get(id=get_id(awaitable))
    elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse":
        return await cohere.datasets.get(id=get_id(awaitable))
    else:
        raise ValueError(f"Unexpected awaitable type {awaitable}")


def get_failure_reason(job: typing.Union[EmbedJob, DatasetsGetResponse]) -> Optional[str]:
    if isinstance(job, EmbedJob):
        return f"Embed job {job.job_id} failed with status {job.status}"
    elif isinstance(job, DatasetsGetResponse):
        return f"Dataset creation failed with status {job.dataset.validation_status} and error : {job.dataset.validation_error}"
    return None


@typing.overload
def wait(
        cohere: typing.Any,
        awaitable: CreateEmbedJobResponse,
        timeout: Optional[float] = None,
        interval: float = 10,
) -> EmbedJob:
    ...


@typing.overload
def wait(
        cohere: typing.Any,
        awaitable: DatasetsCreateResponse,
        timeout: Optional[float] = None,
        interval: float = 10,
) -> DatasetsGetResponse:
    ...


def wait(
        cohere: typing.Any,
        awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse],
        timeout: Optional[float] = None,
        interval: float = 2,
) -> typing.Union[EmbedJob, DatasetsGetResponse]:
    start_time = time.time()
    terminal_states = get_terminal_states()
    failed_states = get_failed_states()

    job = get_job(cohere, awaitable)
    while get_validation_status(job) not in terminal_states:
        if timeout is not None and time.time() - start_time > timeout:
            raise TimeoutError(f"wait timed out after {timeout} seconds")

        time.sleep(interval)
        print("...")

        job = get_job(cohere, awaitable)

    if get_validation_status(job) in failed_states:
        raise Exception(get_failure_reason(job))

    return job


@typing.overload
async def async_wait(
        cohere: typing.Any,
        awaitable: CreateEmbedJobResponse,
        timeout: Optional[float] = None,
        interval: float = 10,
) -> EmbedJob:
    ...


@typing.overload
async def async_wait(
        cohere: typing.Any,
        awaitable: DatasetsCreateResponse,
        timeout: Optional[float] = None,
        interval: float = 10,
) -> DatasetsGetResponse:
    ...


async def async_wait(
        cohere: typing.Any,
        awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse],
        timeout: Optional[float] = None,
        interval: float = 10,
) -> typing.Union[EmbedJob, DatasetsGetResponse]:
    start_time = time.time()
    terminal_states = get_terminal_states()
    failed_states = get_failed_states()

    job = await async_get_job(cohere, awaitable)
    while get_validation_status(job) not in terminal_states:
        if timeout is not None and time.time() - start_time > timeout:
            raise TimeoutError(f"wait timed out after {timeout} seconds")

        await asyncio.sleep(interval)
        print("...")

        job = await async_get_job(cohere, awaitable)

    if get_validation_status(job) in failed_states:
        raise Exception(get_failure_reason(job))

    return job


def sum_fields_if_not_none(obj: typing.Any, field: str) -> Optional[int]:
    non_none = [getattr(obj, field) for obj in obj if getattr(obj, field) is not None]
    return sum(non_none) if non_none else None


def merge_meta_field(metas: typing.List[ApiMeta]) -> ApiMeta:
    api_version = metas[0].api_version if metas else None
    billed_units = [meta.billed_units for meta in metas]
    input_tokens = sum_fields_if_not_none(billed_units, "input_tokens")
    output_tokens = sum_fields_if_not_none(billed_units, "output_tokens")
    search_units = sum_fields_if_not_none(billed_units, "search_units")
    classifications = sum_fields_if_not_none(billed_units, "classifications")
    warnings = {warning for meta in metas if meta.warnings for warning in meta.warnings}
    return ApiMeta(
        api_version=api_version,
        billed_units=ApiMetaBilledUnits(
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            search_units=search_units,
            classifications=classifications
        ),
        warnings=list(warnings)
    )


def merge_embed_responses(responses: typing.List[EmbedResponse]) -> EmbedResponse:
    meta = merge_meta_field([response.meta for response in responses if response.meta])
    response_id = ", ".join(response.id for response in responses)
    texts = [
        text
        for response in responses
        for text in response.texts
    ]

    if responses[0].response_type == "embeddings_floats":
        embeddings_floats = typing.cast(typing.List[EmbeddingsFloatsEmbedResponse], responses)

        embeddings = [
            embedding
            for embeddings_floats in embeddings_floats
            for embedding in embeddings_floats.embeddings
        ]

        return EmbeddingsFloatsEmbedResponse(
            response_type="embeddings_floats",
            id=response_id,
            texts=texts,
            embeddings=embeddings,
            meta=meta
        )
    else:
        embeddings_type = typing.cast(typing.List[EmbeddingsByTypeEmbedResponse], responses)

        embeddings_by_type = [
            response.embeddings
            for response in embeddings_type
        ]

        # only get set keys from the pydantic model (i.e. exclude fields that are set to 'None')
        fields = [x for x in get_fields(embeddings_type[0].embeddings) if getattr(embeddings_type[0].embeddings, x) is not None]

        merged_dicts = {
            field: [
                embedding
                for embedding_by_type in embeddings_by_type
                for embedding in getattr(embedding_by_type, field)
            ]
            for field in fields
        }

        embeddings_by_type_merged = EmbedByTypeResponseEmbeddings.parse_obj(merged_dicts)

        return EmbeddingsByTypeEmbedResponse(
            response_type="embeddings_by_type",
            id=response_id,
            embeddings=embeddings_by_type_merged,
            texts=texts,
            meta=meta
        )


supported_formats = ["jsonl", "csv", "avro"]


def save_avro(dataset: Dataset, filepath: str):
    if not dataset.schema_:
        raise ValueError("Dataset does not have a schema")
    schema = parse_schema(json.loads(dataset.schema_))
    with open(filepath, "wb") as outfile:
        writer(outfile, schema, dataset_generator(dataset))


def save_jsonl(dataset: Dataset, filepath: str):
    with open(filepath, "w") as outfile:
        for data in dataset_generator(dataset):
            json.dump(data, outfile)
            outfile.write("\n")


def save_csv(dataset: Dataset, filepath: str):
    with open(filepath, "w") as outfile:
        for i, data in enumerate(dataset_generator(dataset)):
            if i == 0:
                writer = csv.DictWriter(outfile, fieldnames=list(data.keys()))
                writer.writeheader()
            writer.writerow(data)


def dataset_generator(dataset: Dataset):
    if not dataset.dataset_parts:
        raise ValueError("Dataset does not have dataset_parts")
    for part in dataset.dataset_parts:
        if not part.url:
            raise ValueError("Dataset part does not have a url")
        resp = requests.get(part.url, stream=True)
        for record in reader(resp.raw): # type: ignore
            yield record


class SdkUtils:

    @staticmethod
    def save_dataset(dataset: Dataset, filepath: str, format: typing.Literal["jsonl", "csv", "avro"] = "jsonl"):
        if format == "jsonl":
            return save_jsonl(dataset, filepath)
        if format == "csv":
            return save_csv(dataset, filepath)
        if format == "avro":
            return save_avro(dataset, filepath)
        raise Exception(f"unsupported format must be one of : {supported_formats}")


class SyncSdkUtils(SdkUtils):
    pass


class AsyncSdkUtils(SdkUtils):
    pass
