patchouli / backend /section_infer_helper /online_llm_helper.py
traveler514's picture
first commit
81a794d
import logging
from openai import OpenAI
from tqdm import tqdm
from collections import defaultdict
import traceback
import httpx
from backend.utils.data_process import split_to_file_diff, split_to_section
from backend.section_infer_helper.base_helper import BaseHelper
logger = logging.getLogger(__name__)
class OnlineLLMHelper(BaseHelper):
MAX_LENGTH = 4096
MAX_NEW_TOKENS = 16
PREDEF_MODEL = ["gpt-3.5-turbo", "deepseek-chat", "qwen-coder-plus", "gpt-4-turbo", "gpt-4o", "gemini-1.5-pro-latest", "claude-3-5-sonnet-20241022"]
MODEL_CONFIGS = defaultdict(lambda: {
"supported_languages": ["C", "C++", "Java", "Python"],
})
SYSTEM_PROMPT = "You are an expert in code vulnerability and patch fixes."
def generate_instruction(language, file_name, patch, section, message = None):
instruction = "[TASK]\nHere is a patch in {} language and a section of this patch for a source code file with path {}. Determine if the patch section fixes any software vulnerabilities. Output 'yes' or 'no' and do not output any other text.\n".format(language, file_name)
instruction += "[Patch]\n{}\n".format(patch)
instruction += "[A section of this patch]\n{}\n".format(section)
if message is not None and message != "":
instruction += "[Message of the Patch]\n{}\n".format(message)
return instruction
def __init__(self):
self.model_name = None
self.url = None
self.key = None
def generate_message(filename, patch, section, patch_message = None):
ext = filename.split(".")[-1]
language = BaseHelper._get_lang_by_ext(ext)
user_message = OnlineLLMHelper.generate_instruction(language, filename, patch, section, patch_message)
user_message = user_message.split(" ")
user_message = user_message[:OnlineLLMHelper.MAX_LENGTH]
user_message = " ".join(user_message)
messages = [
{
"role": "system",
"content": OnlineLLMHelper.SYSTEM_PROMPT
},
{
"role": "user",
"content": user_message
}
]
return messages
def load_model(self, model_name, url, api_key):
self.model_name = model_name
self.openai_client = OpenAI(
base_url = url,
api_key = api_key,
timeout=httpx.Timeout(15.0)
)
def infer(self, diff_code, message = None, batch_size=1):
if self.model_name is None:
raise RuntimeError("Model is not loaded")
results = {}
input_list = []
file_diff_list = split_to_file_diff(diff_code, BaseHelper._get_lang_ext(OnlineLLMHelper.MODEL_CONFIGS[self.model_name]["supported_languages"]))
for file_a, _, file_diff in file_diff_list:
sections = split_to_section(file_diff)
file_name = file_a.removeprefix("a/")
results[file_name] = []
for section in sections:
input_list.append(BaseHelper.InputData(file_name, section, section, message))
input_prompt, output_text, output_prob = self.do_infer(input_list, batch_size)
assert len(input_list) == len(input_prompt) == len(output_text) == len(output_prob)
for i in range(len(input_list)):
file_name = input_list[i].filename
section = input_list[i].section
output_text_i = output_text[i].lower()
output_prob_i = output_prob[i]
results[file_name].append({
"section": section,
"predict": -1 if output_text_i == "error" else 1 if "yes" in output_text_i else 0,
"conf": output_prob_i
})
return results
def do_infer(self, input_list, batch_size = 1):
input_prompt = []
for input_data in input_list:
input_prompt.append(OnlineLLMHelper.generate_message(input_data.filename, input_data.patch, input_data.section, input_data.patch_msg))
if len(input_prompt) > 0:
logger.info("Example input prompt: %s", input_prompt[0])
output_text = []
for prompt, input_data in tqdm(zip(input_prompt, input_list), desc="Inferencing", unit = "section", total = len(input_prompt)):
try:
response = self.openai_client.chat.completions.create(
messages = prompt,
model = self.model_name,
max_completion_tokens = OnlineLLMHelper.MAX_NEW_TOKENS
)
output_text.append(response.choices[0].message.content)
except KeyboardInterrupt:
logging.error("KeyboardInterrupted")
break
except Exception as e:
logger.error(f"Error: {e}")
logger.error(f"Error inferencing: {input_data.filename} - {input_data.section}")
logger.error(traceback.format_exc())
output_text.append("error")
continue
# break
output_prob = [1.0] * len(output_text)
return input_prompt, output_text, output_prob
online_llm_helper = OnlineLLMHelper()