File size: 7,528 Bytes
20e20dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import javalang
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
import re
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
# Configuration
MAX_FILE_SIZE = 5000
MAX_AST_DEPTH = 50
EMBEDDING_DIM = 128
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# AST Encoder
class ASTNodeEncoder:
def __init__(self):
self.node_types = set()
self.type_to_idx = {}
def fit(self, ast_nodes):
self.node_types.update(ast_nodes)
self.type_to_idx = {t: i for i, t in enumerate(sorted(self.node_types))}
def encode(self, node_type):
if node_type not in self.type_to_idx:
return torch.zeros(EMBEDDING_DIM)
idx = self.type_to_idx[node_type]
embedding = torch.zeros(EMBEDDING_DIM)
embedding[idx % EMBEDDING_DIM] = 1
embedding += torch.randn(EMBEDDING_DIM) * 0.1
return embedding
# Code Normalization
def normalize_java_code(code):
code = re.sub(r'//.*', '', code)
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
code = re.sub(r'"[^"]*"', '"<STRING>"', code)
code = re.sub(r"'[^']*'", "'<CHAR>'", code)
return ' '.join(code.split())
# AST Processing
def extract_ast_paths(node, encoder, current_path=None, paths=None, depth=0):
if current_path is None:
current_path = []
if paths is None:
paths = []
if depth > MAX_AST_DEPTH:
return paths
node_type = str(type(node).__name__)
node_embedding = encoder.encode(node_type)
current_path.append(node_embedding)
if not hasattr(node, 'children') or depth == MAX_AST_DEPTH:
paths.append(torch.stack(current_path))
current_path.pop()
return paths
for child in node.children:
if isinstance(child, (javalang.ast.Node, list, tuple)):
if isinstance(child, (list, tuple)):
for c in child:
if isinstance(c, javalang.ast.Node):
extract_ast_paths(c, encoder, current_path, paths, depth+1)
elif isinstance(child, javalang.ast.Node):
extract_ast_paths(child, encoder, current_path, paths, depth+1)
current_path.pop()
return paths
def ast_to_graph_data(ast, encoder):
paths = extract_ast_paths(ast, encoder)
if not paths:
return None
edge_index = []
node_features = []
node_counter = 0
node_mapping = {}
for path in paths:
for i in range(len(path) - 1):
for j in [i, i+1]:
node_key = tuple(path[j].tolist())
if node_key not in node_mapping:
node_mapping[node_key] = node_counter
node_features.append(path[j])
node_counter += 1
src = node_mapping[tuple(path[i].tolist())]
dst = node_mapping[tuple(path[i+1].tolist())]
edge_index.append([src, dst])
if not edge_index:
return None
node_features = torch.stack(node_features)
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
return Data(x=node_features, edge_index=edge_index)
# Model Architecture
class ASTGNN(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.conv1 = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.conv2 = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.pool = nn.AdaptiveMaxPool1d(1)
def forward(self, data):
x, edge_index = data.x.to(DEVICE), data.edge_index.to(DEVICE)
x = self.conv1(x)
x = self.conv2(x)
x = x.t().unsqueeze(0)
x = self.pool(x)
return x.squeeze(0).squeeze(-1)
class HybridCloneDetector(nn.Module):
def __init__(self, ast_input_dim, hidden_dim):
super().__init__()
self.ast_gnn = ASTGNN(ast_input_dim, hidden_dim)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2))
def forward(self, ast_data, code_embedding):
ast_embed = self.ast_gnn(ast_data)
combined = torch.cat([ast_embed, code_embedding.squeeze(0)], dim=0)
return self.classifier(combined.unsqueeze(0))
# Load Models
def load_models():
ast_encoder = ASTNodeEncoder()
ast_encoder.fit(['MethodDeclaration', 'VariableDeclaration', 'IfStatement'])
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE)
model = HybridCloneDetector(EMBEDDING_DIM, EMBEDDING_DIM).to(DEVICE)
if Path('model.pth').exists():
model.load_state_dict(torch.load('model.pth', map_location=DEVICE))
return ast_encoder, tokenizer, code_model, model
ast_encoder, tokenizer, code_model, model = load_models()
# Prediction Function
def predict_clone(code1, code2):
try:
# Process first code
norm_code1 = normalize_java_code(code1)
tokens1 = list(javalang.tokenizer.tokenize(norm_code1))
parser = javalang.parser.Parser(tokens1)
ast1 = parser.parse()
ast_data1 = ast_to_graph_data(ast1, ast_encoder)
inputs1 = tokenizer(norm_code1, return_tensors="pt", truncation=True).to(DEVICE)
with torch.no_grad():
code_embed1 = code_model(**inputs1).last_hidden_state.mean(dim=1)
# Process second code
norm_code2 = normalize_java_code(code2)
tokens2 = list(javalang.tokenizer.tokenize(norm_code2))
parser = javalang.parser.Parser(tokens2)
ast2 = parser.parse()
ast_data2 = ast_to_graph_data(ast2, ast_encoder)
inputs2 = tokenizer(norm_code2, return_tensors="pt", truncation=True).to(DEVICE)
with torch.no_grad():
code_embed2 = code_model(**inputs2).last_hidden_state.mean(dim=1)
# Predict
with torch.no_grad():
logits1 = model(ast_data1.to(DEVICE), code_embed1)
logits2 = model(ast_data2.to(DEVICE), code_embed2)
sim_score = F.cosine_similarity(logits1, logits2).item()
return {
"Similarity": f"{sim_score:.3f}",
"Clone": "Yes" if sim_score > 0.7 else "No"
}
except Exception as e:
return {"Error": str(e)}
# Gradio Interface
demo = gr.Interface(
fn=predict_clone,
inputs=[
gr.Textbox(label="First Java Code", lines=10),
gr.Textbox(label="Second Java Code", lines=10)
],
outputs=gr.JSON(label="Prediction"),
examples=[
["""public class Hello {
public static void main(String[] args) {
System.out.println("Hello, World!");
}
}""",
"""public class Greet {
public static void main(String[] args) {
System.out.println("Hello, World!");
}
}"""],
["""public int add(int a, int b) {
return a + b;
}""",
"""public int sum(int x, int y) {
return x + y;
}"""]
],
title="Java Code Clone Detector",
description="Detect code clones between two Java code snippets using AST and neural embeddings"
)
if __name__ == "__main__":
demo.launch() |