# +-----------------------------------------------+
# |                                               |
# |           Give Feedback / Get Help            |
# | https://github.com/BerriAI/litellm/issues/new |
# |                                               |
# +-----------------------------------------------+
#
#  Thank you users! We ❤️ you! - Krrish & Ishaan

import ast
import hashlib
import json
import time
import traceback
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union

from openai.types.audio.transcription_create_params import TranscriptionCreateParams
from openai.types.chat.completion_create_params import (
    CompletionCreateParamsNonStreaming,
    CompletionCreateParamsStreaming,
)
from openai.types.completion_create_params import (
    CompletionCreateParamsNonStreaming as TextCompletionCreateParamsNonStreaming,
)
from openai.types.completion_create_params import (
    CompletionCreateParamsStreaming as TextCompletionCreateParamsStreaming,
)
from openai.types.embedding_create_params import EmbeddingCreateParams
from pydantic import BaseModel

import litellm
from litellm._logging import verbose_logger
from litellm.types.caching import *
from litellm.types.rerank import RerankRequest
from litellm.types.utils import all_litellm_params

from .base_cache import BaseCache
from .disk_cache import DiskCache
from .dual_cache import DualCache  # noqa
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache


def print_verbose(print_statement):
    try:
        verbose_logger.debug(print_statement)
        if litellm.set_verbose:
            print(print_statement)  # noqa
    except Exception:
        pass


class CacheMode(str, Enum):
    default_on = "default_on"
    default_off = "default_off"


