"""This module provides a thin wrapper around the OpenTelemetry propagate API to allow
the OpenTelemetry contexts (and therefore Logfire contexts) to be transferred between
different code running in different threads, processes or even services.

In general, you should not need to use this module since Logfire will automatically
patch [`ThreadPoolExecutor`][concurrent.futures.ThreadPoolExecutor] and
[`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor] to carry over the context.
And existing plugins exist to propagate the context with
[requests](https://pypi.org/project/opentelemetry-instrumentation-requests/) and
[httpx](https://pypi.org/project/opentelemetry-instrumentation-httpx/).
"""  # noqa: D205

from __future__ import annotations

from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Iterator, Mapping

from opentelemetry import context as otel_context, propagate, trace
from opentelemetry.propagators.textmap import TextMapPropagator

import logfire
from logfire._internal.stack_info import warn_at_user_stacklevel

# anything that can be used to carry context, e.g. Headers or a dict
ContextCarrier = Mapping[str, Any]

__all__ = 'get_context', 'attach_context', 'ContextCarrier'


def get_context() -> ContextCarrier:
    """Create a new empty carrier dict and inject context into it.

    Returns:
        A new dict with the context injected into it.

    Usage:

    ```py
    from logfire.propagate import get_context, attach_context

    logfire_context = get_context()

    ...

    # later on in another thread, process or service
    with attach_context(logfire_context):
        ...
    ```

    You could also inject context into an existing mapping like headers with:

    ```py
    from logfire.propagate import get_context

    existing_headers = {'X-Foobar': 'baz'}
    existing_headers.update(get_context())
    ...
    ```
    """
    carrier: ContextCarrier = {}
    propagate.inject(carrier)
    return carrier


@contextmanager
def attach_context(carrier: ContextCarrier, *, third_party: bool = False) -> Iterator[None]:
    """Attach a context as generated by [`get_context`][logfire.propagate.get_context] to the current execution context.

    Since `attach_context` is a context manager, it restores the previous context when exiting.

    Set `third_party` to `True` if using this inside a library intended to be used by others.
    This will respect the [`distributed_tracing` argument of `logfire.configure()`][logfire.configure(distributed_tracing)],
    so users will be warned about unintentional distributed tracing by default and they can suppress it.
    See [Unintentional Distributed Tracing](https://logfire.pydantic.dev/docs/how-to-guides/distributed-tracing/#unintentional-distributed-tracing) for more information.
    """
    # capture the current context to restore it later
    old_context = otel_context.get_current()
    propagator = propagate.get_global_textmap()
    if not third_party:
        while isinstance(propagator, (WarnOnExtractTraceContextPropagator, NoExtractTraceContextPropagator)):
            propagator = propagator.wrapped
    new_context = propagator.extract(carrier=carrier)
    try:
        otel_context.attach(new_context)
        yield
    finally:
        otel_context.attach(old_context)


@dataclass
class WrapperPropagator(TextMapPropagator):
    """Helper base class to wrap another propagator."""

    wrapped: TextMapPropagator

    def extract(self, *args: Any, **kwargs: Any) -> otel_context.Context:
        return self.wrapped.extract(*args, **kwargs)

    def inject(self, *args: Any, **kwargs: Any):
        return self.wrapped.inject(*args, **kwargs)

    @property
    def fields(self):
        return self.wrapped.fields


class NoExtractTraceContextPropagator(WrapperPropagator):
    """A propagator that ignores any trace context that was extracted by the wrapped propagator.

    Used when `logfire.configure(distributed_tracing=False)` is called.
    """

    def extract(
        self,
        carrier: Any,
        context: otel_context.Context | None = None,
        *args: Any,
        **kwargs: Any,
    ) -> otel_context.Context:
        result = super().extract(carrier, context, *args, **kwargs)
        if result == context:  # pragma: no cover
            # Optimization: nothing was extracted by the wrapped propagator
            return result
        return trace.set_span_in_context(trace.get_current_span(context), result)


@dataclass
class WarnOnExtractTraceContextPropagator(WrapperPropagator):
    """A propagator that warns the first time that trace context is extracted by the wrapped propagator.

    Used when `logfire.configure(distributed_tracing=None)` is called. This is the default behavior.
    """

    warned: bool = False

    def extract(
        self,
        carrier: Any,
        context: otel_context.Context | None = None,
        *args: Any,
        **kwargs: Any,
    ) -> otel_context.Context:
        result = super().extract(carrier, context, *args, **kwargs)
        if not self.warned and result != context and trace.get_current_span(context) != trace.get_current_span(result):
            self.warned = True
            message = (
                'Found propagated trace context. See '
                'https://logfire.pydantic.dev/docs/how-to-guides/distributed-tracing/#unintentional-distributed-tracing.'
            )
            warn_at_user_stacklevel(message, RuntimeWarning)
            logfire.warn(message)
        return result
