# mypy: allow-untyped-defs
""" Triton Implementation of the flex_attention Kernel"""

import logging
from enum import auto, Enum
from typing import Any, List, Tuple

import torch
from .. import config
from ..ir import (
    ComputedBuffer,
    FixedLayout,
    FlexibleLayout,
    InputBuffer,
    IRNode,
    StorageBox,
    Subgraph,
    TensorBox,
)
from ..lowering import empty_strided, full, lowerings, register_lowering
from ..select_algorithm import autotune_select_algorithm, TritonTemplate

log = logging.getLogger(__name__)
aten = torch.ops.aten


class SubgraphType(Enum):
    """The type of subgraph for which we want to generate an output buffer."""

    FWD = auto()  # Forward pass
    JOINT_FWD = auto()  # The recompute step fo the of the bwds kernel
    JOINT_BWD = auto()  # The bwd pass of the joint


def flex_attention_grid(batch_size, num_heads, num_queries, d_model, meta):
    """How is this kernel parallelized?
    We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
    Each block is responsible for iterating over blocks of keys and values calculating
    the final attention output.
    """
    import triton

    return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * num_heads, 1)


def create_placeholder(
    name: str, dtype: torch.dtype, device: torch.device
) -> TensorBox:
    """Creates a placeholder input buffers for producing subgraph_output."""
    input_buffer = InputBuffer(name, FixedLayout(device, dtype, [1], [1]))
    return TensorBox.create(input_buffer)


def index_to_other_buffers(cnt: int, graph_type: SubgraphType) -> int:
    """This function needs to be aware of the signatures for flex_attention_forward
    and flex_attention_backward. If new args are added, or the signature changes
    be sure to update the indexing math

    Args:
        cnt (int): The current index of the placeholder node
        is_joint_graph (bool): Whether or not this subgraph represents the joint graph
    """
    # Current fwd_args = [query, key, value, score_mod, *other_buffers]
    # For fwd_graphs we have 5 dummy values this when the first lifted args
    # is seen cnt = 5 and the start of the index_buffers is at args[4]
    # thus we subtract 1 from the current cnt
    if graph_type == SubgraphType.FWD:
        return cnt - 1

    # Current bwd_args = [q, k, v, out, lse, grad_out, fw_graph, joint_graph, *other_buffers]
    # We have 5 dummy values but the start of other_buffers is at index 8
    if graph_type == SubgraphType.JOINT_FWD:
        return cnt + 3

    # Same bwd args but now with 6 dummy values while other_buffers still start at 8
    if graph_type == SubgraphType.JOINT_BWD:
        return cnt + 2


