import os
import logging
import re

from sympy.external import import_module
from sympy.parsing.latex.lark.transformer import TransformToSymPyExpr

_lark = import_module("lark")


class LarkLaTeXParser:
    r"""Class for converting input `\mathrm{\LaTeX}` strings into SymPy Expressions.
    It holds all the necessary internal data for doing so, and exposes hooks for
    customizing its behavior.

    Parameters
    ==========

    print_debug_output : bool, optional

        If set to ``True``, prints debug output to the logger. Defaults to ``False``.

    transform : bool, optional

        If set to ``True``, the class runs the Transformer class on the parse tree
        generated by running ``Lark.parse`` on the input string. Defaults to ``True``.

        Setting it to ``False`` can help with debugging the `\mathrm{\LaTeX}` grammar.

    grammar_file : str, optional

        The path to the grammar file that the parser should use. If set to ``None``,
        it uses the default grammar, which is in ``grammar/latex.lark``, relative to
        the ``sympy/parsing/latex/lark/`` directory.

    transformer : str, optional

        The name of the Transformer class to use. If set to ``None``, it uses the
        default transformer class, which is :py:func:`TransformToSymPyExpr`.

    """
    def __init__(self, print_debug_output=False, transform=True, grammar_file=None, transformer=None):
        grammar_dir_path = os.path.join(os.path.dirname(__file__), "grammar/")

        if grammar_file is None:
            with open(os.path.join(grammar_dir_path, "latex.lark"), encoding="utf-8") as f:
                latex_grammar = f.read()
        else:
            with open(grammar_file, encoding="utf-8") as f:
                latex_grammar = f.read()

        self.parser = _lark.Lark(
            latex_grammar,
            source_path=grammar_dir_path,
            parser="earley",
            start="latex_string",
            lexer="auto",
            ambiguity="explicit",
            propagate_positions=False,
            maybe_placeholders=False,
            keep_all_tokens=True)

        self.print_debug_output = print_debug_output
        self.transform_expr = transform

        if transformer is None:
            self.transformer = TransformToSymPyExpr()
        else:
            self.transformer = transformer()

    def doparse(self, s: str):
        if self.print_debug_output:
            _lark.logger.setLevel(logging.DEBUG)

        parse_tree = self.parser.parse(s)

        if not self.transform_expr:
            # exit early and return the parse tree
            _lark.logger.debug("expression = %s", s)
            _lark.logger.debug(parse_tree)
            _lark.logger.debug(parse_tree.pretty())
            return parse_tree

        if self.print_debug_output:
            # print this stuff before attempting to run the transformer
            _lark.logger.debug("expression = %s", s)
            # print the `parse_tree` variable
            _lark.logger.debug(parse_tree.pretty())

        sympy_expression = self.transformer.transform(parse_tree)

        if self.print_debug_output:
            _lark.logger.debug("SymPy expression = %s", sympy_expression)

        return sympy_expression


if _lark is not None:
    _lark_latex_parser = LarkLaTeXParser()


def parse_latex_lark(s: str):
    """
    Experimental LaTeX parser using Lark.

    This function is still under development and its API may change with the
    next releases of SymPy.
    """
    if _lark is None:
        raise ImportError("Lark is probably not installed")
    return _lark_latex_parser.doparse(s)


def _pretty_print_lark_trees(tree, indent=0, show_expr=True):
    if isinstance(tree, _lark.Token):
        return tree.value

    data = str(tree.data)

    is_expr = data.startswith("expression")

    if is_expr:
        data = re.sub(r"^expression", "E", data)

    is_ambig = (data == "_ambig")

    if is_ambig:
        new_indent = indent + 2
    else:
        new_indent = indent

    output = ""
    show_node = not is_expr or show_expr

    if show_node:
        output += str(data) + "("

    if is_ambig:
        output += "\n" + "\n".join([" " * new_indent + _pretty_print_lark_trees(i, new_indent, show_expr) for i in tree.children])
    else:
        output += ",".join([_pretty_print_lark_trees(i, new_indent, show_expr) for i in tree.children])

    if show_node:
        output += ")"

    return output
