import logging from openai import OpenAI from tqdm import tqdm from collections import defaultdict import traceback import httpx from backend.utils.data_process import split_to_file_diff, split_to_section from backend.section_infer_helper.base_helper import BaseHelper logger = logging.getLogger(__name__) class OnlineLLMHelper(BaseHelper): MAX_LENGTH = 4096 MAX_NEW_TOKENS = 16 PREDEF_MODEL = ["gpt-3.5-turbo", "deepseek-chat", "qwen-coder-plus", "gpt-4-turbo", "gpt-4o", "gemini-1.5-pro-latest", "claude-3-5-sonnet-20241022"] MODEL_CONFIGS = defaultdict(lambda: { "supported_languages": ["C", "C++", "Java", "Python"], }) SYSTEM_PROMPT = "You are an expert in code vulnerability and patch fixes." def generate_instruction(language, file_name, patch, section, message = None): instruction = "[TASK]\nHere is a patch in {} language and a section of this patch for a source code file with path {}. Determine if the patch section fixes any software vulnerabilities. Output 'yes' or 'no' and do not output any other text.\n".format(language, file_name) instruction += "[Patch]\n{}\n".format(patch) instruction += "[A section of this patch]\n{}\n".format(section) if message is not None and message != "": instruction += "[Message of the Patch]\n{}\n".format(message) return instruction def __init__(self): self.model_name = None self.url = None self.key = None def generate_message(filename, patch, section, patch_message = None): ext = filename.split(".")[-1] language = BaseHelper._get_lang_by_ext(ext) user_message = OnlineLLMHelper.generate_instruction(language, filename, patch, section, patch_message) user_message = user_message.split(" ") user_message = user_message[:OnlineLLMHelper.MAX_LENGTH] user_message = " ".join(user_message) messages = [ { "role": "system", "content": OnlineLLMHelper.SYSTEM_PROMPT }, { "role": "user", "content": user_message } ] return messages def load_model(self, model_name, url, api_key): self.model_name = model_name self.openai_client = OpenAI( base_url = url, api_key = api_key, timeout=httpx.Timeout(15.0) ) def infer(self, diff_code, message = None, batch_size=1): if self.model_name is None: raise RuntimeError("Model is not loaded") results = {} input_list = [] file_diff_list = split_to_file_diff(diff_code, BaseHelper._get_lang_ext(OnlineLLMHelper.MODEL_CONFIGS[self.model_name]["supported_languages"])) for file_a, _, file_diff in file_diff_list: sections = split_to_section(file_diff) file_name = file_a.removeprefix("a/") results[file_name] = [] for section in sections: input_list.append(BaseHelper.InputData(file_name, section, section, message)) input_prompt, output_text, output_prob = self.do_infer(input_list, batch_size) assert len(input_list) == len(input_prompt) == len(output_text) == len(output_prob) for i in range(len(input_list)): file_name = input_list[i].filename section = input_list[i].section output_text_i = output_text[i].lower() output_prob_i = output_prob[i] results[file_name].append({ "section": section, "predict": -1 if output_text_i == "error" else 1 if "yes" in output_text_i else 0, "conf": output_prob_i }) return results def do_infer(self, input_list, batch_size = 1): input_prompt = [] for input_data in input_list: input_prompt.append(OnlineLLMHelper.generate_message(input_data.filename, input_data.patch, input_data.section, input_data.patch_msg)) if len(input_prompt) > 0: logger.info("Example input prompt: %s", input_prompt[0]) output_text = [] for prompt, input_data in tqdm(zip(input_prompt, input_list), desc="Inferencing", unit = "section", total = len(input_prompt)): try: response = self.openai_client.chat.completions.create( messages = prompt, model = self.model_name, max_completion_tokens = OnlineLLMHelper.MAX_NEW_TOKENS ) output_text.append(response.choices[0].message.content) except KeyboardInterrupt: logging.error("KeyboardInterrupted") break except Exception as e: logger.error(f"Error: {e}") logger.error(f"Error inferencing: {input_data.filename} - {input_data.section}") logger.error(traceback.format_exc()) output_text.append("error") continue # break output_prob = [1.0] * len(output_text) return input_prompt, output_text, output_prob online_llm_helper = OnlineLLMHelper()