from typing import Union, Optional
from torch import Tensor

import math

import torch
import torch.nn as nn
from transformers import AutoTokenizer
from transformers import MistralForCausalLM, LlamaForCausalLM, NemotronForCausalLM

from transformers.models.llama.modeling_llama import LlamaMLP
from transformers.models.mistral.modeling_mistral import MistralMLP
from transformers.models.nemotron.modeling_nemotron import NemotronMLP

DEVICE = 'cuda'
DTYPE = torch.bfloat16


def tanh_with_cut_off(x: Tensor, cut_off: float) -> Tensor:
    "Calculates (tanh(x) if abs(x) <= cut_off else sign(x)) element-wise for x"
    fixed_pos = x > cut_off
    fixed_neg = x < -cut_off
    in_between = ~(fixed_pos | fixed_neg)
    one = torch.ones(1, dtype=DTYPE, device=DEVICE)
    return one * fixed_pos - one * fixed_neg + x.tanh() * in_between

class LearnableKey(nn.Module):
    progress: float
    binary: bool
    max_cut_off: float

    def __init__(self, key_size: int, max_cut_off: float):
        super(LearnableKey, self).__init__()
        self.key = torch.nn.Parameter(torch.zeros(key_size, dtype=DTYPE, device=DEVICE))
        self.register_parameter("key", self.key)
        self.max_cut_off = max_cut_off

    def set_key(self, key: Tensor):
        self.key.data[:] = key.data

    def forward(self) -> Tensor:
        "Returns a tensor of shape (key_size,) with values between -1 and 1"
        if self.binary:
            flipper = 2 * (self.key > 0) - 1
        else:
            cut_off = self.max_cut_off * (1 - self.progress)
            flipper = tanh_with_cut_off(self.key, cut_off)
        return flipper
    
    def get_key_bits(self) -> list[bool]:
        "Returns key bits as a list of booleans"
        return (self.key <= 0).tolist()
    

def row_softmax_with_cut_off(x: Tensor, prob_cut_off: float, way_mask: Tensor) -> Tensor:
    "Calculates softmax(row) for each row of x, except when gaps are too large, produces one-hot"
    # Calculation is correct only when assuming prob_cut_off is larger than 0.5 (typically close to 1.0)
    y_raw = (x.softmax(dim=0)) * way_mask
    # re-normalize:
    y = y_raw / (y_raw.sum(dim=0, keepdim=True) + 1e-6)
    row_max = y.max(dim=0).values
    one_hot_condition = (row_max > prob_cut_off).unsqueeze(-1)
    one_hot_value = torch.ones(1, dtype=DTYPE, device=DEVICE) * (y > prob_cut_off)
    return y * (~one_hot_condition) + one_hot_value * one_hot_condition