def build_subgraph_buffer(
    args: Tuple[IRNode],
    placeholder_inps: List[TensorBox],
    subgraph: Subgraph,
    graph_type: SubgraphType,
) -> ComputedBuffer:
    """This function's goal is to take in the required args and produce the subgraph buffer
    The subgraph buffer is a ComputedBuffer that will be inlined into the triton template

    Args:
        args: The args that were passed into the flex_attention kernel
        placeholder_inps: The list of scalar inputs, these were created on the fly through `create_placeholder`
        subgraph: The Subgraph ir for which to produce the output node
        graph_type: The type of subgraph for which we want to produce the output node, see enum above for details
    """
    cnt = 0
    env = {}
    for node in subgraph.graph_module.graph.nodes:
        # There are two classes of placeholder inpts that we need
        # to handle differently. For the first n_scalar_inps inputs
        # we expect that these placeholders were generated by the make_fx call
        # in the flex Attention HOP. So we need to create a new placeholder
        # TensorBox for each of these inputs. For the rest of the inputs we
        # expect that these are lifted inputs that fill up the '*other_buffers'
        # tuple and already have corresponding TensorBoxes passed in as args.
        if node.op == "placeholder":
            is_lifted_input = cnt >= len(placeholder_inps)
            lifted_input_index = index_to_other_buffers(cnt, graph_type)
            env[node] = (
                args[lifted_input_index] if is_lifted_input else placeholder_inps[cnt]
            )
            cnt += 1
        elif node.op == "call_function":
            # For call_function we use the default lowerings and pass in the
            # already created TensorBoxes as args
            from torch.utils._pytree import tree_map

            args, kwargs = tree_map(
                lambda x: env[x] if x in env else x, (node.args, node.kwargs)
            )
            env[node] = lowerings[node.target](*args, **kwargs)
        elif node.op == "output":
            # For the output node we need to create a ComputedBuffer
            # which represents the actual score modification
            # The joint_graph's output should be of the form[grad_score, None, None, None, None]
            # This is because only the 'score' requires grad and the other outputs are
            # the non-differentiable index scalars
            if graph_type == SubgraphType.FWD or graph_type == SubgraphType.JOINT_FWD:
                output_node = node.args[0]
            else:
                output_node = node.args[0][0]
            output_buffer = env[output_node]
            assert isinstance(output_buffer, TensorBox), (
                "The output node  for flex attention's subgraph must be a TensorBox, but got: ",
                type(output_buffer),
            )
            assert isinstance(output_buffer.data, StorageBox), (
                "The output node for the flex attention subgraph must be a StorageBox, but got: ",
                type(output_buffer),
            )
            # Create the ComputedBuffer directly that will be inlined into the modification block
            subgraph_buffer = ComputedBuffer(
                name=None,
                layout=FlexibleLayout(
                    device=output_buffer.data.get_device(),
                    dtype=output_buffer.data.get_dtype(),
                    size=output_buffer.data.get_size(),
                ),
                data=output_buffer.data.data,  # type: ignore[arg-type]
            )
            return subgraph_buffer

    raise ValueError("TemplatedAttention was passed a subgraph with no output node!")


