# mypy: allow-untyped-defs
import itertools
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Tuple

from sympy import Integer

from .. import metrics
from ..runtime.hints import DeviceProperties
from ..scheduler import SchedulerNode
from ..utils import ceildiv, Placeholder
from ..virtualized import V
from .common import IndentedBuffer, Kernel
from .triton import gen_common_triton_imports, TritonKernel
from .triton_utils import config_of, signature_to_meta


@dataclass
class PartitionState:
    partitions: List[
        List[Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]]
    ]
    cur_partition: List[
        Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]
    ]
    cur_count: int

    def finalize(self):
        if self.cur_partition:
            self.partitions.append(self.cur_partition)


class ForeachKernel(Kernel):
    MAX_NUM_ARGS = 250  # number where I would no longer get triton errors

    @staticmethod
    def _update_partition(partition_state, node_rw_count, node_info):
        if partition_state.cur_count + node_rw_count > ForeachKernel.MAX_NUM_ARGS:
            partition_state.partitions.append(partition_state.cur_partition)
            partition_state.cur_partition = [node_info]
            partition_state.cur_count = node_rw_count
        else:
            partition_state.cur_count += node_rw_count
            partition_state.cur_partition.append(node_info)

    @staticmethod
    def horizontal_partition(subkernel_nodes, triton_scheduling):
        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
        for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
        (read/writes) and to have the same 2D or 1D blocking strategy."""
        assert len(subkernel_nodes) >= 1

        partition_state_1d = PartitionState([], [], 0)
        yelem_to_partition_state_2d: Dict[Integer, PartitionState] = defaultdict(
            lambda: PartitionState([], [], 0)
        )

        for node in subkernel_nodes:
            fused_nodes = node.get_nodes()
            _, (numel, rnumel) = max(
                fused_nodes, key=lambda x: int(x.is_reduction())
            ).group
            tiled_groups = triton_scheduling.select_tiling(fused_nodes, numel, rnumel)
            node_info = fused_nodes, tiled_groups, numel, rnumel

            read_writes = node.read_writes
            read_write_count = len(read_writes.reads) + len(read_writes.writes)

            if tiled_groups[1] == 1:
                ForeachKernel._update_partition(
                    partition_state_1d, read_write_count, node_info
                )
            else:
                y_elem = tiled_groups[0]
                partition_state_2d = yelem_to_partition_state_2d[y_elem]
                ForeachKernel._update_partition(
                    partition_state_2d, read_write_count, node_info
                )

        partition_state_1d.finalize()
        all_partitions = partition_state_1d.partitions
        for partition_state_2d in yelem_to_partition_state_2d.values():
            partition_state_2d.finalize()
            all_partitions.extend(partition_state_2d.partitions)

        return all_partitions

    def __init__(self):
        super().__init__()
        self.blocking_2d = False
        self.block_size_1d = 1024  # Try tuning this value
        self.block_size_2d = 32
        self.num_warps = 8
        self.sub_kernels = []
        self.iter_vars_count = itertools.count()
        self.x_block_count = 0
        self.y_block_count = 0

    def get_block_size(self):
        if self.blocking_2d:
            return self.block_size_2d
        else:
            return self.block_size_1d

    @staticmethod
    def codegen_pid_offsets(code, block_count, lower_bound, prefix):
        if block_count == 0:
            code.splice(f"{prefix}pid_offset = {prefix}pid")
        else:
            code.splice(f"{prefix}pid_offset = {prefix}pid - {lower_bound}")

    def codegen_pid_range(self, code, x_elems):
        num_x_blocks = ceildiv(x_elems, self.get_block_size())
        upper_bound_x_pid = self.x_block_count + num_x_blocks
        lower_bound_x_pid = self.x_block_count

        if self.x_block_count == 0:
            cond = "if"
        else:
            cond = "elif"

        x_pid_bounds_check = (
            f"xpid >= {lower_bound_x_pid} and xpid < {upper_bound_x_pid}"
        )
        code.splice(f"{cond} {x_pid_bounds_check}:")

        with code.indent():
            ForeachKernel.codegen_pid_offsets(
                code, num_x_blocks, lower_bound_x_pid, "x"
            )
            self.x_block_count += num_x_blocks

    def create_sub_kernel(self, *groups, index_dtype, mutations, reduction_hint):
        sub_kernel = TritonKernel(
            *groups,
            index_dtype=index_dtype,
            mutations=mutations,
            pid_cache={
                "tl.program_id(0)": "xpid_offset",
                "tl.program_id(1)": "ypid",
            },
            reduction_hint=reduction_hint,
        )
        if self.blocking_2d:
            assert len(groups) == 3

        self.blocking_2d |= groups[1] != 1 and len(groups) == 3
        metrics.generated_kernel_count -= 1
        sub_kernel.args = self.args
        sub_kernel.iter_vars_count = self.iter_vars_count
        sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
        self.sub_kernels.append(sub_kernel)
        return sub_kernel

    def jit_lines(self):
        can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels)
        size_dtype = "tl.int32" if can_use_32bit else "tl.int64"
        _, _, signature, _ = self.args.python_argdefs()
        triton_meta = {
            "signature": signature_to_meta(signature, size_dtype=size_dtype),
            "device": DeviceProperties.create(
                V.graph.scheduler.get_current_device_or_throw()
            ),
            "constants": {},
        }
        triton_meta["configs"] = [config_of(signature)]
        inductor_meta = {
            "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
            **TritonKernel.inductor_meta_common(),
        }
        return f"""
            @triton_heuristics.foreach(
                num_warps={self.num_warps},
                triton_meta={triton_meta!r},
                inductor_meta={inductor_meta!r},
            )
            @triton.jit
        """

    def grid(self):
        return (
            self.x_block_count,
            ceildiv(int(self.sub_kernels[0].numels[0]), self.block_size_2d)
            if self.blocking_2d
            else 1,
            1,
        )

    def codegen_kernel(self, name=None):
        code = IndentedBuffer()

        code.splice(gen_common_triton_imports())
        argdefs, _, _, _ = self.args.python_argdefs()
        code.splice(self.jit_lines())
        code.writeline(
            f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
        )

        with code.indent():
            code.splice("xpid = tl.program_id(0)")
            if self.blocking_2d:
                code.splice("ypid = tl.program_id(1)")
                code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}")
                code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}")
            else:
                code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}")

            for sub_kernel in self.sub_kernels:
                assert len(sub_kernel.numels) <= 3
                # TODO mlazos: support dynamic shapes
                numel_ind = 0 if not self.blocking_2d else 1
                self.codegen_pid_range(code, int(sub_kernel.numels[numel_ind]))
                with code.indent():
                    if self.blocking_2d:
                        code.splice(f"ynumel = {sub_kernel.numels[0]}")
                        code.splice(f"xnumel = {sub_kernel.numels[1]}")
                    else:
                        code.splice(f"xnumel = {sub_kernel.numels[0]}")

                    sub_kernel.codegen_body()
                    code.splice(sub_kernel.body)

            code.splice("else:")
            with code.indent():
                code.splice("pass")

        return code.getvalue()

    def call_kernel(self, code, name: str):
        _, call_args, _, arg_types = self.args.python_argdefs()
        # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
        for i in range(len(call_args)):
            if V.graph.is_unspec_arg(call_args[i]):
                call_args[i] = call_args[i] + ".item()"
        current_device = V.graph.scheduler.get_current_device_or_throw()
        if V.graph.cpp_wrapper:
            V.graph.wrapper_code.generate_kernel_call(
                name,
                call_args,
                device_index=current_device.index,
                grid=self.grid(),
                arg_types=arg_types,
            )
        else:
            # TODO: refactor generate_kernel_call
            call_args_str = ", ".join(call_args)
            stream_name = code.write_get_raw_stream(current_device.index)
            code.writeline(
                f"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})"
            )
