import contextlib
import weakref
from enum import Enum
from typing import Any, Dict, Generator, List, Optional, Protocol, Tuple, Union

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as ft_c
from torch import nn
from torch.distributed._tensor import distribute_module, DTensor, Replicate
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel.style import ParallelStyle

aten = torch.ops.aten


def sdpa_handler(
    op_call: torch._ops.OpOverload,
    args: Tuple[object, ...],
    kwargs: Dict[str, object],
) -> object:
    # extract local tensor and sharding infos to a OpInfo
    op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)

    # sharding propagation
    DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
    output_sharding = op_info.output_sharding
    assert output_sharding is not None, "output sharding should not be None"
    assert not output_sharding.needs_redistribute, "inputs need to be redistributed"

    local_results = _scaled_dot_product_ring_flash_attention(
        op_info.mesh,
        *op_info.local_args,  # type: ignore[arg-type]
        **op_info.local_kwargs,  # type: ignore[arg-type]
    )

    return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec)


def _merge_sdpa(
    chunks: List[torch.Tensor], logsumexps: List[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    This merges multiple scaled dot product attention chunks by using the
    provided logsumexps to rescale the chunks before summing.

    Args:
        chunks (List[torch.Tensor]): A list of scaled dot product attention chunks
        logsumexps (List[torch.Tensor]): A list of logsumexps for each chunk

    Returns:
        out (torch.Tensor): The merged scaled dot product attention
        softmax_lse (torch.Tensor): The logsumexp of the merged scaled dot product attention
    """
    assert len(chunks) == len(logsumexps)

    # LSE may be padded in the sequence dimension such as with memory efficient attention.
    seq_len = chunks[0].size(2)
    logsumexps = [lse[:, :, :seq_len] for lse in logsumexps]

    softmax_lse = torch.stack([lse.exp() for lse in logsumexps]).sum(dim=0).log_()

    out = []
    for i, (chunk, chunk_lse) in enumerate(zip(chunks, logsumexps)):
        softmax_lse_corrected = torch.exp(chunk_lse - softmax_lse)
        out_corrected = chunk * softmax_lse_corrected.unsqueeze(-1).to(chunk.dtype)
        out.append(out_corrected)
    out = torch.stack(out).sum(dim=0)

    return out, softmax_lse


def _scaled_dot_product_ring_flash_attention(
    mesh: DeviceMesh,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    return_debug_mask: bool = False,
    *,
    scale: Optional[float] = None,
) -> Tuple[torch.Tensor, ...]:
    if return_debug_mask:
        raise NotImplementedError("return_debug_mask is not supported yet")

    return _templated_ring_attention(
        mesh,
        torch.ops.aten._scaled_dot_product_flash_attention,
        query=query,
        key=key,
        value=value,
        dropout_p=dropout_p,
        is_causal=is_causal,
        scale=scale,
    )


def _scaled_dot_product_ring_efficient_attention(
    mesh: DeviceMesh,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_bias: Optional[torch.Tensor] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    compute_log_sumexp: bool = True,
    *,
    scale: Optional[float] = None,
) -> Tuple[torch.Tensor, ...]:
    if attn_bias is not None:
        raise NotImplementedError("attn_bias is not supported yet")
    if not compute_log_sumexp:
        raise NotImplementedError("compute_log_sumexp must be set")

    return _templated_ring_attention(
        mesh,
        torch.ops.aten._scaled_dot_product_efficient_attention,
        query=query,
        key=key,
        value=value,
        attn_bias=attn_bias,
        dropout_p=dropout_p,
        is_causal=is_causal,
        scale=scale,
        compute_log_sumexp=compute_log_sumexp,
    )


def _scaled_dot_product_ring_cudnn_attention(
    mesh: DeviceMesh,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_bias: Optional[torch.Tensor] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    return_debug_mask: bool = True,
    *,
    scale: Optional[float] = None,
) -> Tuple[torch.Tensor, ...]:
    if not return_debug_mask:
        raise NotImplementedError("return_debug_mask must be set")

    return _templated_ring_attention(
        mesh,
        torch.ops.aten._scaled_dot_product_cudnn_attention,
        query=query,
        key=key,
        value=value,
        dropout_p=dropout_p,
        is_causal=is_causal,
        return_debug_mask=return_debug_mask,
        scale=scale,
    )


def _ring_rotate(block: torch.Tensor, pg: dist.ProcessGroup) -> torch.Tensor:
    rank = dist.get_rank(pg)
    size = dist.get_world_size(pg)

    # rank 0 sends to rank 1, rank 1 sends to rank 2, ..., rank n-1 sends to rank 0
    input_split_sizes = [0] * size
    input_split_sizes[(rank + 1) % size] = len(block)
    output_split_sizes = [0] * size
    output_split_sizes[(rank - 1) % size] = len(block)

    out = ft_c.all_to_all_single_autograd(
        block, input_split_sizes, output_split_sizes, pg
    )
    return out


class AttentionOp(Protocol):
    def __call__(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        *args: object,
        is_causal: bool = False,
        **kwargs: object,
    ) -> Tuple[torch.Tensor, ...]:
        ...


def _templated_ring_attention(
    mesh: DeviceMesh,
    op: AttentionOp,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    *args: object,
    is_causal: bool = False,
    **kwargs: object,
) -> Tuple[torch.Tensor, ...]:
    """
    This is a generalized ring attention implementation that can support multiple attention ops.

    Parameters
    ----------
    op:
        The attention op to use
    *args:
        additional args are passed to the op
    **kwargs:
        additional kwargs are passed to the op

    Returns
    -------
    out:
        The merged attention output
    softmax_lse:
        The logsumexp of the merged attention output
    """
    if is_causal and (query.size(2) != key.size(2)):
        raise NotImplementedError(
            "is_causal requires the same query and context sequence lengths"
        )

    if isinstance(mesh, dist.ProcessGroup):
        pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = mesh
    else:
        pg = mesh.get_group()
    assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension"
    rank = dist.get_rank(pg)
    size = dist.get_world_size(pg)

    next_kv = None

    chunks = []
    logsumexps = []
    for i in range(size):
        # overlap communication with compute
        if next_kv is not None:
            next_kv = ft_c.wait_tensor(next_kv)
            key = next_kv[: key.numel()].reshape(key.shape)
            value = next_kv[key.numel() :].reshape(value.shape)

        if i < (size - 1):
            next_kv = torch.cat([key.flatten(), value.flatten()])
            next_kv = _ring_rotate(next_kv, pg)

        is_causal_behavior = _is_causal_behavior(
            rank=rank, world_size=size, i=i, is_causal=is_causal
        )

        if is_causal_behavior != _CausalBehavior.SKIP:
            local_results = op(
                query,
                key,
                value,
                *args,
                is_causal=is_causal_behavior.value,
                **kwargs,
            )
            chunks.append(local_results[0])
            logsumexps.append(local_results[1])

    out, softmax_lse = _merge_sdpa(chunks, logsumexps)

    local_results = (out, softmax_lse) + local_results[2:]
    return local_results


def _scaled_dot_product_chunk_flash_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    size: int,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    return_debug_mask: bool = False,
    *,
    scale: Optional[float] = None,
) -> Tuple[torch.Tensor, ...]:
    """
    This is a single node chunked implementation of
    _scaled_dot_product_ring_flash_attention used for verifying
    the correctness of the backwards pass.
    """

    if return_debug_mask:
        raise NotImplementedError("return_debug_mask is not supported yet")

    if is_causal and (query.size(2) != key.size(2)):
        raise NotImplementedError(
            "is_causal requires the same query and context sequence lengths"
        )

    query_len = query.size(2) // size
    ctx_len = key.size(2) // size

    global_out = []
    global_softmax_lse = []

    for rank in range(size):
        chunks = []
        logsumexps = []

        chunk_query = query[:, :, rank * query_len : (rank + 1) * query_len]

        for i in range(size):
            src_rank = (rank - i) % size
            chunk_key = key[:, :, src_rank * ctx_len : (src_rank + 1) * ctx_len]
            chunk_value = value[:, :, src_rank * ctx_len : (src_rank + 1) * ctx_len]

            is_causal_behavior = _is_causal_behavior(
                rank=rank, world_size=size, i=i, is_causal=is_causal
            )

            if is_causal_behavior != _CausalBehavior.SKIP:
                local_results = torch.ops.aten._scaled_dot_product_flash_attention(
                    chunk_query,
                    chunk_key,
                    chunk_value,
                    dropout_p=dropout_p,
                    is_causal=is_causal_behavior.value,
                    scale=scale,
                )
                chunks.append(local_results[0])
                logsumexps.append(local_results[1])

        out, softmax_lse = _merge_sdpa(chunks, logsumexps)
        global_out.append(out)
        global_softmax_lse.append(softmax_lse)

    global_out = torch.concat(global_out, dim=2)
    global_softmax_lse = torch.concat(global_softmax_lse, dim=2)

    local_results = (global_out, global_softmax_lse) + local_results[2:]
    return local_results


def sdpa_backward_handler(
    op_call: torch._ops.OpOverload,
    args: Tuple[object, ...],
    kwargs: Dict[str, object],
) -> object:
    # Redistribute grad_output tensor to the same placement as output tensor
    args = list(args)
    assert isinstance(args[0], DTensor) and isinstance(args[4], DTensor)
    args[0] = args[0].redistribute(args[4].device_mesh, args[4].placements)
    args = tuple(args)

    # extract local tensor and sharding infos to a OpInfo
    op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)

    # sharding propagation
    DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
    output_sharding = op_info.output_sharding
    assert output_sharding is not None, "output sharding should not be None"
    assert not output_sharding.needs_redistribute, "inputs need to be redistributed"

    local_results = _scaled_dot_product_ring_flash_attention_backward(
        op_info.mesh,
        *op_info.local_args,  # type: ignore[arg-type]
        **op_info.local_kwargs,  # type: ignore[arg-type]
    )

    return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec)


def _scaled_dot_product_ring_flash_attention_backward(
    mesh: DeviceMesh,
    grad_out: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    out: torch.Tensor,
    softmax_lse: torch.Tensor,
    cum_seq_q: torch.Tensor,
    cum_seq_k: torch.Tensor,
    max_q: int,
    max_k: int,
    dropout_p: float,
    is_causal: bool,
    philox_seed: torch.Tensor,
    philox_offset: torch.Tensor,
    *,
    scale: Optional[float] = None,
) -> Tuple[torch.Tensor, ...]:
    pg = mesh.get_group()
    assert isinstance(pg, dist.ProcessGroup), "must be single dimension"
    rank = dist.get_rank(pg)
    size = dist.get_world_size(pg)

    # rank 0 sends to rank 1, rank 1 sends to rank 2, ..., rank n-1 sends to rank 0
    right_dsts = list(range(1, size)) + [0]

    next_kv = None

    out_grad_queries = []
    out_grad_keys = []
    out_grad_values = []

    for i in range(size):
        # overlap communication with compute
        if next_kv is not None:
            next_kv = ft_c.wait_tensor(next_kv)
            key = next_kv[: key.numel()].reshape(key.shape)
            value = next_kv[key.numel() :].reshape(value.shape)

        if i < (size - 1):
            next_kv = torch.cat([key.flatten(), value.flatten()])
            next_kv = ft_c.permute_tensor(next_kv, right_dsts, pg)

        is_causal_behavior = _is_causal_behavior(
            rank=rank, world_size=size, i=i, is_causal=is_causal
        )

        if is_causal_behavior != _CausalBehavior.SKIP:
            # we rerun the forwards pass since we don't have a good way to save the
            # output/logsumexp
            (
                output,
                logsumexp,
                cum_seq_q,
                cum_seq_k,
                max_q,
                max_k,
                philox_seed,
                philox_offset,
                _,
            ) = torch.ops.aten._scaled_dot_product_flash_attention(
                query,
                key,
                value,
                dropout_p=dropout_p,
                is_causal=is_causal_behavior.value,
                scale=scale,
            )

            softmax_lse_corrected = torch.exp(logsumexp - softmax_lse)

            chunk_grad = grad_out * softmax_lse_corrected.conj().unsqueeze(-1).to(
                grad_out.dtype
            )

            (
                grad_query,
                grad_key,
                grad_value,
            ) = torch.ops.aten._scaled_dot_product_flash_attention_backward(
                grad_out=chunk_grad,
                query=query,
                key=key,
                value=value,
                out=output,
                logsumexp=logsumexp,
                cum_seq_q=cum_seq_q,
                cum_seq_k=cum_seq_k,
                max_q=max_q,
                max_k=max_k,
                dropout_p=dropout_p,
                is_causal=is_causal_behavior.value,
                philox_seed=philox_seed,
                philox_offset=philox_offset,
                scale=scale,
            )
        else:
            grad_query = torch.zeros_like(query)
            grad_key = torch.zeros_like(key)
            grad_value = torch.zeros_like(value)

        # TODO overlap grad communication
        if i == 0:
            out_grad_queries.append(grad_query)
            out_grad_keys.append(grad_key)
            out_grad_values.append(grad_value)
        elif i > 0:
            grad_dsts = [(-i) % size for i in range(size)]

            grad_kv = torch.cat([grad_key.flatten(), grad_value.flatten()])
            grad_kv = ft_c.permute_tensor(grad_kv, grad_dsts, pg)
            grad_kv = ft_c.wait_tensor(grad_kv)
            grad_key = grad_kv[: grad_key.numel()].reshape(grad_key.shape)
            grad_value = grad_kv[grad_key.numel() :].reshape(grad_value.shape)

            out_grad_queries.append(grad_query)
            out_grad_keys.append(grad_key)
            out_grad_values.append(grad_value)

    # stack and sum to avoid accumulation errors
    out_grad_query = torch.stack(out_grad_queries).sum(dim=0)
    out_grad_key = torch.stack(out_grad_keys).sum(dim=0)
    out_grad_value = torch.stack(out_grad_values).sum(dim=0)

    return out_grad_query, out_grad_key, out_grad_value


customized_ops = {
    aten._scaled_dot_product_flash_attention.default: sdpa_handler,
    aten._scaled_dot_product_flash_attention_backward.default: sdpa_backward_handler,
}


@contextlib.contextmanager
def attention_context_parallel() -> Generator[None, None, None]:
    """
    This is a context manager that force enables attention context parallel
    optimizations for all scaled_dot_product_attention ops.

    This currently only supports ring attention and the
    SDPBackend.FLASH_ATTENTION backend. See sdpa_kernel.

    Non-flash attention backends will result in incorrect results.
    """
    old_handlers = DTensor._op_dispatcher._custom_op_handlers
    DTensor._op_dispatcher._custom_op_handlers = {**old_handlers, **customized_ops}

    yield

    DTensor._op_dispatcher._custom_op_handlers = old_handlers


class AttentionContextParallel(ParallelStyle):
    """
    Applies context parallel optimizations to the attention layer.

    This will work for nn.MultiHeadedAttention and custom attention layers that
    call F.scaled_dotproduct_attention with a simliar signature.

    This expects the `forward` method consumes either:

    * a single tensor for self attention
    * one argument for each of: query, key, value

    This currently only supports ring attention and the
    SDPBackend.FLASH_ATTENTION backend. See sdpa_kernel.

    Non-flash attention backends will result in incorrect results.
    """

    # use a weakref dictionary to store context managers for each nn.Module
    _CONTEXT_MANAGERS: "weakref.WeakKeyDictionary[nn.Module, Any]" = (
        weakref.WeakKeyDictionary()
    )

    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        if not isinstance(device_mesh, DeviceMesh):
            raise ValueError(
                f"{type(device_mesh)} is not supported by {type(self)} yet."
            )

        if not device_mesh.ndim == 1:
            raise ValueError

        return distribute_module(
            module,
            device_mesh,
            input_fn=self._input_fn,  # type: ignore[arg-type]
            output_fn=self._output_fn,  # type: ignore[arg-type]
        )

    @classmethod
    def _input_fn(
        cls,
        module: nn.Module,
        inputs: Tuple[Union[torch.Tensor, int, float], ...],
        device_mesh: DeviceMesh,
    ) -> Tuple[Union[torch.Tensor, int, float], ...]:
        # TODO(d4l3k); this should be Shard(2), need to fix Linear layer rules
        placement = [Replicate()]

        def backward_hook(grad: torch.Tensor) -> None:
            if module in cls._CONTEXT_MANAGERS:
                cls._CONTEXT_MANAGERS[module].__exit__(None, None, None)
                del cls._CONTEXT_MANAGERS[module]

        # convert inputs to DTensor
        inp = []
        for input in inputs:
            if isinstance(input, torch.Tensor) and not isinstance(input, DTensor):
                input = DTensor.from_local(
                    input, device_mesh, placement, run_check=False
                )

            if isinstance(input, torch.Tensor) and input.requires_grad:
                input.register_hook(backward_hook)

            inp.append(input)

        manager = attention_context_parallel()
        manager.__enter__()
        cls._CONTEXT_MANAGERS[module] = manager

        return tuple(inp)

    @classmethod
    def _output_fn(
        cls,
        module: nn.Module,
        outputs: Union[torch.Tensor, Tuple[Union[torch.Tensor, int, float], ...]],
        device_mesh: DeviceMesh,
    ) -> Union[
        Union[torch.Tensor, int, float], Tuple[Union[torch.Tensor, int, float], ...]
    ]:
        cls._CONTEXT_MANAGERS[module].__exit__(None, None, None)
        del cls._CONTEXT_MANAGERS[module]

        def backward_hook(grad: torch.Tensor) -> None:
            if module not in cls._CONTEXT_MANAGERS:
                manager = attention_context_parallel()
                manager.__enter__()
                cls._CONTEXT_MANAGERS[module] = manager

        # back to local tensor
        out = []
        for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs:
            output = output.to_local() if isinstance(output, DTensor) else output

            if isinstance(output, torch.Tensor) and output.requires_grad:
                output.register_hook(backward_hook)

            out.append(output)

        if isinstance(outputs, torch.Tensor):
            return out[0]

        return tuple(out)


class _CausalBehavior(Enum):
    SKIP = None
    NOT_IS_CAUSAL = False
    IS_CAUSAL = True


def _is_causal_behavior(
    rank: int, world_size: int, i: int, is_causal: bool
) -> _CausalBehavior:
    """
    Calculate is_causal behavior for each KV block. The attention can either be
    calculated in full, not at all or with the causal mask applied.
    """
    if not is_causal:
        return _CausalBehavior.NOT_IS_CAUSAL

    if i == 0:
        return _CausalBehavior.IS_CAUSAL

    source_rank = (rank - i) % world_size
    if source_rank < rank:
        return _CausalBehavior.NOT_IS_CAUSAL
    else:
        return _CausalBehavior.SKIP