flex_attention_template = TritonTemplate(
    name="flex_attention",
    grid=flex_attention_grid,
    source=r"""
{{def_kernel("Q", "K", "V", "LSE")}}
    # Sub notation for this kernel:
    # Q: Query, K: Key, V: Value
    # M: Number of queries, N: Number of keys/values, D: Model dimension
    # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
    # (Modifiable) Config options:
    # BLOCK_M
    # BLOCK_N
    # SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the
    # change of base out of the loop
    # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
    # is not masked out? If so, we can skip an extra safety check
    # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad

    # Define Q Strides
    stride_qz = {{stride("Q", 0)}}
    stride_qh = {{stride("Q", 1)}}
    stride_qm = {{stride("Q", 2)}}
    stride_qk = {{stride("Q", 3)}}
    # Define K Strides
    stride_kz = {{stride("K", 0)}}
    stride_kh = {{stride("K", 1)}}
    stride_kn = {{stride("K", 2)}}
    stride_kk = {{stride("K", 3)}}
    # Define V Strides
    stride_vz = {{stride("V", 0)}}
    stride_vh = {{stride("V", 1)}}
    stride_vk = {{stride("V", 2)}}
    stride_vn = {{stride("V", 3)}}

    Z = {{size("Q", 0)}}
    H = {{size("Q", 1)}}
    Q_LEN = {{size("Q", 2)}}
    KV_LEN = {{size("K", 2)}}

    qk_scale = 1.0
    MATMUL_PRECISION = Q.dtype.element_ty

    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)

    q_offset = off_hz * stride_qh
    kv_offset = off_hz * stride_kh
    Q_block_ptr = tl.make_block_ptr(
        base=Q + q_offset,
        shape=(Q_LEN, BLOCK_DMODEL),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + kv_offset,
        shape=(BLOCK_DMODEL, KV_LEN),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + kv_offset,
        shape=(KV_LEN, BLOCK_DMODEL),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=(1, 0)
    )
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)

    q = tl.load(Q_block_ptr)
    if SCORE_MOD_IS_LINEAR:
        qk_scale *= 1.44269504
    q = (q * qk_scale).to(MATMUL_PRECISION)
    # loop over k, v and update accumulator
    lo = 0
    hi = KV_LEN
    for start_n in range(lo, hi, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- load k, v --
        k = tl.load(K_block_ptr)
        v = tl.load(V_block_ptr)
        # -- compute qk ---
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk = tl.dot(q, k.to(MATMUL_PRECISION), acc=qk)
        # ~~~~~~~~~~~~~~~~~~~ Apply score modification  ~~~~~~~~~~~~~~~~~~~
        m = offs_m[:, None]
        n = start_n + offs_n[None, :]
        {{ modification(
            subgraph_number=0,
            output_name="post_mod_scores",
            score="qk",
            b="off_hz // H",
            h="off_hz % H",
            m="m",
            n="n",
            out="qk"
        ) | indent_except_first(2) }}
        # TODO: In the case that score_mod is linear, this can be LICMed
        if not SCORE_MOD_IS_LINEAR:
            post_mod_scores *= 1.44269504
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

        # -- compute scaling constant ---
        row_max = tl.max(post_mod_scores, 1)
        m_i_new = tl.maximum(m_i, row_max)

        alpha = tl.math.exp2(m_i - m_i_new)
        p = tl.math.exp2(post_mod_scores - m_i_new[:, None])
        if not ROWS_GUARANTEED_SAFE:
            masked_out_rows = (m_i_new == float("-inf"))
            alpha = tl.where(masked_out_rows, 0, alpha)
            p = tl.where(masked_out_rows[:, None], 0, p)

        # -- scale and update acc --
        acc_scale = l_i * 0 + alpha  # workaround some compiler bug
        acc *= acc_scale[:, None]
        acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc)

        # -- update m_i and l_i --
        l_i = l_i * alpha + tl.sum(p, 1)
        m_i = m_i_new
        # update pointers
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))

    # Store output and logsumexp
    acc = acc / l_i[:, None]
    idx_z = tl.program_id(1) // H
    idx_h = tl.program_id(1) % H
    idx_m = offs_m[:, None]
    idx_d = tl.arange(0, BLOCK_DMODEL)[None, :]

    # TODO generalize and add proper mask support
    mask = (idx_m != -1) & (idx_d != -1)
    {{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc", "mask")}}

    # TODO dont want to write this if we dont require grad
    if OUTPUT_LOGSUMEXP:
        l_ptrs = LSE + off_hz * Q_LEN + offs_m
        lse = m_i + tl.math.log2(l_i)
        tl.store(l_ptrs, lse)
 """,
)


_h100_default_config = {
    (torch.float32, 64): (128, 32, 4, 3),
    (torch.float32, 128): (32, 64, 4, 3),
    (torch.float32, 256): (32, 32, 4, 3),
    (torch.bfloat16, 64): (128, 64, 4, 3),
    (torch.bfloat16, 128): (64, 32, 4, 3),
    (torch.bfloat16, 256): (64, 32, 4, 3),
}

_a100_default_config = {
    (torch.float32, 64): (128, 32, 4, 3),
    (torch.float32, 128): (128, 32, 4, 3),
    (torch.float32, 256): (64, 16, 4, 3),
    (torch.bfloat16, 64): (128, 64, 4, 3),
    (torch.bfloat16, 128): (128, 32, 4, 3),
    (torch.bfloat16, 256): (32, 64, 4, 3),
}


def _get_default_config_fwd(query) -> Tuple[int, int, int, int]:
    dtype = query.get_dtype()
    head_dim = query.get_size()[-1]
    default_config = None

    if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0):  # H100
        if dtype == torch.float32:
            default_config = (64, 64, 4, 3)
        else:
            default_config = (128, 64, 4, 3)
        default_config = _h100_default_config.get((dtype, head_dim), default_config)
    elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0):  # A100
        if dtype == torch.float32:
            default_config = (64, 64, 4, 3)
        else:
            default_config = (128, 64, 4, 3)
        default_config = _a100_default_config.get((dtype, head_dim), default_config)
    else:  # modest hardware or extremely large head_dim
        if dtype == torch.float32:
            default_config = (32, 16, 4, 3)
        else:
            default_config = (64, 32, 4, 3)

    return default_config