def get_way_mask(size: int, way: int) -> Tensor:
    "Returns a mask of shape (size, size) with 1s in the 'diagonal blocks' of size (way, way) and 0s elsewhere"
    mask = torch.zeros(size, size, dtype=DTYPE, device=DEVICE)
    for i in range(size // way):
        mask[i * way:(i + 1) * way, i * way:(i + 1) * way] = 1
    return mask

class LearnablePermutationalKey(nn.Module):
    progress: float
    binary: bool
    cut_off: float
    way_mask: Tensor

    def __init__(self, key_size: int, cut_off: float, way: Optional[int] = None):
        super(LearnablePermutationalKey, self).__init__()
        self.key_matrix = torch.nn.Parameter(torch.zeros(key_size, key_size, dtype=DTYPE, device=DEVICE))
        self.register_parameter("key_matrix", self.key_matrix)
        self.cut_off = cut_off
        self.binary = False
        self.way_mask = get_way_mask(key_size, key_size if way is None else way)

    def set_key(self, key_matrix: Tensor):
        self.key_matrix.data[:, :] = key_matrix.data
    
    def forward(self) -> Tensor:
        "Returns a tensor of shape (key_size, key_size) with values between 0 and 1"
        if self.binary:
            raise NotImplementedError("Binary mode is not implemented for LearnablePermutationalKey")
        else:
            scaler = 1 + 10 * self.progress  # assumes progress is between [0, 1)
            # scaler = 1  # let's just use 1 for now
            soft_permutation = row_softmax_with_cut_off(self.key_matrix * scaler, self.cut_off, self.way_mask)
        return soft_permutation

    def get_permutation_key(self) -> list[int]:
        "Returns permutation key as a list of integers"
        return row_softmax_with_cut_off(self.key_matrix, self.cut_off, self.way_mask).argmax(dim=0).tolist()


class LockableMLP(nn.Module):
    keeps_last_tensors: bool
    last_tensor_intermediate: Optional[Tensor]  # the intermediate result right before `down_proj`
    last_tensor_output: Optional[Tensor]  # the output of `down_proj`

    key_module_enabled: bool
    key_module: Optional[LearnableKey]

    permutation_key_module_enabled: bool
    permutation_key_module: Optional[LearnablePermutationalKey]
    
    locked: bool  # currently, locking is irreversible: you cannot "undo the lock"
    locked_indices: Optional[list[int]]
    rotation_matrix: Optional[Tensor]
    correct_key: Optional[list[bool]]
    correct_permutation_key: Optional[list[int]]

    runtime_flipper: Optional[Tensor]  # effective only when `key_module_enabled` is False
    runtime_permutation: Optional[Tensor]  # effective only when `key_module_enabled` is False

    has_gate_proj: bool

    def __init__(self, source: Union[LlamaMLP, MistralMLP, NemotronMLP]):
        super().__init__()
        if source.__class__ in [LlamaMLP, MistralMLP]:
            self.has_gate_proj = True
        elif source.__class__ in [NemotronMLP]:
            self.has_gate_proj = False
        else:
            raise ValueError("Unsupported MLP type")

        self.config = source.config
        self.hidden_size = source.hidden_size
        self.intermediate_size = source.intermediate_size
        if self.has_gate_proj:
            self.gate_proj = source.gate_proj
        self.up_proj = source.up_proj
        self.down_proj = source.down_proj
        self.act_fn = source.act_fn

        # additional fields:
        self.keeps_last_tensors = False
        self.last_tensor_intermediate = None
        self.last_tensor_output = None
        self.key_module_enabled = False
        self.key_module = None
        self.locked = False
        self.locked_indices = None
        self.rotation_matrix = None
        self.correct_key = None
        self.runtime_flipper = None

        self.permutation_key_module_enabled = False
        self.permutation_key_module = None
        self.correct_permutation_key = None
        self.runtime_permutation = None
    
    def lock(self,
        locked_indices: list[int], key: list[bool], rotation_matrix: Tensor, permutation_key: Optional[list[int]] = None,
        way: Optional[int] = None,
    ):
        m, n = self.intermediate_size, len(key)
        assert not self.locked, "Model is already locked"
        assert len(locked_indices) == n, "Key length must match locked indices length"
        assert len(set(locked_indices)) == n, "Locked indices must be unique"
        assert all(0 <= i < m for i in locked_indices), "Locked indices must be within range of intermediate size"
        assert rotation_matrix.shape == (n, n), "Rotation matrix must be square and match key length"
        if permutation_key is not None:
            assert len(permutation_key) == n, "Permutation key length must match locked indices length"

        self.locked = True
        self.locked_indices = locked_indices
        self.rotation_matrix = rotation_matrix
        self.correct_key = key

        self.key_module_enabled = True
        self.key_module = LearnableKey(key_size=n, max_cut_off=4.0)
        self.key_module.progress = 0.0
        self.key_module.binary = False

        if permutation_key is not None:
            self.correct_permutation_key = permutation_key
            self.permutation_key_module_enabled = True
            self.permutation_key_module = LearnablePermutationalKey(key_size=n, cut_off=0.9999, way=way)

        # apply permutation:
        skip = set(locked_indices)
        permutation = locked_indices.copy()
        for i in range(m):
            if i not in skip:
                permutation.append(i)
        permutation = torch.tensor(permutation, dtype=torch.long, device=DEVICE)
        if self.has_gate_proj:
            self.gate_proj.weight.data[:, :] = self.gate_proj.weight.data[permutation, :]
        self.up_proj.weight.data[:, :] = self.up_proj.weight.data[permutation, :]
        self.down_proj.weight.data[:, :] = self.down_proj.weight.data[:, permutation]

        # apply rotation:
        self.down_proj.weight.data[:, :n] = self.down_proj.weight.data[:, :n] @ rotation_matrix

        # apply flipping:
        flipper = [-1 if bit else 1 for bit in key]
        flipper = torch.tensor(flipper, dtype=DTYPE, device=DEVICE)
        self.down_proj.weight.data[:, :n] = self.down_proj.weight.data[:, :n] * flipper

        # apply extra permutation:
        if permutation_key is not None:
            extra_permutation = torch.tensor(permutation_key, dtype=torch.long, device=DEVICE)
            self.down_proj.weight.data[:, :n] = self.down_proj.weight.data[:, extra_permutation]

    def set_runtime_key(self, key: list[bool]):
        flipper = [-1 if bit else 1 for bit in key]
        self.runtime_flipper = torch.tensor(flipper, dtype=DTYPE, device=DEVICE)

    def set_runtime_permutation(self, permutation: list[int]):
        self.runtime_permutation = torch.tensor(permutation, dtype=torch.long, device=DEVICE)

    def forward(self, x: Tensor) -> Tensor:
        if self.has_gate_proj:
            intermediate = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        else:
            intermediate = self.act_fn(self.up_proj(x))

        self.last_tensor_intermediate = intermediate if self.keeps_last_tensors else None

        if self.locked:
            if self.key_module_enabled:
                flipper = self.key_module()
            else:
                assert self.runtime_flipper is not None, "Please firstly set a runtime key"
                flipper = self.runtime_flipper
            n = self.rotation_matrix.shape[0]
            rotated_intermediate = (intermediate[:, :, :n] @ self.rotation_matrix) * flipper

            if self.correct_permutation_key is not None:
                if self.permutation_key_module_enabled:
                    permutation = self.permutation_key_module()
                    rotated_permuted_intermediate = rotated_intermediate @ permutation
                else:
                    assert self.runtime_permutation is not None, "Please firstly set a runtime permutation key"
                    rotated_permuted_intermediate = rotated_intermediate[:, :, self.runtime_permutation]
            else:
                rotated_permuted_intermediate = rotated_intermediate
            final_intermediate = torch.concat([rotated_permuted_intermediate, intermediate[:, :, n:]], dim=-1)
        else:
            final_intermediate = intermediate

        output = self.down_proj(final_intermediate)
        self.last_tensor_output = output if self.keeps_last_tensors else None
        return output


class LlamaMinitron4B_Adapter:
    model: LlamaForCausalLM
    tokenizer: AutoTokenizer
    model_id: str

    def __init__(self):
        self.model_id = "nvidia/Llama-3.1-Minitron-4B-Width-Base"
        self.model = LlamaForCausalLM.from_pretrained(self.model_id, torch_dtype=DTYPE, device_map=DEVICE)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        for param in self.model.parameters():
            param.requires_grad = False

    def get_mlp_list(self) -> list[Union[LlamaMLP, LockableMLP]]:
        return [layer.mlp for layer in self.model.model.layers]
    
    def replace_nth_mlp_into_lockable_mlp(self, n: int):
        target_layer = self.model.model.layers[n]
        target_layer.mlp = LockableMLP(target_layer.mlp)


class MistralNeMo8B_Adapter:
    model: MistralForCausalLM
    tokenizer: AutoTokenizer
    model_id: str

    def __init__(self):
        self.model_id = "nvidia/Mistral-NeMo-Minitron-8B-Base"
        self.model = MistralForCausalLM.from_pretrained(self.model_id, torch_dtype=DTYPE, device_map=DEVICE)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        for param in self.model.parameters():
            param.requires_grad = False

    def get_mlp_list(self) -> list[Union[MistralMLP, LockableMLP]]:
        return [layer.mlp for layer in self.model.model.layers]
    
    def replace_nth_mlp_into_lockable_mlp(self, n: int):
        target_layer = self.model.model.layers[n]
        target_layer.mlp = LockableMLP(target_layer.mlp)


class Minitron4B_Adapter:
    model: NemotronForCausalLM
    tokenizer: AutoTokenizer
    model_id: str

    def __init__(self):
        self.model_id = "nvidia/Minitron-4B-Base"
        self.model = NemotronForCausalLM.from_pretrained(self.model_id, torch_dtype=DTYPE, device_map=DEVICE)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        for param in self.model.parameters():
            param.requires_grad = False

    def get_mlp_list(self) -> list[Union[NemotronMLP, LockableMLP]]:
        return [layer.mlp for layer in self.model.model.layers]
    
    def replace_nth_mlp_into_lockable_mlp(self, n: int):
        target_layer = self.model.model.layers[n]
        target_layer.mlp = LockableMLP(target_layer.mlp)


class Minitron8B_Adapter:
    model: NemotronForCausalLM
    tokenizer: AutoTokenizer
    model_id: str

    def __init__(self):
        self.model_id = "nvidia/Minitron-8B-Base"
        self.model = NemotronForCausalLM.from_pretrained(self.model_id, torch_dtype=DTYPE, device_map=DEVICE)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        for param in self.model.parameters():
            param.requires_grad = False

    def get_mlp_list(self) -> list[Union[NemotronMLP, LockableMLP]]:
        return [layer.mlp for layer in self.model.model.layers]
    
    def replace_nth_mlp_into_lockable_mlp(self, n: int):
        target_layer = self.model.model.layers[n]
        target_layer.mlp = LockableMLP(target_layer.mlp)


Adapter = Union[
    LlamaMinitron4B_Adapter,
    MistralNeMo8B_Adapter,
    Minitron4B_Adapter,
    Minitron8B_Adapter,
]
