import logging from collections import defaultdict from transformers import pipeline, AutoTokenizer logger = logging.getLogger(__name__) class CweInferHelper(): TOP_K = 5 MAX_LENGTH = 1024 MODEL_CONFIG = defaultdict(lambda model: { "model_name_or_path": model, "tokenizer_name": model }) MODEL_CONFIG.update({ "patchouli-cwe-UniXcoder": { "model_name_or_path": "./backend/model/cwe-cls/patchouli-unixcoder", "tokenizer_name": "microsoft/unixcoder-base-nine" } }) PREDEF_MODEL = list(MODEL_CONFIG.keys()) def __init__(self): self.model = None self.classifier = None self.tokenizer = None def load_model(self, model): logger.info(f"Loading CWE classify model: {model}") if model == self.model: return self.model = model model_name_or_path = self.MODEL_CONFIG[model]["model_name_or_path"] tokenizer_name = self.MODEL_CONFIG[model]["tokenizer_name"] self.classifier = pipeline("text-classification", model=model_name_or_path, tokenizer=tokenizer_name, device_map="auto") self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) def infer(self, diff_code, patch_message = None): if self.classifier is None: raise ValueError("Model is not loaded") input_text = "" if patch_message is not None and patch_message != "": input_text += f"[MESSAGE]\n{patch_message}\n" input_text += f"[PATCH]\n{diff_code}" logger.info(f"Classifying CWE for diff code") input_ids = self.tokenizer(input_text, max_length=CweInferHelper.MAX_LENGTH-10, padding="max_length", truncation=True).input_ids input_text = self.tokenizer.decode(input_ids) result = self.classifier(input_text, top_k = self.TOP_K) result = {item["label"]: item["score"] for item in result} return result cwe_infer_helper = CweInferHelper() if __name__ == "__main__": code = """diff --git a/net/netfilter/ipvs/ip_vs_ctl.c b/net/netfilter/ipvs/ip_vs_ctl.c index 6bde12da2fe003..c37ac2d7bec44d 100644 --- a/net/netfilter/ipvs/ip_vs_ctl.c +++ b/net/netfilter/ipvs/ip_vs_ctl.c @@ -2077,6 +2077,10 @@ do_ip_vs_set_ctl(struct sock *sk, int cmd, void __user *user, unsigned int len) if (!capable(CAP_NET_ADMIN)) return -EPERM; + if (cmd < IP_VS_BASE_CTL || cmd > IP_VS_SO_SET_MAX) + return -EINVAL; + if (len < 0 || len > MAX_ARG_LEN) + return -EINVAL; if (len != set_arglen[SET_CMDID(cmd)]) { pr_err("set_ctl: len %u != %u\n", len, set_arglen[SET_CMDID(cmd)]); @@ -2352,17 +2356,25 @@ do_ip_vs_get_ctl(struct sock *sk, int cmd, void __user *user, int *len) { unsigned char arg[128]; int ret = 0; + unsigned int copylen; if (!capable(CAP_NET_ADMIN)) return -EPERM; + if (cmd < IP_VS_BASE_CTL || cmd > IP_VS_SO_GET_MAX) + return -EINVAL; + if (*len < get_arglen[GET_CMDID(cmd)]) { pr_err("get_ctl: len %u < %u\n", *len, get_arglen[GET_CMDID(cmd)]); return -EINVAL; } - if (copy_from_user(arg, user, get_arglen[GET_CMDID(cmd)]) != 0) + copylen = get_arglen[GET_CMDID(cmd)]; + if (copylen > 128) + return -EINVAL; + + if (copy_from_user(arg, user, copylen) != 0) return -EFAULT; if (mutex_lock_interruptible(&__ip_vs_mutex)) """ cwe_infer_helper.load_model("patchouli") result = cwe_infer_helper.infer(code) print(result)