#### LiteLLM.Completion / Embedding Cache ####
class Cache:
    def __init__(
        self,
        type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
        mode: Optional[
            CacheMode
        ] = CacheMode.default_on,  # when default_on cache is always on, when default_off cache is opt in
        host: Optional[str] = None,
        port: Optional[str] = None,
        password: Optional[str] = None,
        namespace: Optional[str] = None,
        ttl: Optional[float] = None,
        default_in_memory_ttl: Optional[float] = None,
        default_in_redis_ttl: Optional[float] = None,
        similarity_threshold: Optional[float] = None,
        supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
            "completion",
            "acompletion",
            "embedding",
            "aembedding",
            "atranscription",
            "transcription",
            "atext_completion",
            "text_completion",
            "arerank",
            "rerank",
        ],
        # s3 Bucket, boto3 configuration
        s3_bucket_name: Optional[str] = None,
        s3_region_name: Optional[str] = None,
        s3_api_version: Optional[str] = None,
        s3_use_ssl: Optional[bool] = True,
        s3_verify: Optional[Union[bool, str]] = None,
        s3_endpoint_url: Optional[str] = None,
        s3_aws_access_key_id: Optional[str] = None,
        s3_aws_secret_access_key: Optional[str] = None,
        s3_aws_session_token: Optional[str] = None,
        s3_config: Optional[Any] = None,
        s3_path: Optional[str] = None,
        redis_semantic_cache_use_async=False,
        redis_semantic_cache_embedding_model="text-embedding-ada-002",
        redis_flush_size: Optional[int] = None,
        redis_startup_nodes: Optional[List] = None,
        disk_cache_dir=None,
        qdrant_api_base: Optional[str] = None,
        qdrant_api_key: Optional[str] = None,
        qdrant_collection_name: Optional[str] = None,
        qdrant_quantization_config: Optional[str] = None,
        qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
        **kwargs,
    ):
        """
        Initializes the cache based on the given type.

        Args:
            type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".

            # Redis Cache Args
            host (str, optional): The host address for the Redis cache. Required if type is "redis".
            port (int, optional): The port number for the Redis cache. Required if type is "redis".
            password (str, optional): The password for the Redis cache. Required if type is "redis".
            namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
            ttl (float, optional): The ttl for the Redis cache
            redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
            redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.

            # Qdrant Cache Args
            qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
            qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
            qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
            similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".

            # Disk Cache Args
            disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.

            # S3 Cache Args
            s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
            s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
            s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
            s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
            s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
            s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
            s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
            s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
            s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
            s3_config (dict, optional): The config for the s3 cache. Defaults to None.

            # Common Cache Args
            supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
            **kwargs: Additional keyword arguments for redis.Redis() cache

        Raises:
            ValueError: If an invalid cache type is provided.

        Returns:
            None. Cache is set as a litellm param
        """
        if type == LiteLLMCacheType.REDIS:
            self.cache: BaseCache = RedisCache(
                host=host,
                port=port,
                password=password,
                redis_flush_size=redis_flush_size,
                startup_nodes=redis_startup_nodes,
                **kwargs,
            )
        elif type == LiteLLMCacheType.REDIS_SEMANTIC:
            self.cache = RedisSemanticCache(
                host=host,
                port=port,
                password=password,
                similarity_threshold=similarity_threshold,
                use_async=redis_semantic_cache_use_async,
                embedding_model=redis_semantic_cache_embedding_model,
                **kwargs,
            )
        elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
            self.cache = QdrantSemanticCache(
                qdrant_api_base=qdrant_api_base,
                qdrant_api_key=qdrant_api_key,
                collection_name=qdrant_collection_name,
                similarity_threshold=similarity_threshold,
                quantization_config=qdrant_quantization_config,
                embedding_model=qdrant_semantic_cache_embedding_model,
            )
        elif type == LiteLLMCacheType.LOCAL:
            self.cache = InMemoryCache()
        elif type == LiteLLMCacheType.S3:
            self.cache = S3Cache(
                s3_bucket_name=s3_bucket_name,
                s3_region_name=s3_region_name,
                s3_api_version=s3_api_version,
                s3_use_ssl=s3_use_ssl,
                s3_verify=s3_verify,
                s3_endpoint_url=s3_endpoint_url,
                s3_aws_access_key_id=s3_aws_access_key_id,
                s3_aws_secret_access_key=s3_aws_secret_access_key,
                s3_aws_session_token=s3_aws_session_token,
                s3_config=s3_config,
                s3_path=s3_path,
                **kwargs,
            )
        elif type == LiteLLMCacheType.DISK:
            self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
        if "cache" not in litellm.input_callback:
            litellm.input_callback.append("cache")
        if "cache" not in litellm.success_callback:
            litellm.success_callback.append("cache")
        if "cache" not in litellm._async_success_callback:
            litellm._async_success_callback.append("cache")
        self.supported_call_types = supported_call_types  # default to ["completion", "acompletion", "embedding", "aembedding"]
        self.type = type
        self.namespace = namespace
        self.redis_flush_size = redis_flush_size
        self.ttl = ttl
        self.mode: CacheMode = mode or CacheMode.default_on

        if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
            self.ttl = default_in_memory_ttl

        if (
            self.type == LiteLLMCacheType.REDIS
            or self.type == LiteLLMCacheType.REDIS_SEMANTIC
        ) and default_in_redis_ttl is not None:
            self.ttl = default_in_redis_ttl

        if self.namespace is not None and isinstance(self.cache, RedisCache):
            self.cache.namespace = self.namespace

    def get_cache_key(self, **kwargs) -> str:
        """
        Get the cache key for the given arguments.

        Args:
            **kwargs: kwargs to litellm.completion() or embedding()

        Returns:
            str: The cache key generated from the arguments, or None if no cache key could be generated.
        """
        cache_key = ""
        # verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)

        preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
        if preset_cache_key is not None:
            verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
            return preset_cache_key

        combined_kwargs = self._get_relevant_args_to_use_for_cache_key()
        litellm_param_kwargs = all_litellm_params
        for param in kwargs:
            if param in combined_kwargs:
                param_value: Optional[str] = self._get_param_value(param, kwargs)
                if param_value is not None:
                    cache_key += f"{str(param)}: {str(param_value)}"
            elif (
                param not in litellm_param_kwargs
            ):  # check if user passed in optional param - e.g. top_k
                if (
                    litellm.enable_caching_on_provider_specific_optional_params is True
                ):  # feature flagged for now
                    if kwargs[param] is None:
                        continue  # ignore None params
                    param_value = kwargs[param]
                    cache_key += f"{str(param)}: {str(param_value)}"

        verbose_logger.debug("\nCreated cache key: %s", cache_key)
        hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
        hashed_cache_key = self._add_redis_namespace_to_cache_key(
            hashed_cache_key, **kwargs
        )
        self._set_preset_cache_key_in_kwargs(
            preset_cache_key=hashed_cache_key, **kwargs
        )
        return hashed_cache_key

    def _get_param_value(
        self,
        param: str,
        kwargs: dict,
    ) -> Optional[str]:
        """
        Get the value for the given param from kwargs
        """
        if param == "model":
            return self._get_model_param_value(kwargs)
        elif param == "file":
            return self._get_file_param_value(kwargs)
        return kwargs[param]

    def _get_model_param_value(self, kwargs: dict) -> str:
        """
        Handles getting the value for the 'model' param from kwargs

        1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
        2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
        3. Else use the `model` passed in kwargs
        """
        metadata: Dict = kwargs.get("metadata", {}) or {}
        litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
        metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
        model_group: Optional[str] = metadata.get(
            "model_group"
        ) or metadata_in_litellm_params.get("model_group")
        caching_group = self._get_caching_group(metadata, model_group)
        return caching_group or model_group or kwargs["model"]

    def _get_caching_group(
        self, metadata: dict, model_group: Optional[str]
    ) -> Optional[str]:
        caching_groups: Optional[List] = metadata.get("caching_groups", [])
        if caching_groups:
            for group in caching_groups:
                if model_group in group:
                    return str(group)
        return None

    def _get_file_param_value(self, kwargs: dict) -> str:
        """
        Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
        """
        file = kwargs.get("file")
        metadata = kwargs.get("metadata", {})
        litellm_params = kwargs.get("litellm_params", {})
        return (
            metadata.get("file_checksum")
            or getattr(file, "name", None)
            or metadata.get("file_name")
            or litellm_params.get("file_name")
        )

    def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
        """
        Get the preset cache key from kwargs["litellm_params"]

        We use _get_preset_cache_keys for two reasons

        1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
        2. avoid doing duplicate / repeated work
        """
        if kwargs:
            if "litellm_params" in kwargs:
                return kwargs["litellm_params"].get("preset_cache_key", None)
        return None

    def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
        """
        Set the calculated cache key in kwargs

        This is used to avoid doing duplicate / repeated work

        Placed in kwargs["litellm_params"]
        """
        if kwargs:
            if "litellm_params" in kwargs:
                kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key

    def _get_relevant_args_to_use_for_cache_key(self) -> Set[str]:
        """
        Gets the supported kwargs for each call type and combines them
        """
        chat_completion_kwargs = self._get_litellm_supported_chat_completion_kwargs()
        text_completion_kwargs = self._get_litellm_supported_text_completion_kwargs()
        embedding_kwargs = self._get_litellm_supported_embedding_kwargs()
        transcription_kwargs = self._get_litellm_supported_transcription_kwargs()
        rerank_kwargs = self._get_litellm_supported_rerank_kwargs()
        exclude_kwargs = self._get_kwargs_to_exclude_from_cache_key()

        combined_kwargs = chat_completion_kwargs.union(
            text_completion_kwargs,
            embedding_kwargs,
            transcription_kwargs,
            rerank_kwargs,
        )
        combined_kwargs = combined_kwargs.difference(exclude_kwargs)
        return combined_kwargs

    def _get_litellm_supported_chat_completion_kwargs(self) -> Set[str]:
        """
        Get the litellm supported chat completion kwargs

        This follows the OpenAI API Spec
        """
        all_chat_completion_kwargs = set(
            CompletionCreateParamsNonStreaming.__annotations__.keys()
        ).union(set(CompletionCreateParamsStreaming.__annotations__.keys()))
        return all_chat_completion_kwargs

    def _get_litellm_supported_text_completion_kwargs(self) -> Set[str]:
        """
        Get the litellm supported text completion kwargs

        This follows the OpenAI API Spec
        """
        all_text_completion_kwargs = set(
            TextCompletionCreateParamsNonStreaming.__annotations__.keys()
        ).union(set(TextCompletionCreateParamsStreaming.__annotations__.keys()))
        return all_text_completion_kwargs

    def _get_litellm_supported_rerank_kwargs(self) -> Set[str]:
        """
        Get the litellm supported rerank kwargs
        """
        return set(RerankRequest.model_fields.keys())

    def _get_litellm_supported_embedding_kwargs(self) -> Set[str]:
        """
        Get the litellm supported embedding kwargs

        This follows the OpenAI API Spec
        """
        return set(EmbeddingCreateParams.__annotations__.keys())

    def _get_litellm_supported_transcription_kwargs(self) -> Set[str]:
        """
        Get the litellm supported transcription kwargs

        This follows the OpenAI API Spec
        """
        return set(TranscriptionCreateParams.__annotations__.keys())

    def _get_kwargs_to_exclude_from_cache_key(self) -> Set[str]:
        """
        Get the kwargs to exclude from the cache key
        """
        return set(["metadata"])

    @staticmethod
    def _get_hashed_cache_key(cache_key: str) -> str:
        """
        Get the hashed cache key for the given cache key.

        Use hashlib to create a sha256 hash of the cache key

        Args:
            cache_key (str): The cache key to hash.

        Returns:
            str: The hashed cache key.
        """
        hash_object = hashlib.sha256(cache_key.encode())
        # Hexadecimal representation of the hash
        hash_hex = hash_object.hexdigest()
        verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
        return hash_hex

    def _add_redis_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
        """
        If a redis namespace is provided, add it to the cache key

        Args:
            hash_hex (str): The hashed cache key.
            **kwargs: Additional keyword arguments.

        Returns:
            str: The final hashed cache key with the redis namespace.
        """
        namespace = kwargs.get("metadata", {}).get("redis_namespace") or self.namespace
        if namespace:
            hash_hex = f"{namespace}:{hash_hex}"
        verbose_logger.debug("Final hashed key: %s", hash_hex)
        return hash_hex

    def generate_streaming_content(self, content):
        chunk_size = 5  # Adjust the chunk size as needed
        for i in range(0, len(content), chunk_size):
            yield {
                "choices": [
                    {
                        "delta": {
                            "role": "assistant",
                            "content": content[i : i + chunk_size],
                        }
                    }
                ]
            }
            time.sleep(0.02)

    def _get_cache_logic(
        self,
        cached_result: Optional[Any],
        max_age: Optional[float],
    ):
        """
        Common get cache logic across sync + async implementations
        """
        # Check if a timestamp was stored with the cached response
        if (
            cached_result is not None
            and isinstance(cached_result, dict)
            and "timestamp" in cached_result
        ):
            timestamp = cached_result["timestamp"]
            current_time = time.time()

            # Calculate age of the cached response
            response_age = current_time - timestamp

            # Check if the cached response is older than the max-age
            if max_age is not None and response_age > max_age:
                return None  # Cached response is too old

            # If the response is fresh, or there's no max-age requirement, return the cached response
            # cached_response is in `b{} convert it to ModelResponse
            cached_response = cached_result.get("response")
            try:
                if isinstance(cached_response, dict):
                    pass
                else:
                    cached_response = json.loads(
                        cached_response  # type: ignore
                    )  # Convert string to dictionary
            except Exception:
                cached_response = ast.literal_eval(cached_response)  # type: ignore
            return cached_response
        return cached_result

    def get_cache(self, **kwargs):
        """
        Retrieves the cached result for the given arguments.

        Args:
            *args: args to litellm.completion() or embedding()
            **kwargs: kwargs to litellm.completion() or embedding()

        Returns:
            The cached result if it exists, otherwise None.
        """
        try:  # never block execution
            if self.should_use_cache(**kwargs) is not True:
                return
            messages = kwargs.get("messages", [])
            if "cache_key" in kwargs:
                cache_key = kwargs["cache_key"]
            else:
                cache_key = self.get_cache_key(**kwargs)
            if cache_key is not None:
                cache_control_args = kwargs.get("cache", {})
                max_age = cache_control_args.get(
                    "s-max-age", cache_control_args.get("s-maxage", float("inf"))
                )
                cached_result = self.cache.get_cache(cache_key, messages=messages)
                return self._get_cache_logic(
                    cached_result=cached_result, max_age=max_age
                )
        except Exception:
            print_verbose(f"An exception occurred: {traceback.format_exc()}")
            return None

    async def async_get_cache(self, **kwargs):
        """
        Async get cache implementation.

        Used for embedding calls in async wrapper
        """

        try:  # never block execution
            if self.should_use_cache(**kwargs) is not True:
                return

            kwargs.get("messages", [])
            if "cache_key" in kwargs:
                cache_key = kwargs["cache_key"]
            else:
                cache_key = self.get_cache_key(**kwargs)
            if cache_key is not None:
                cache_control_args = kwargs.get("cache", {})
                max_age = cache_control_args.get(
                    "s-max-age", cache_control_args.get("s-maxage", float("inf"))
                )
                cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
                return self._get_cache_logic(
                    cached_result=cached_result, max_age=max_age
                )
        except Exception:
            print_verbose(f"An exception occurred: {traceback.format_exc()}")
            return None

    def _add_cache_logic(self, result, **kwargs):
        """
        Common implementation across sync + async add_cache functions
        """
        try:
            if "cache_key" in kwargs:
                cache_key = kwargs["cache_key"]
            else:
                cache_key = self.get_cache_key(**kwargs)
            if cache_key is not None:
                if isinstance(result, BaseModel):
                    result = result.model_dump_json()

                ## DEFAULT TTL ##
                if self.ttl is not None:
                    kwargs["ttl"] = self.ttl
                ## Get Cache-Controls ##
                _cache_kwargs = kwargs.get("cache", None)
                if isinstance(_cache_kwargs, dict):
                    for k, v in _cache_kwargs.items():
                        if k == "ttl":
                            kwargs["ttl"] = v

                cached_data = {"timestamp": time.time(), "response": result}
                return cache_key, cached_data, kwargs
            else:
                raise Exception("cache key is None")
        except Exception as e:
            raise e

    def add_cache(self, result, **kwargs):
        """
        Adds a result to the cache.

        Args:
            *args: args to litellm.completion() or embedding()
            **kwargs: kwargs to litellm.completion() or embedding()

        Returns:
            None
        """
        try:
            if self.should_use_cache(**kwargs) is not True:
                return
            cache_key, cached_data, kwargs = self._add_cache_logic(
                result=result, **kwargs
            )
            self.cache.set_cache(cache_key, cached_data, **kwargs)
        except Exception as e:
            verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")

    async def async_add_cache(self, result, **kwargs):
        """
        Async implementation of add_cache
        """
        try:
            if self.should_use_cache(**kwargs) is not True:
                return
            if self.type == "redis" and self.redis_flush_size is not None:
                # high traffic - fill in results in memory and then flush
                await self.batch_cache_write(result, **kwargs)
            else:
                cache_key, cached_data, kwargs = self._add_cache_logic(
                    result=result, **kwargs
                )

                await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
        except Exception as e:
            verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")

    async def async_add_cache_pipeline(self, result, **kwargs):
        """
        Async implementation of add_cache for Embedding calls

        Does a bulk write, to prevent using too many clients
        """
        try:
            if self.should_use_cache(**kwargs) is not True:
                return

            # set default ttl if not set
            if self.ttl is not None:
                kwargs["ttl"] = self.ttl

            cache_list = []
            for idx, i in enumerate(kwargs["input"]):
                preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
                kwargs["cache_key"] = preset_cache_key
                embedding_response = result.data[idx]
                cache_key, cached_data, kwargs = self._add_cache_logic(
                    result=embedding_response,
                    **kwargs,
                )
                cache_list.append((cache_key, cached_data))

            await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
            # if async_set_cache_pipeline:
            #     await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
            # else:
            #     tasks = []
            #     for val in cache_list:
            #         tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
            #     await asyncio.gather(*tasks)
        except Exception as e:
            verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")

    def should_use_cache(self, **kwargs):
        """
        Returns true if we should use the cache for LLM API calls

        If cache is default_on then this is True
        If cache is default_off then this is only true when user has opted in to use cache
        """
        if self.mode == CacheMode.default_on:
            return True

        # when mode == default_off -> Cache is opt in only
        _cache = kwargs.get("cache", None)
        verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
        if _cache and isinstance(_cache, dict):
            if _cache.get("use-cache", False) is True:
                return True
        return False

    async def batch_cache_write(self, result, **kwargs):
        cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
        await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)

    async def ping(self):
        cache_ping = getattr(self.cache, "ping")
        if cache_ping:
            return await cache_ping()
        return None

    async def delete_cache_keys(self, keys):
        cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
        if cache_delete_cache_keys:
            return await cache_delete_cache_keys(keys)
        return None

    async def disconnect(self):
        if hasattr(self.cache, "disconnect"):
            await self.cache.disconnect()

    def _supports_async(self) -> bool:
        """
        Internal method to check if the cache type supports async get/set operations

        Only S3 Cache Does NOT support async operations

        """
        if self.type and self.type == LiteLLMCacheType.S3:
            return False
        return True


