patchouli / backend /cwe_infer_helper.py
traveler514's picture
first commit
81a794d
import logging
from collections import defaultdict
from transformers import pipeline, AutoTokenizer
logger = logging.getLogger(__name__)
class CweInferHelper():
TOP_K = 5
MAX_LENGTH = 1024
MODEL_CONFIG = defaultdict(lambda model: {
"model_name_or_path": model,
"tokenizer_name": model
})
MODEL_CONFIG.update({
"patchouli-cwe-UniXcoder": {
"model_name_or_path": "./backend/model/cwe-cls/patchouli-unixcoder",
"tokenizer_name": "microsoft/unixcoder-base-nine"
}
})
PREDEF_MODEL = list(MODEL_CONFIG.keys())
def __init__(self):
self.model = None
self.classifier = None
self.tokenizer = None
def load_model(self, model):
logger.info(f"Loading CWE classify model: {model}")
if model == self.model:
return
self.model = model
model_name_or_path = self.MODEL_CONFIG[model]["model_name_or_path"]
tokenizer_name = self.MODEL_CONFIG[model]["tokenizer_name"]
self.classifier = pipeline("text-classification", model=model_name_or_path, tokenizer=tokenizer_name, device_map="auto")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def infer(self, diff_code, patch_message = None):
if self.classifier is None:
raise ValueError("Model is not loaded")
input_text = ""
if patch_message is not None and patch_message != "":
input_text += f"[MESSAGE]\n{patch_message}\n"
input_text += f"[PATCH]\n{diff_code}"
logger.info(f"Classifying CWE for diff code")
input_ids = self.tokenizer(input_text, max_length=CweInferHelper.MAX_LENGTH-10, padding="max_length", truncation=True).input_ids
input_text = self.tokenizer.decode(input_ids)
result = self.classifier(input_text, top_k = self.TOP_K)
result = {item["label"]: item["score"] for item in result}
return result
cwe_infer_helper = CweInferHelper()
if __name__ == "__main__":
code = """diff --git a/net/netfilter/ipvs/ip_vs_ctl.c b/net/netfilter/ipvs/ip_vs_ctl.c
index 6bde12da2fe003..c37ac2d7bec44d 100644
--- a/net/netfilter/ipvs/ip_vs_ctl.c
+++ b/net/netfilter/ipvs/ip_vs_ctl.c
@@ -2077,6 +2077,10 @@ do_ip_vs_set_ctl(struct sock *sk, int cmd, void __user *user, unsigned int len)
if (!capable(CAP_NET_ADMIN))
return -EPERM;
+ if (cmd < IP_VS_BASE_CTL || cmd > IP_VS_SO_SET_MAX)
+ return -EINVAL;
+ if (len < 0 || len > MAX_ARG_LEN)
+ return -EINVAL;
if (len != set_arglen[SET_CMDID(cmd)]) {
pr_err("set_ctl: len %u != %u\n",
len, set_arglen[SET_CMDID(cmd)]);
@@ -2352,17 +2356,25 @@ do_ip_vs_get_ctl(struct sock *sk, int cmd, void __user *user, int *len)
{
unsigned char arg[128];
int ret = 0;
+ unsigned int copylen;
if (!capable(CAP_NET_ADMIN))
return -EPERM;
+ if (cmd < IP_VS_BASE_CTL || cmd > IP_VS_SO_GET_MAX)
+ return -EINVAL;
+
if (*len < get_arglen[GET_CMDID(cmd)]) {
pr_err("get_ctl: len %u < %u\n",
*len, get_arglen[GET_CMDID(cmd)]);
return -EINVAL;
}
- if (copy_from_user(arg, user, get_arglen[GET_CMDID(cmd)]) != 0)
+ copylen = get_arglen[GET_CMDID(cmd)];
+ if (copylen > 128)
+ return -EINVAL;
+
+ if (copy_from_user(arg, user, copylen) != 0)
return -EFAULT;
if (mutex_lock_interruptible(&__ip_vs_mutex))
"""
cwe_infer_helper.load_model("patchouli")
result = cwe_infer_helper.infer(code)
print(result)