Spaces:
Running
Running
import logging | |
import torch | |
from torch.nn.functional import softmax | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
from tqdm import tqdm | |
from collections import defaultdict | |
from backend.section_infer_helper.base_helper import BaseHelper | |
from backend.utils.data_process import split_to_file_diff, split_to_section | |
logger = logging.getLogger(__name__) | |
class LocalLLMHelper(BaseHelper): | |
MAX_LENGTH = 4096 | |
MAX_NEW_TOKEN = 16 | |
BATCH_SIZE = 4 | |
SYSTEM_PROMPT = "You are now an expert in code vulnerability and patch fixes." | |
def generate_instruction(language, file_name, patch, section, message = None): | |
instruction = "[TASK]\nHere is a patch in {} language and a section of this patch for a source code file with path {}. Determine if the patch section fixes any software vulnerabilities. Output 'yes' or 'no' and do not output any other text.\n".format(language, file_name) | |
instruction += "[Patch]\n{}\n".format(patch) | |
instruction += "[A section of this patch]\n{}\n".format(section) | |
if message is not None and message != "": | |
instruction += "[Message of the Patch]\n{}\n".format(message) | |
return instruction | |
MODEL_CONFIGS = defaultdict(lambda: { | |
"supported_languages": ["C", "C++", "Java", "Python"], | |
}) | |
MODEL_CONFIGS.update({ | |
("Qwen/Qwen2.5-Coder-0.5B-Instruct", "backend/model/PEFT/patchouli-qwc2.5-0.5b"): { | |
"supported_languages": ["C", "C++", "Java", "Python"], | |
}, | |
("Qwen/Qwen2.5-Coder-0.5B-Instruct", None): { | |
"supported_languages": ["C", "C++", "Java", "Python"], | |
}, | |
("Qwen/Qwen2.5-Coder-7B-Instruct", None): { | |
"supported_languages": ["C", "C++", "Java", "Python"], | |
}, | |
("deepseek-ai/deepseek-coder-7b-instruct-v1.5", None): { | |
"supported_languages": ["C", "C++", "Java", "Python"], | |
}, | |
("codellama/CodeLlama-7b-Instruct-hf", None): { | |
"supported_languages": ["C", "C++", "Java", "Python"], | |
}, | |
}) | |
PREDEF_MODEL = [] | |
for model, peft in MODEL_CONFIGS.keys(): | |
if model not in PREDEF_MODEL: | |
PREDEF_MODEL.append(model) | |
MODEL_PEFT_MAP = defaultdict(lambda: [None]) | |
for model, peft in MODEL_CONFIGS.keys(): | |
if peft is not None: | |
MODEL_PEFT_MAP[model].append(peft) | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
self.model_name_or_path = None | |
self.peft_name_or_path = None | |
def __del__(self): | |
if self.model is not None: | |
self.release_model() | |
def infer(self, diff_code, message = None, batch_size = BATCH_SIZE): | |
if self.model is None: | |
raise RuntimeError("Model is not loaded") | |
results = {} | |
input_list = [] | |
file_diff_list = split_to_file_diff(diff_code, BaseHelper._get_lang_ext(LocalLLMHelper.MODEL_CONFIGS[self.model_name_or_path]["supported_languages"])) | |
for file_a, _, file_diff in file_diff_list: | |
sections = split_to_section(file_diff) | |
file_name = file_a.removeprefix("a/") | |
results[file_name] = [] | |
for section in sections: | |
input_list.append(BaseHelper.InputData(file_name, section, section, message)) | |
input_prompt, output_text, output_prob = self.do_infer(input_list, batch_size) | |
assert len(input_list) == len(input_prompt) == len(output_text) == len(output_prob) | |
for i in range(len(input_list)): | |
file_name = input_list[i].filename | |
section = input_list[i].section | |
output_text_i = output_text[i].lower() | |
output_prob_i = output_prob[i] | |
results[file_name].append({ | |
"section": section, | |
"predict": 1 if "yes" in output_text_i else 0, | |
"conf": output_prob_i | |
}) | |
return results | |
def load_model(self, model_name_or_path, peft_name_or_path = None): | |
if model_name_or_path == self.model_name_or_path and peft_name_or_path == self.peft_name_or_path: | |
return | |
logger.info(f"Loading model {model_name_or_path}") | |
if self.model is not None: | |
self.release_model() | |
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.float32, device_map="auto") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left") | |
if peft_name_or_path is not None and peft_name_or_path != "" and peft_name_or_path != "None": | |
logger.info(f"Loading PEFT model {peft_name_or_path}") | |
self.model = PeftModel.from_pretrained(self.model, peft_name_or_path) | |
self.tokenizer = AutoTokenizer.from_pretrained(peft_name_or_path, padding_side="left") | |
if self.tokenizer.pad_token_id is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.model.eval() | |
self.model_name_or_path = model_name_or_path | |
self.peft_name_or_path = peft_name_or_path | |
logger.info(f"Model loaded") | |
def generate_message(filename, patch, section, patch_message = None): | |
ext = filename.split(".")[-1] | |
language = BaseHelper._get_lang_by_ext(ext) | |
messages = [ | |
{ | |
"role": "system", | |
"content": LocalLLMHelper.SYSTEM_PROMPT | |
}, | |
{ | |
"role": "user", | |
"content": LocalLLMHelper.generate_instruction(language, filename, patch, section, patch_message) | |
} | |
] | |
return messages | |
def release_model(self): | |
del self.model | |
del self.tokenizer | |
self.model = None | |
self.tokenizer = None | |
torch.cuda.empty_cache() | |
logger.info(f"Model {self.model_name_or_path} released") | |
self.model_name_or_path = None | |
def do_infer(self, input_list, batch_size = BATCH_SIZE): | |
if type(input_list) is not list: | |
input_list = [input_list] | |
input_data_batches = [input_list[i:i+batch_size] for i in range(0, len(input_list), batch_size)] | |
input_ids_list = [] | |
if len(input_list) > 0: | |
logger.info("Example input prompt") | |
logger.info(LocalLLMHelper.generate_message(input_list[0].filename, input_list[0].patch, input_list[0].section, input_list[0].patch_msg)) | |
for batch in tqdm(input_data_batches, desc="Tokenizing", unit="batch", total=len(input_data_batches)): | |
message_list = [] | |
for input_data in batch: | |
message_list.append(LocalLLMHelper.generate_message(input_data.filename, input_data.patch, input_data.section, input_data.patch_msg)) | |
input_ids_batch = self.tokenizer.apply_chat_template( | |
message_list, | |
add_generation_prompt=True, | |
return_tensors="pt", | |
max_length=LocalLLMHelper.MAX_LENGTH, | |
truncation=True, | |
padding=True) | |
input_ids_list.append(input_ids_batch) | |
input_prompt = [] | |
output_text = [] | |
output_prob = [] | |
for input_ids in tqdm(input_ids_list, desc="Generating", unit="batch", total=len(input_ids_list)): | |
input_ids = input_ids.to(self.model.device) | |
outputs = self.model.generate(input_ids, max_new_tokens=LocalLLMHelper.MAX_NEW_TOKEN, | |
eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, | |
output_logits=True, return_dict_in_generate=True) | |
input_prompt.extend(self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)) | |
output_text.extend(self.tokenizer.batch_decode(outputs.sequences[:, len(input_ids[0]):], skip_special_tokens=True)) | |
batch_output_prob = softmax(outputs.logits[0], dim=-1).max(dim=-1).values | |
output_prob.extend([float(p) for p in batch_output_prob]) | |
return input_prompt, output_text, output_prob | |
local_llm_helper = LocalLLMHelper() | |