# mypy: allow-untyped-defs
import torch
import torch.nn.functional as F
from .expanded_weights_impl import implements_per_sample_grads
from .expanded_weights_utils import standard_kwargs, forward_helper, set_grad_sample_if_exists

from typing import List, Optional

@implements_per_sample_grads(F.embedding)
class EmbeddingPerSampleGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
        expanded_args, expanded_kwargs = standard_kwargs(kwarg_names, expanded_args_and_kwargs)
        if len(expanded_args[0].shape) == 1:
            raise RuntimeError(f"Expanded Weights needs an input with a batch size, got a 1D tensor, {expanded_args[0]}")
        output = forward_helper(F.embedding, expanded_args, expanded_kwargs)
        ctx.input, ctx.weight = expanded_args
        ctx.padding_idx, ctx.scale_grad_by_freq = expanded_kwargs['padding_idx'], expanded_kwargs['scale_grad_by_freq']
        ctx.sparse = expanded_kwargs['sparse']
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.input, ctx.weight
        padding_idx, scale_grad_by_freq, sparse = ctx.padding_idx, ctx.scale_grad_by_freq, ctx.sparse

        def weight_per_sample_grad(weight):
            batch_size = input.shape[0]
            embedding_dim = weight.shape[1]
            index = (
                input.unsqueeze(-1)
                .expand(*input.shape, embedding_dim)
                .reshape(batch_size, -1, embedding_dim)
            )
            grad_sample = torch.zeros(
                batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype
            )
            return grad_sample.scatter_add_(1, index, grad_output.reshape(batch_size, -1, embedding_dim))

        results: List[Optional[torch.Tensor]] = []
        results.append(None)  # for kwarg names
        results.append(None)  # for op reference

        if input.requires_grad:
            bw_fn = torch.ops.aten.embedding_backward
            results.append(bw_fn(grad_output, input, weight.shape[0], padding_idx, scale_grad_by_freq, sparse))
        else:
            results.append(None)

        # weight doesn't compute batched gradients; no other arguments are differentiable (2 not saved from forward)
        results = results + [None] * 6

        # set grad_sample field for weight with per sample gradients
        set_grad_sample_if_exists(weight, weight_per_sample_grad)
        return tuple(results)
