traveler514's picture
first commit
81a794d
import logging
import torch
from torch.nn.functional import softmax
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from tqdm import tqdm
from collections import defaultdict
from backend.section_infer_helper.base_helper import BaseHelper
from backend.utils.data_process import split_to_file_diff, split_to_section
logger = logging.getLogger(__name__)
class LocalLLMHelper(BaseHelper):
MAX_LENGTH = 4096
MAX_NEW_TOKEN = 16
BATCH_SIZE = 4
SYSTEM_PROMPT = "You are now an expert in code vulnerability and patch fixes."
def generate_instruction(language, file_name, patch, section, message = None):
instruction = "[TASK]\nHere is a patch in {} language and a section of this patch for a source code file with path {}. Determine if the patch section fixes any software vulnerabilities. Output 'yes' or 'no' and do not output any other text.\n".format(language, file_name)
instruction += "[Patch]\n{}\n".format(patch)
instruction += "[A section of this patch]\n{}\n".format(section)
if message is not None and message != "":
instruction += "[Message of the Patch]\n{}\n".format(message)
return instruction
MODEL_CONFIGS = defaultdict(lambda: {
"supported_languages": ["C", "C++", "Java", "Python"],
})
MODEL_CONFIGS.update({
("Qwen/Qwen2.5-Coder-0.5B-Instruct", "backend/model/PEFT/patchouli-qwc2.5-0.5b"): {
"supported_languages": ["C", "C++", "Java", "Python"],
},
("Qwen/Qwen2.5-Coder-0.5B-Instruct", None): {
"supported_languages": ["C", "C++", "Java", "Python"],
},
("Qwen/Qwen2.5-Coder-7B-Instruct", None): {
"supported_languages": ["C", "C++", "Java", "Python"],
},
("deepseek-ai/deepseek-coder-7b-instruct-v1.5", None): {
"supported_languages": ["C", "C++", "Java", "Python"],
},
("codellama/CodeLlama-7b-Instruct-hf", None): {
"supported_languages": ["C", "C++", "Java", "Python"],
},
})
PREDEF_MODEL = []
for model, peft in MODEL_CONFIGS.keys():
if model not in PREDEF_MODEL:
PREDEF_MODEL.append(model)
MODEL_PEFT_MAP = defaultdict(lambda: [None])
for model, peft in MODEL_CONFIGS.keys():
if peft is not None:
MODEL_PEFT_MAP[model].append(peft)
def __init__(self):
self.model = None
self.tokenizer = None
self.model_name_or_path = None
self.peft_name_or_path = None
def __del__(self):
if self.model is not None:
self.release_model()
def infer(self, diff_code, message = None, batch_size = BATCH_SIZE):
if self.model is None:
raise RuntimeError("Model is not loaded")
results = {}
input_list = []
file_diff_list = split_to_file_diff(diff_code, BaseHelper._get_lang_ext(LocalLLMHelper.MODEL_CONFIGS[self.model_name_or_path]["supported_languages"]))
for file_a, _, file_diff in file_diff_list:
sections = split_to_section(file_diff)
file_name = file_a.removeprefix("a/")
results[file_name] = []
for section in sections:
input_list.append(BaseHelper.InputData(file_name, section, section, message))
input_prompt, output_text, output_prob = self.do_infer(input_list, batch_size)
assert len(input_list) == len(input_prompt) == len(output_text) == len(output_prob)
for i in range(len(input_list)):
file_name = input_list[i].filename
section = input_list[i].section
output_text_i = output_text[i].lower()
output_prob_i = output_prob[i]
results[file_name].append({
"section": section,
"predict": 1 if "yes" in output_text_i else 0,
"conf": output_prob_i
})
return results
def load_model(self, model_name_or_path, peft_name_or_path = None):
if model_name_or_path == self.model_name_or_path and peft_name_or_path == self.peft_name_or_path:
return
logger.info(f"Loading model {model_name_or_path}")
if self.model is not None:
self.release_model()
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.float32, device_map="auto")
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left")
if peft_name_or_path is not None and peft_name_or_path != "" and peft_name_or_path != "None":
logger.info(f"Loading PEFT model {peft_name_or_path}")
self.model = PeftModel.from_pretrained(self.model, peft_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(peft_name_or_path, padding_side="left")
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.eval()
self.model_name_or_path = model_name_or_path
self.peft_name_or_path = peft_name_or_path
logger.info(f"Model loaded")
def generate_message(filename, patch, section, patch_message = None):
ext = filename.split(".")[-1]
language = BaseHelper._get_lang_by_ext(ext)
messages = [
{
"role": "system",
"content": LocalLLMHelper.SYSTEM_PROMPT
},
{
"role": "user",
"content": LocalLLMHelper.generate_instruction(language, filename, patch, section, patch_message)
}
]
return messages
def release_model(self):
del self.model
del self.tokenizer
self.model = None
self.tokenizer = None
torch.cuda.empty_cache()
logger.info(f"Model {self.model_name_or_path} released")
self.model_name_or_path = None
def do_infer(self, input_list, batch_size = BATCH_SIZE):
if type(input_list) is not list:
input_list = [input_list]
input_data_batches = [input_list[i:i+batch_size] for i in range(0, len(input_list), batch_size)]
input_ids_list = []
if len(input_list) > 0:
logger.info("Example input prompt")
logger.info(LocalLLMHelper.generate_message(input_list[0].filename, input_list[0].patch, input_list[0].section, input_list[0].patch_msg))
for batch in tqdm(input_data_batches, desc="Tokenizing", unit="batch", total=len(input_data_batches)):
message_list = []
for input_data in batch:
message_list.append(LocalLLMHelper.generate_message(input_data.filename, input_data.patch, input_data.section, input_data.patch_msg))
input_ids_batch = self.tokenizer.apply_chat_template(
message_list,
add_generation_prompt=True,
return_tensors="pt",
max_length=LocalLLMHelper.MAX_LENGTH,
truncation=True,
padding=True)
input_ids_list.append(input_ids_batch)
input_prompt = []
output_text = []
output_prob = []
for input_ids in tqdm(input_ids_list, desc="Generating", unit="batch", total=len(input_ids_list)):
input_ids = input_ids.to(self.model.device)
outputs = self.model.generate(input_ids, max_new_tokens=LocalLLMHelper.MAX_NEW_TOKEN,
eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id,
output_logits=True, return_dict_in_generate=True)
input_prompt.extend(self.tokenizer.batch_decode(input_ids, skip_special_tokens=True))
output_text.extend(self.tokenizer.batch_decode(outputs.sequences[:, len(input_ids[0]):], skip_special_tokens=True))
batch_output_prob = softmax(outputs.logits[0], dim=-1).max(dim=-1).values
output_prob.extend([float(p) for p in batch_output_prob])
return input_prompt, output_text, output_prob
local_llm_helper = LocalLLMHelper()