|
import torch |
|
from transformers import RobertaTokenizer, RobertaModel |
|
import numpy as np |
|
from scipy.special import softmax |
|
import gradio as gr |
|
import re |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
class CodeClassifier(torch.nn.Module): |
|
def __init__(self, base_model, num_labels=6): |
|
super(CodeClassifier, self).__init__() |
|
self.base = base_model |
|
self.reduction = torch.nn.Linear(768, 512) |
|
self.classifier = torch.nn.Linear(512, num_labels) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.base(input_ids=input_ids, attention_mask=attention_mask) |
|
reduced = self.reduction(outputs.pooler_output) |
|
return self.classifier(reduced) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base') |
|
base_model = RobertaModel.from_pretrained('microsoft/codebert-base') |
|
|
|
model = CodeClassifier(base_model) |
|
|
|
checkpoint_path = hf_hub_download(repo_id="martynattakit/CodeSentinel-Model", filename="best_model.pt") |
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
model_state = checkpoint.get('model_state_dict', checkpoint) |
|
model.load_state_dict(model_state, strict=False) |
|
print("Loaded state dict keys:", model.state_dict().keys()) |
|
print("Classifier weight shape:", model.classifier.weight.shape) |
|
model.eval() |
|
model.to(device) |
|
|
|
|
|
label_map = { |
|
0: ('none', 'No Vulnerability Detected'), |
|
1: ('cwe-121', 'Stack-based Buffer Overflow'), |
|
2: ('cwe-78', 'OS Command Injection'), |
|
3: ('cwe-190', 'Integer Overflow or Wraparound'), |
|
4: ('cwe-191', 'Integer Underflow'), |
|
5: ('cwe-122', 'Heap-based Buffer Overflow') |
|
} |
|
|
|
def load_c_file(file): |
|
try: |
|
if file is None: |
|
return "" |
|
with open(file.name, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
return content |
|
except Exception as e: |
|
return f"Error reading file: {str(e)}" |
|
|
|
def clean_code(code): |
|
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) |
|
code = re.sub(r'//.*$', '', code, flags=re.MULTILINE) |
|
code = ' '.join(code.split()) |
|
return code |
|
|
|
def evaluate_code(code): |
|
try: |
|
if len(code) >= 1500000: |
|
return "Code too large" |
|
|
|
cleaned_code = clean_code(code) |
|
inputs = tokenizer(cleaned_code, return_tensors="pt", truncation=True, padding=True, max_length=256).to(device) |
|
print("Input shape:", inputs['input_ids'].shape) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
print("Raw logits:", outputs.cpu().numpy()) |
|
probs = softmax(outputs.cpu().numpy(), axis=1) |
|
pred = np.argmax(probs, axis=1)[0] |
|
cwe, description = label_map[pred] |
|
return f"{cwe} {description}" |
|
|
|
except Exception as e: |
|
return f"Error during prediction: {str(e)}" |
|
|
|
with gr.Blocks() as web: |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
code_box = gr.Textbox(lines=20, label="** C/C++ Code", placeholder="Paste your C or C++ code here...") |
|
with gr.Column(scale=1): |
|
cc_file = gr.File(label="Upload C/C++ File (.c or .cpp)", file_types=[".c", ".cpp"]) |
|
check_btn = gr.Button("Check") |
|
|
|
with gr.Row(): |
|
gr.Markdown("### Result:") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
label_box = gr.Textbox(label="Vulnerability", interactive=False) |
|
|
|
cc_file.change(fn=load_c_file, inputs=cc_file, outputs=code_box) |
|
check_btn.click(fn=evaluate_code, inputs=code_box, outputs=[label_box]) |
|
|
|
web.launch() |