def enable_cache(
    type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
    host: Optional[str] = None,
    port: Optional[str] = None,
    password: Optional[str] = None,
    supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
        "completion",
        "acompletion",
        "embedding",
        "aembedding",
        "atranscription",
        "transcription",
        "atext_completion",
        "text_completion",
        "arerank",
        "rerank",
    ],
    **kwargs,
):
    """
    Enable cache with the specified configuration.

    Args:
        type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
        host (Optional[str]): The host address of the cache server. Defaults to None.
        port (Optional[str]): The port number of the cache server. Defaults to None.
        password (Optional[str]): The password for the cache server. Defaults to None.
        supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
            The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
        **kwargs: Additional keyword arguments.

    Returns:
        None

    Raises:
        None
    """
    print_verbose("LiteLLM: Enabling Cache")
    if "cache" not in litellm.input_callback:
        litellm.input_callback.append("cache")
    if "cache" not in litellm.success_callback:
        litellm.success_callback.append("cache")
    if "cache" not in litellm._async_success_callback:
        litellm._async_success_callback.append("cache")

    if litellm.cache is None:
        litellm.cache = Cache(
            type=type,
            host=host,
            port=port,
            password=password,
            supported_call_types=supported_call_types,
            **kwargs,
        )
    print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
    print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")


