Spaces:
Running
Running
import logging | |
from openai import OpenAI | |
from tqdm import tqdm | |
from collections import defaultdict | |
import traceback | |
import httpx | |
from backend.utils.data_process import split_to_file_diff, split_to_section | |
from backend.section_infer_helper.base_helper import BaseHelper | |
logger = logging.getLogger(__name__) | |
class OnlineLLMHelper(BaseHelper): | |
MAX_LENGTH = 4096 | |
MAX_NEW_TOKENS = 16 | |
PREDEF_MODEL = ["gpt-3.5-turbo", "deepseek-chat", "qwen-coder-plus", "gpt-4-turbo", "gpt-4o", "gemini-1.5-pro-latest", "claude-3-5-sonnet-20241022"] | |
MODEL_CONFIGS = defaultdict(lambda: { | |
"supported_languages": ["C", "C++", "Java", "Python"], | |
}) | |
SYSTEM_PROMPT = "You are an expert in code vulnerability and patch fixes." | |
def generate_instruction(language, file_name, patch, section, message = None): | |
instruction = "[TASK]\nHere is a patch in {} language and a section of this patch for a source code file with path {}. Determine if the patch section fixes any software vulnerabilities. Output 'yes' or 'no' and do not output any other text.\n".format(language, file_name) | |
instruction += "[Patch]\n{}\n".format(patch) | |
instruction += "[A section of this patch]\n{}\n".format(section) | |
if message is not None and message != "": | |
instruction += "[Message of the Patch]\n{}\n".format(message) | |
return instruction | |
def __init__(self): | |
self.model_name = None | |
self.url = None | |
self.key = None | |
def generate_message(filename, patch, section, patch_message = None): | |
ext = filename.split(".")[-1] | |
language = BaseHelper._get_lang_by_ext(ext) | |
user_message = OnlineLLMHelper.generate_instruction(language, filename, patch, section, patch_message) | |
user_message = user_message.split(" ") | |
user_message = user_message[:OnlineLLMHelper.MAX_LENGTH] | |
user_message = " ".join(user_message) | |
messages = [ | |
{ | |
"role": "system", | |
"content": OnlineLLMHelper.SYSTEM_PROMPT | |
}, | |
{ | |
"role": "user", | |
"content": user_message | |
} | |
] | |
return messages | |
def load_model(self, model_name, url, api_key): | |
self.model_name = model_name | |
self.openai_client = OpenAI( | |
base_url = url, | |
api_key = api_key, | |
timeout=httpx.Timeout(15.0) | |
) | |
def infer(self, diff_code, message = None, batch_size=1): | |
if self.model_name is None: | |
raise RuntimeError("Model is not loaded") | |
results = {} | |
input_list = [] | |
file_diff_list = split_to_file_diff(diff_code, BaseHelper._get_lang_ext(OnlineLLMHelper.MODEL_CONFIGS[self.model_name]["supported_languages"])) | |
for file_a, _, file_diff in file_diff_list: | |
sections = split_to_section(file_diff) | |
file_name = file_a.removeprefix("a/") | |
results[file_name] = [] | |
for section in sections: | |
input_list.append(BaseHelper.InputData(file_name, section, section, message)) | |
input_prompt, output_text, output_prob = self.do_infer(input_list, batch_size) | |
assert len(input_list) == len(input_prompt) == len(output_text) == len(output_prob) | |
for i in range(len(input_list)): | |
file_name = input_list[i].filename | |
section = input_list[i].section | |
output_text_i = output_text[i].lower() | |
output_prob_i = output_prob[i] | |
results[file_name].append({ | |
"section": section, | |
"predict": -1 if output_text_i == "error" else 1 if "yes" in output_text_i else 0, | |
"conf": output_prob_i | |
}) | |
return results | |
def do_infer(self, input_list, batch_size = 1): | |
input_prompt = [] | |
for input_data in input_list: | |
input_prompt.append(OnlineLLMHelper.generate_message(input_data.filename, input_data.patch, input_data.section, input_data.patch_msg)) | |
if len(input_prompt) > 0: | |
logger.info("Example input prompt: %s", input_prompt[0]) | |
output_text = [] | |
for prompt, input_data in tqdm(zip(input_prompt, input_list), desc="Inferencing", unit = "section", total = len(input_prompt)): | |
try: | |
response = self.openai_client.chat.completions.create( | |
messages = prompt, | |
model = self.model_name, | |
max_completion_tokens = OnlineLLMHelper.MAX_NEW_TOKENS | |
) | |
output_text.append(response.choices[0].message.content) | |
except KeyboardInterrupt: | |
logging.error("KeyboardInterrupted") | |
break | |
except Exception as e: | |
logger.error(f"Error: {e}") | |
logger.error(f"Error inferencing: {input_data.filename} - {input_data.section}") | |
logger.error(traceback.format_exc()) | |
output_text.append("error") | |
continue | |
# break | |
output_prob = [1.0] * len(output_text) | |
return input_prompt, output_text, output_prob | |
online_llm_helper = OnlineLLMHelper() | |