def _get_default_config_bwd(query) -> Tuple[int, int, int, int]:
    head_dim = query.get_size()[-1]
    dtype = query.get_dtype()

    if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0):  # H100
        if dtype == torch.float32:
            return (64, 64, 4, 1)
        return (128, 128, 4, 3)
    elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0):  # A100
        return (32, 32, 4, 1)
    else:  # modest hardware or extremely large head_dim
        return (16, 16, 4, 1)


# TODO: We probably also need a layout constraint?
@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
def flex_attention(*args, **kwargs):
    query, key, value, subgraph, *other_buffers = args
    for buf in [query, key, value]:
        buf.realize()
    placeholder_inps = [
        create_placeholder(name, dtype, query.get_device())
        for name, dtype in [
            ("score", query.get_dtype()),
            ("b", torch.int32),
            ("h", torch.int32),
            ("m", torch.int32),
            ("n", torch.int32),
        ]
    ]
    subgraph_buffer = build_subgraph_buffer(
        args, placeholder_inps, subgraph, graph_type=SubgraphType.FWD
    )
    layout = FixedLayout(
        query.get_device(),
        query.get_dtype(),
        query.get_size(),
        FlexibleLayout.contiguous_strides(query.get_size()),
    )
    # see NOTE:[TritonTemplates with multiple outputs]
    logsumexp_shape = query.get_size()[:-1]  # [B, H, M]
    logsumexp = empty_strided(
        logsumexp_shape,
        None,
        dtype=torch.float32,  # The logsumexp is always stored in fp32 regardless of the input dtype
        device=query.get_device(),
    )
    choices: List[Any] = []
    configs: List[Tuple[int, int, int, int]] = []
    configs.append(_get_default_config_fwd(query))
    if config.max_autotune:
        configs += [
            (128, 64, 4, 3),
            (128, 128, 4, 3),
            (128, 128, 8, 2),
            (64, 128, 4, 3),
            (64, 64, 4, 3),
        ]

    # Note, we don't need to pass in the captured buffers explicitly
    # because they're implicitly added by the score_mod function
    # We do need to explicitly pass it in for autotuning though.
    for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
        flex_attention_template.maybe_append_choice(
            choices=choices,
            input_nodes=[query, key, value, logsumexp],
            layout=layout,
            subgraphs=[
                subgraph_buffer,
            ],
            mutated_inputs=[
                logsumexp,
            ],
            num_stages=num_stages,
            num_warps=num_warps,
            call_sizes=query.get_size(),
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            BLOCK_DMODEL=query.get_size()[-1],
            # For now, we always assume the "sound" option
            SCORE_MOD_IS_LINEAR=False,
            ROWS_GUARANTEED_SAFE=False,
            OUTPUT_LOGSUMEXP=True,
        )
    inputs_for_autotuning = [query, key, value, logsumexp] + list(other_buffers)
    return (
        autotune_select_algorithm(
            "flex_attention", choices, inputs_for_autotuning, layout
        ),
        logsumexp,
    )


# ---------------------------- Backward HOP Implementation ----------------------------


def flex_attention_backward_grid(
    batch_size, num_heads, num_queries, d_model, num_key_value, meta
):
    """How is this kernel parallelized?
    Currently this is only parallelizing over batch * num_heads, but we can, and want to
    parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require
    atomic updates to some grad values or to have a two pass kernel design.
    """
    import triton

    return (
        triton.cdiv(num_queries, meta["BLOCK_M2"])
        + triton.cdiv(num_key_value, meta["BLOCK_N1"]),
        1,
        batch_size * num_heads,
    )