def update_cache(
    type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
    host: Optional[str] = None,
    port: Optional[str] = None,
    password: Optional[str] = None,
    supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
        "completion",
        "acompletion",
        "embedding",
        "aembedding",
        "atranscription",
        "transcription",
        "atext_completion",
        "text_completion",
        "arerank",
        "rerank",
    ],
    **kwargs,
):
    """
    Update the cache for LiteLLM.

    Args:
        type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
        host (Optional[str]): The host of the cache. Defaults to None.
        port (Optional[str]): The port of the cache. Defaults to None.
        password (Optional[str]): The password for the cache. Defaults to None.
        supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
            The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
        **kwargs: Additional keyword arguments for the cache.

    Returns:
        None

    """
    print_verbose("LiteLLM: Updating Cache")
    litellm.cache = Cache(
        type=type,
        host=host,
        port=port,
        password=password,
        supported_call_types=supported_call_types,
        **kwargs,
    )
    print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
    print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")


def disable_cache():
    """
    Disable the cache used by LiteLLM.

    This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.

    Parameters:
    None

    Returns:
    None
    """
    from contextlib import suppress

    print_verbose("LiteLLM: Disabling Cache")
    with suppress(ValueError):
        litellm.input_callback.remove("cache")
        litellm.success_callback.remove("cache")
        litellm._async_success_callback.remove("cache")

    litellm.cache = None
    print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")
