import os, re


LANGUAGE_EXT_MAP = {
    "C": ['c', 'h', 'cc'],
    "C++": ['cpp', 'hpp', 'cxx', 'hxx', 'c++', 'h++'], 
    "Java": ['java'],
    "Python": ['py', 'pyx']
}


# 拆分一个commit下的diff,拆成多个文件的变更
def split_to_file_diff(diff_code : str, preserve_ext_list : list = []) -> list:
    # 定义正则表达式匹配 diff --git 行,并提取文件路径
    pattern = re.compile(r"diff --git (\S+) (\S+)")  # 匹配 diff --git 行
    files = []
    current_diff_content = []
    current_file_a, current_file_b = None, None
    preserve_ext_list = [f".{ext}" for ext in preserve_ext_list]

    # 遍历每行 diff 数据
    for line in diff_code.splitlines():
        match = pattern.match(line)
        if match:
            # 如果是新的 diff --git 行,处理前一个 diff
            if current_file_a and current_file_b:
                # 获取文件扩展名
                ext_a = os.path.splitext(current_file_a)[1]
                ext_b = os.path.splitext(current_file_b)[1]
                # 只保留指定扩展名的文件
                if len(preserve_ext_list) == 0 or ext_a in preserve_ext_list and ext_b in preserve_ext_list:
                    files.append((current_file_a, current_file_b, '\n'.join(current_diff_content)))
            
            # 更新当前文件路径
            current_file_a = match.group(1)
            current_file_b = match.group(2)
            current_diff_content = [line]  # 重置当前的 diff 内容,包含当前行
        else:
            if current_file_a and current_file_b:
                current_diff_content.append(line)

    # 处理最后一个 diff
    if current_file_a and current_file_b:
        ext_a = os.path.splitext(current_file_a)[1]
        ext_b = os.path.splitext(current_file_b)[1]
        if len(preserve_ext_list) == 0 or ext_a in preserve_ext_list and ext_b in preserve_ext_list:
            files.append((current_file_a, current_file_b, '\n'.join(current_diff_content)))

    return files



# 拆分一个change为多个变更点
def split_to_section(file_diff : str) -> list:
    # 使用正则表达式匹配@@区块和其中的变更内容
    # 正则匹配格式:以 @@ 开始,后接变更内容
    pattern = re.compile(r"@@.*?@@(\r?\n?)([\s\S]*?)(?=@@|\Z)", re.MULTILINE)
    change_blocks = []

    # 匹配所有变更区块
    for match in pattern.finditer(file_diff):
        # 获取变更内容
        block = match.group(0)
        # 按行拆分变更内容
        change_blocks.append(block)

    return change_blocks



if __name__ == "__main__":
# 测试用例
    diff = \
"""diff --git a/drivers/net/bonding/bond_main.c b/drivers/net/bonding/bond_main.c
index 71ba18efa15b..867664918715 100644
--- a/drivers/net/bonding/bond_main.c
+++ b/drivers/net/bonding/bond_main.c
@@ -1543,9 +1543,11 @@ int bond_enslave(struct net_device *bond_dev, struct net_device *slave_dev)
     bond_set_carrier(bond);
 
     if (USES_PRIMARY(bond->params.mode)) {
+        block_netpoll_tx();
         write_lock_bh(&bond->curr_slave_lock);
         bond_select_active_slave(bond);
         write_unlock_bh(&bond->curr_slave_lock);
+        unblock_netpoll_tx();
     }
 
     pr_info("%s: enslaving %s as a%s interface with a%s link.\n",
@@ -1571,10 +1573,12 @@ err_detach:
     if (bond->primary_slave == new_slave)
         bond->primary_slave = NULL;
     if (bond->curr_active_slave == new_slave) {
+        block_netpoll_tx();
         write_lock_bh(&bond->curr_slave_lock);
         bond_change_active_slave(bond, NULL);
         bond_select_active_slave(bond);
         write_unlock_bh(&bond->curr_slave_lock);
+        unblock_netpoll_tx();
     }
     slave_disable_netpoll(new_slave);
 
@@ -2864,9 +2868,12 @@ static int bond_slave_netdev_event(unsigned long event,
         pr_info("%s: Primary slave changed to %s, reselecting active slave.\
",
             bond->dev->name, bond->primary_slave ? slave_dev->name :
                                    "none");
+
+        block_netpoll_tx();
         write_lock_bh(&bond->curr_slave_lock);
         bond_select_active_slave(bond);
         write_unlock_bh(&bond->curr_slave_lock);
+        unblock_netpoll_tx();
         break;
     case NETDEV_FEAT_CHANGE:
         bond_compute_features(bond);
"""

    # 提取所有变更块
    changes = split_to_file_diff(diff, ['c'])
    for file_a, file_b, diff_content in changes:
        print(f"a: {file_a}, b: {file_b}")
        print(diff_content)
        print("=" * 50)

    change_blocks = split_to_section(changes[0][2])
    for idx, block in enumerate(change_blocks):
        print(f"Change Block {idx + 1}:")
        print(block)
        print("-" * 50)