Spaces:
Running
Running
import logging | |
from collections import defaultdict | |
from transformers import pipeline, AutoTokenizer | |
logger = logging.getLogger(__name__) | |
class CweInferHelper(): | |
TOP_K = 5 | |
MAX_LENGTH = 1024 | |
MODEL_CONFIG = defaultdict(lambda model: { | |
"model_name_or_path": model, | |
"tokenizer_name": model | |
}) | |
MODEL_CONFIG.update({ | |
"patchouli-cwe-UniXcoder": { | |
"model_name_or_path": "./backend/model/cwe-cls/patchouli-unixcoder", | |
"tokenizer_name": "microsoft/unixcoder-base-nine" | |
} | |
}) | |
PREDEF_MODEL = list(MODEL_CONFIG.keys()) | |
def __init__(self): | |
self.model = None | |
self.classifier = None | |
self.tokenizer = None | |
def load_model(self, model): | |
logger.info(f"Loading CWE classify model: {model}") | |
if model == self.model: | |
return | |
self.model = model | |
model_name_or_path = self.MODEL_CONFIG[model]["model_name_or_path"] | |
tokenizer_name = self.MODEL_CONFIG[model]["tokenizer_name"] | |
self.classifier = pipeline("text-classification", model=model_name_or_path, tokenizer=tokenizer_name, device_map="auto") | |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
def infer(self, diff_code, patch_message = None): | |
if self.classifier is None: | |
raise ValueError("Model is not loaded") | |
input_text = "" | |
if patch_message is not None and patch_message != "": | |
input_text += f"[MESSAGE]\n{patch_message}\n" | |
input_text += f"[PATCH]\n{diff_code}" | |
logger.info(f"Classifying CWE for diff code") | |
input_ids = self.tokenizer(input_text, max_length=CweInferHelper.MAX_LENGTH-10, padding="max_length", truncation=True).input_ids | |
input_text = self.tokenizer.decode(input_ids) | |
result = self.classifier(input_text, top_k = self.TOP_K) | |
result = {item["label"]: item["score"] for item in result} | |
return result | |
cwe_infer_helper = CweInferHelper() | |
if __name__ == "__main__": | |
code = """diff --git a/net/netfilter/ipvs/ip_vs_ctl.c b/net/netfilter/ipvs/ip_vs_ctl.c | |
index 6bde12da2fe003..c37ac2d7bec44d 100644 | |
--- a/net/netfilter/ipvs/ip_vs_ctl.c | |
+++ b/net/netfilter/ipvs/ip_vs_ctl.c | |
@@ -2077,6 +2077,10 @@ do_ip_vs_set_ctl(struct sock *sk, int cmd, void __user *user, unsigned int len) | |
if (!capable(CAP_NET_ADMIN)) | |
return -EPERM; | |
+ if (cmd < IP_VS_BASE_CTL || cmd > IP_VS_SO_SET_MAX) | |
+ return -EINVAL; | |
+ if (len < 0 || len > MAX_ARG_LEN) | |
+ return -EINVAL; | |
if (len != set_arglen[SET_CMDID(cmd)]) { | |
pr_err("set_ctl: len %u != %u\n", | |
len, set_arglen[SET_CMDID(cmd)]); | |
@@ -2352,17 +2356,25 @@ do_ip_vs_get_ctl(struct sock *sk, int cmd, void __user *user, int *len) | |
{ | |
unsigned char arg[128]; | |
int ret = 0; | |
+ unsigned int copylen; | |
if (!capable(CAP_NET_ADMIN)) | |
return -EPERM; | |
+ if (cmd < IP_VS_BASE_CTL || cmd > IP_VS_SO_GET_MAX) | |
+ return -EINVAL; | |
+ | |
if (*len < get_arglen[GET_CMDID(cmd)]) { | |
pr_err("get_ctl: len %u < %u\n", | |
*len, get_arglen[GET_CMDID(cmd)]); | |
return -EINVAL; | |
} | |
- if (copy_from_user(arg, user, get_arglen[GET_CMDID(cmd)]) != 0) | |
+ copylen = get_arglen[GET_CMDID(cmd)]; | |
+ if (copylen > 128) | |
+ return -EINVAL; | |
+ | |
+ if (copy_from_user(arg, user, copylen) != 0) | |
return -EFAULT; | |
if (mutex_lock_interruptible(&__ip_vs_mutex)) | |
""" | |
cwe_infer_helper.load_model("patchouli") | |
result = cwe_infer_helper.infer(code) | |
print(result) | |