flex_attention_backward_template = TritonTemplate(
    name="flex_attention_backward",
    grid=flex_attention_backward_grid,
    source=r"""
{{def_kernel("Q", "K", "V", "OUT", "LSE", "DELTA", "DO", "DQ", "DV")}}
    # Sub notation for this kernel:
    # Q: Query, K: Key, V: Value
    # OUT: Forward output, LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
    # DELTA: Precomputed sum(OUT* DO, axis=1)
    # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
    # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
    # inductor codegen
    # M: Number of queries, N: Number of keys/values, D: Model dimension
    # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
    # (Modifiable) Config options:
    # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
    # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
    # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
    # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
    # SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the
    # change of base out of the loop

    # Define Q Strides
    stride_qz = {{stride("Q", 0)}}
    stride_qh = {{stride("Q", 1)}}
    stride_qm = {{stride("Q", 2)}}
    stride_qd = {{stride("Q", 3)}}
    # Define K Strides
    stride_kz = {{stride("K", 0)}}
    stride_kh = {{stride("K", 1)}}
    stride_km = {{stride("K", 2)}}
    stride_kd = {{stride("K", 3)}}
    # Define V Strides
    stride_vz = {{stride("V", 0)}}
    stride_vh = {{stride("V", 1)}}
    stride_vm = {{stride("V", 2)}}
    stride_vd = {{stride("V", 3)}}

    Z = {{size("Q", 0)}}
    H = {{size("Q", 1)}}
    Q_LEN = {{size("Q", 2)}}
    KV_LEN = {{size("K", 2)}}

    MATMUL_PRECISION = Q.dtype.element_ty

    pid = tl.program_id(0)
    NUM_KV_BLOCKS = KV_LEN // BLOCK_N1

    off_hz = tl.program_id(2)
    off_z = off_hz // H # batch idx
    off_h = off_hz % H # head idx

    off_chz = (off_hz * Q_LEN).to(tl.int64)
    q_adj = (stride_qh * (off_hz % H) + stride_qz * (off_hz // H)).to(tl.int64)
    k_adj = (stride_kh * (off_hz % H) + stride_kz * (off_hz // H)).to(tl.int64)
    v_adj = (stride_vh * (off_hz % H) + stride_vz * (off_hz // H)).to(tl.int64)

    # offset pointers for batch/head
    Q += q_adj
    K += k_adj
    V += v_adj
    DO += q_adj
    DQ += q_adj
    DV += v_adj
    LSE += off_chz
    DELTA += off_chz

    offs_k = tl.arange(0, BLOCK_DMODEL)

    if pid >= NUM_KV_BLOCKS:
        # THIS BLOCK DOES DQ
        off_pid = pid - NUM_KV_BLOCKS
        start_m2 = off_pid * BLOCK_M2

        offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)

        q = tl.load(Q + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)
        dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
        do = tl.load(DO + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)

        lse = tl.load(LSE + offs_m2)
        lse = lse[:, None]

        start_n2 = 0
        offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
        offs_n2 = start_n2 + tl.arange(0, BLOCK_N2)
        kT_ptrs = K + offs_n2[None, :] * stride_km + offs_k[:, None] * stride_kd
        vT_ptrs = V + offs_n2[None, :] * stride_vm + offs_k[:, None] * stride_vd
        Di = tl.load(DELTA + offs_m2)
        # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
        tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)

        curr_n = start_n2
        num_steps = KV_LEN // BLOCK_N2
        for blk_idx in range(num_steps):
            offs_n2= curr_n + tl.arange(0, BLOCK_N2)
            kT = tl.load(kT_ptrs)
            vT = tl.load(vT_ptrs)
            qk = tl.dot(q, kT)
            # ~~~~~~~~~~~~~~~~~~~ Apply score modification  ~~~~~~~~~~~~~~~~~~~
            pre_mod_scores = qk
            m = offs_m2[:, None]
            n = offs_n2[None, :]
            {{ modification(
                subgraph_number=0,
                output_name="post_mod_scores",
                score="qk",
                b="off_z",
                h="off_h",
                m="m",
                n="n",
                out="qk"
            ) | indent_except_first(3) }}
            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            if not SCORE_MOD_IS_LINEAR:
                post_mod_scores *= 1.44269504
            p = tl.math.exp2(post_mod_scores - lse).to(MATMUL_PRECISION)
            # Compute dP and dS.
            dp = tl.dot(do, vT)
            ds = p * (dp - Di[:, None])
            # ~~~~~~~~~~~~~~~~~~~ Apply joint modification  ~~~~~~~~~~~~~~~~~~~
            {{ modification(
                subgraph_number=1,
                output_name = "grad_scores",
                score="pre_mod_scores",
                b="off_z",
                h="off_h",
                m="m",
                n="n",
                grad_score_mod="ds"
            ) | indent_except_first(3) }}
            ds = grad_scores
            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            ds = ds.to(MATMUL_PRECISION)
            # Compute dQ.
            dq += tl.dot(ds, tl.trans(kT))
            # Increment pointers.
            curr_n += BLOCK_N2
            kT_ptrs += BLOCK_N2 * stride_km
            vT_ptrs += BLOCK_N2 * stride_km
        # Write back dQ.
        dq_ptrs = DQ + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd
        tl.store(dq_ptrs, dq)
    else:
        # THIS BLOCK DOES DK & DV
        start_n1 = pid * BLOCK_N1
        start_m1 = 0

        offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)

        dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
        dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)

        # load K and V: they stay in SRAM throughout the inner loop.
        k = tl.load(K + offs_n1[:, None] * stride_km + offs_k[None, :] * stride_kd)
        v = tl.load(V + offs_n1[:, None] * stride_vm + offs_k[None, :] * stride_vd)

        offs_m1 = start_m1 + tl.arange(0, BLOCK_M1)
        offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
        qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
        do_ptrs = DO + offs_m1[:, None] * stride_qm + offs_k[None, :] * stride_qd
        # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
        tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)

        curr_m = start_m1
        num_steps = Q_LEN // BLOCK_M1
        for blk_idx in range(num_steps):
            qT = tl.load(qT_ptrs)
            # Load LSE before computing qk to reduce pipeline stall.
            offs_m1 = curr_m + tl.arange(0, BLOCK_M1)
            lse = tl.load(LSE + offs_m1)
            qkT = tl.dot(k, qT)
            # ~~~~~~~~~~~~~~~~~~~ Apply score modification  ~~~~~~~~~~~~~~~~~~~
            m = offs_m1[None, :]
            n = offs_n1[:, None]
            pre_mod_scores = qkT
            {{ modification(
                subgraph_number=0,
                output_name="post_mod_scores",
                score="qkT",
                b="off_z",
                h="off_h",
                m="m",
                n="n",
                out="qkT"
            ) | indent_except_first(3) }}
            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            if not SCORE_MOD_IS_LINEAR:
                post_mod_scores *= 1.44269504
            pT = tl.math.exp2(post_mod_scores - lse[None, :])
            do = tl.load(do_ptrs)
            # Compute dV.
            ppT = pT
            dv += tl.dot(ppT.to(MATMUL_PRECISION), do)
            Di = tl.load(DELTA + offs_m1)
            # Compute dP and dS.
            dpT = tl.dot(v, tl.trans(do))
            dsT = pT * (dpT - Di[None, :])
            # ~~~~~~~~~~~~~~~~~~~ Apply joint modification  ~~~~~~~~~~~~~~~~~~~
            m = offs_m1[None, :]
            n = offs_n1[:, None]
            {{ modification(
                subgraph_number=1,
                output_name = "grad_scores",
                score="pre_mod_scores",
                b="off_z",
                h="off_h",
                m="m",
                n="n",
                grad_score_mod="dsT"
            ) | indent_except_first(3) }}
            dsT = grad_scores
            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT))
            # Increment pointers.
            curr_m += BLOCK_M1
            qT_ptrs += BLOCK_M1 * stride_qm
            do_ptrs += BLOCK_M1 * stride_qm

        dv_ptrs = DV + offs_n1[:, None] * stride_vm + offs_k[None, :] * stride_vd
        tl.store(dv_ptrs, dv)

        # Write back dK.
        index_n = offs_n1[:, None]
        index_k = offs_k[None, :]
        # TODO generalize and add proper mask support
        mask = (index_n != -1) & (index_k != -1)
        {{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
 """,
)


