File size: 3,506 Bytes
81a794d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import logging
from collections import defaultdict
from transformers import pipeline, AutoTokenizer

logger = logging.getLogger(__name__)

class CweInferHelper():
    
    TOP_K = 5
    MAX_LENGTH = 1024
    
    MODEL_CONFIG = defaultdict(lambda model: {
        "model_name_or_path": model,
        "tokenizer_name": model
    })
    
    MODEL_CONFIG.update({
        "patchouli-cwe-UniXcoder": {
            "model_name_or_path": "./backend/model/cwe-cls/patchouli-unixcoder",
            "tokenizer_name": "microsoft/unixcoder-base-nine"
        }
    })
    
    PREDEF_MODEL = list(MODEL_CONFIG.keys())
    
    def __init__(self):
        self.model = None
        self.classifier = None
        self.tokenizer = None
    
    
    def load_model(self, model):
        logger.info(f"Loading CWE classify model: {model}")
        if model == self.model:
            return
        self.model = model
        model_name_or_path = self.MODEL_CONFIG[model]["model_name_or_path"]
        tokenizer_name = self.MODEL_CONFIG[model]["tokenizer_name"]
        self.classifier = pipeline("text-classification", model=model_name_or_path, tokenizer=tokenizer_name, device_map="auto")
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        
    
    
    def infer(self, diff_code, patch_message = None):
        if self.classifier is None:
            raise ValueError("Model is not loaded")
        input_text = ""
        if patch_message is not None and patch_message != "":
            input_text += f"[MESSAGE]\n{patch_message}\n"
        input_text += f"[PATCH]\n{diff_code}"
        logger.info(f"Classifying CWE for diff code")
        input_ids = self.tokenizer(input_text, max_length=CweInferHelper.MAX_LENGTH-10, padding="max_length", truncation=True).input_ids
        input_text = self.tokenizer.decode(input_ids)
        result = self.classifier(input_text, top_k = self.TOP_K)
        result = {item["label"]: item["score"] for item in result}
        return result


cwe_infer_helper = CweInferHelper()

    
if __name__ == "__main__":
    code = """diff --git a/net/netfilter/ipvs/ip_vs_ctl.c b/net/netfilter/ipvs/ip_vs_ctl.c
index 6bde12da2fe003..c37ac2d7bec44d 100644
--- a/net/netfilter/ipvs/ip_vs_ctl.c
+++ b/net/netfilter/ipvs/ip_vs_ctl.c
@@ -2077,6 +2077,10 @@ do_ip_vs_set_ctl(struct sock *sk, int cmd, void __user *user, unsigned int len)
 	if (!capable(CAP_NET_ADMIN))
 		return -EPERM;
 
+	if (cmd < IP_VS_BASE_CTL || cmd > IP_VS_SO_SET_MAX)
+		return -EINVAL;
+	if (len < 0 || len >  MAX_ARG_LEN)
+		return -EINVAL;
 	if (len != set_arglen[SET_CMDID(cmd)]) {
 		pr_err("set_ctl: len %u != %u\n",
 		       len, set_arglen[SET_CMDID(cmd)]);
@@ -2352,17 +2356,25 @@ do_ip_vs_get_ctl(struct sock *sk, int cmd, void __user *user, int *len)
 {
 	unsigned char arg[128];
 	int ret = 0;
+	unsigned int copylen;
 
 	if (!capable(CAP_NET_ADMIN))
 		return -EPERM;
 
+	if (cmd < IP_VS_BASE_CTL || cmd > IP_VS_SO_GET_MAX)
+		return -EINVAL;
+
 	if (*len < get_arglen[GET_CMDID(cmd)]) {
 		pr_err("get_ctl: len %u < %u\n",
 		       *len, get_arglen[GET_CMDID(cmd)]);
 		return -EINVAL;
 	}
 
-	if (copy_from_user(arg, user, get_arglen[GET_CMDID(cmd)]) != 0)
+	copylen = get_arglen[GET_CMDID(cmd)];
+	if (copylen > 128)
+		return -EINVAL;
+
+	if (copy_from_user(arg, user, copylen) != 0)
 		return -EFAULT;
 
 	if (mutex_lock_interruptible(&__ip_vs_mutex))
"""
    cwe_infer_helper.load_model("patchouli")
    result = cwe_infer_helper.infer(code)
    print(result)