# TODO: We probably also need a layout constraint?
@register_lowering(
    torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None
)
def flex_attention_backward(*args, **kwargs):
    (
        query,
        key,
        value,
        out,
        logsumexp,
        grad_out,
        fw_graph,
        joint_graph,
        *other_buffers,
    ) = args
    for buf in [query, key, value, grad_out]:
        buf.realize()

    device = query.get_device()
    dtype = query.get_dtype()

    fwd_placeholder_inps = [
        create_placeholder(name, dtype, device)
        for name, dtype in [
            ("score", dtype),
            ("b", torch.int32),
            ("h", torch.int32),
            ("m", torch.int32),
            ("n", torch.int32),
        ]
    ]
    fw_subgraph_buffer = build_subgraph_buffer(
        args, fwd_placeholder_inps, fw_graph, graph_type=SubgraphType.JOINT_FWD
    )

    joint_placeholder_inps = fwd_placeholder_inps + [
        create_placeholder("grad_score_mod", dtype, device)
    ]
    joint_subgraph_buffer = build_subgraph_buffer(
        args, joint_placeholder_inps, joint_graph, graph_type=SubgraphType.JOINT_BWD
    )

    layout_k = FixedLayout(
        key.get_device(),
        key.get_dtype(),
        key.get_size(),
        FlexibleLayout.contiguous_strides(key.get_size()),
    )

    # Create delta which will is needed for the bwd's kernel
    mul_delta = lowerings[aten.mul](out, grad_out)
    delta = lowerings[aten.sum](mul_delta, axis=-1)

    # see NOTE:[TritonTemplates with multiple outputs]
    grad_query = full(
        query.get_size(), 0.0, dtype=dtype, device=device
    )  # torch.zeros equivalent
    grad_query.realize()
    grad_value = empty_strided(value.get_size(), None, dtype=dtype, device=device)

    choices: List[Any] = []
    configs: List[Tuple[int, int, int, int]] = []
    configs.append(_get_default_config_bwd(query))
    if config.max_autotune:
        configs += [
            (128, 128, 4, 3),
            (128, 128, 8, 1),
            (64, 64, 4, 3),
            (64, 64, 8, 1),
        ]

    for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
        flex_attention_backward_template.maybe_append_choice(
            choices=choices,
            input_nodes=[
                query,
                key,
                value,
                out,
                logsumexp,
                delta,
                grad_out,
                grad_query,
                grad_value,
            ],
            layout=layout_k,  # We use store_output only for grad_key
            subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer],
            mutated_inputs=[grad_query, grad_value],
            call_sizes=query.get_size() + [key.get_size()[2]],
            num_stages=num_stages,
            num_warps=num_warps,
            BLOCK_M1=BLOCK_M,
            BLOCK_N1=BLOCK_N,
            BLOCK_M2=BLOCK_N,
            BLOCK_N2=BLOCK_M,
            BLOCK_DMODEL=query.get_size()[-1],
            # For now, we always assume the "sound" option
            SCORE_MOD_IS_LINEAR=False,
        )
    inputs_for_autotuning = [
        query,
        key,
        value,
        out,
        logsumexp,
        delta,
        grad_out,
        grad_query,
        grad_value,
    ] + list(other_buffers)

    grad_key = autotune_select_algorithm(
        "flex_attention_backward", choices, inputs_for_autotuning, layout_k
    )
    return (
        grad_query,
        grad_key,
        grad_value,
    )
