|
import javalang |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch_geometric.data import Data |
|
import re |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModel |
|
from pathlib import Path |
|
|
|
|
|
MAX_FILE_SIZE = 5000 |
|
MAX_AST_DEPTH = 50 |
|
EMBEDDING_DIM = 128 |
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
class ASTNodeEncoder: |
|
def __init__(self): |
|
self.node_types = set() |
|
self.type_to_idx = {} |
|
|
|
def fit(self, ast_nodes): |
|
self.node_types.update(ast_nodes) |
|
self.type_to_idx = {t: i for i, t in enumerate(sorted(self.node_types))} |
|
|
|
def encode(self, node_type): |
|
if node_type not in self.type_to_idx: |
|
return torch.zeros(EMBEDDING_DIM) |
|
idx = self.type_to_idx[node_type] |
|
embedding = torch.zeros(EMBEDDING_DIM) |
|
embedding[idx % EMBEDDING_DIM] = 1 |
|
embedding += torch.randn(EMBEDDING_DIM) * 0.1 |
|
return embedding |
|
|
|
|
|
def normalize_java_code(code): |
|
code = re.sub(r'//.*', '', code) |
|
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) |
|
code = re.sub(r'"[^"]*"', '"<STRING>"', code) |
|
code = re.sub(r"'[^']*'", "'<CHAR>'", code) |
|
return ' '.join(code.split()) |
|
|
|
|
|
def extract_ast_paths(node, encoder, current_path=None, paths=None, depth=0): |
|
if current_path is None: |
|
current_path = [] |
|
if paths is None: |
|
paths = [] |
|
|
|
if depth > MAX_AST_DEPTH: |
|
return paths |
|
|
|
node_type = str(type(node).__name__) |
|
node_embedding = encoder.encode(node_type) |
|
current_path.append(node_embedding) |
|
|
|
if not hasattr(node, 'children') or depth == MAX_AST_DEPTH: |
|
paths.append(torch.stack(current_path)) |
|
current_path.pop() |
|
return paths |
|
|
|
for child in node.children: |
|
if isinstance(child, (javalang.ast.Node, list, tuple)): |
|
if isinstance(child, (list, tuple)): |
|
for c in child: |
|
if isinstance(c, javalang.ast.Node): |
|
extract_ast_paths(c, encoder, current_path, paths, depth+1) |
|
elif isinstance(child, javalang.ast.Node): |
|
extract_ast_paths(child, encoder, current_path, paths, depth+1) |
|
|
|
current_path.pop() |
|
return paths |
|
|
|
def ast_to_graph_data(ast, encoder): |
|
paths = extract_ast_paths(ast, encoder) |
|
if not paths: |
|
return None |
|
|
|
edge_index = [] |
|
node_features = [] |
|
node_counter = 0 |
|
node_mapping = {} |
|
|
|
for path in paths: |
|
for i in range(len(path) - 1): |
|
for j in [i, i+1]: |
|
node_key = tuple(path[j].tolist()) |
|
if node_key not in node_mapping: |
|
node_mapping[node_key] = node_counter |
|
node_features.append(path[j]) |
|
node_counter += 1 |
|
|
|
src = node_mapping[tuple(path[i].tolist())] |
|
dst = node_mapping[tuple(path[i+1].tolist())] |
|
edge_index.append([src, dst]) |
|
|
|
if not edge_index: |
|
return None |
|
|
|
node_features = torch.stack(node_features) |
|
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() |
|
return Data(x=node_features, edge_index=edge_index) |
|
|
|
|
|
class ASTGNN(nn.Module): |
|
def __init__(self, input_dim, hidden_dim): |
|
super().__init__() |
|
self.conv1 = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, hidden_dim) |
|
) |
|
self.conv2 = nn.Sequential( |
|
nn.Linear(hidden_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, hidden_dim) |
|
) |
|
self.pool = nn.AdaptiveMaxPool1d(1) |
|
|
|
def forward(self, data): |
|
x, edge_index = data.x.to(DEVICE), data.edge_index.to(DEVICE) |
|
x = self.conv1(x) |
|
x = self.conv2(x) |
|
x = x.t().unsqueeze(0) |
|
x = self.pool(x) |
|
return x.squeeze(0).squeeze(-1) |
|
|
|
class HybridCloneDetector(nn.Module): |
|
def __init__(self, ast_input_dim, hidden_dim): |
|
super().__init__() |
|
self.ast_gnn = ASTGNN(ast_input_dim, hidden_dim) |
|
self.classifier = nn.Sequential( |
|
nn.Linear(hidden_dim * 2, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, 2)) |
|
|
|
def forward(self, ast_data, code_embedding): |
|
ast_embed = self.ast_gnn(ast_data) |
|
combined = torch.cat([ast_embed, code_embedding.squeeze(0)], dim=0) |
|
return self.classifier(combined.unsqueeze(0)) |
|
|
|
|
|
def load_models(): |
|
ast_encoder = ASTNodeEncoder() |
|
ast_encoder.fit(['MethodDeclaration', 'VariableDeclaration', 'IfStatement']) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") |
|
code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE) |
|
|
|
model = HybridCloneDetector(EMBEDDING_DIM, EMBEDDING_DIM).to(DEVICE) |
|
if Path('model.pth').exists(): |
|
model.load_state_dict(torch.load('model.pth', map_location=DEVICE)) |
|
|
|
return ast_encoder, tokenizer, code_model, model |
|
|
|
ast_encoder, tokenizer, code_model, model = load_models() |
|
|
|
|
|
def predict_clone(code1, code2): |
|
try: |
|
|
|
norm_code1 = normalize_java_code(code1) |
|
tokens1 = list(javalang.tokenizer.tokenize(norm_code1)) |
|
parser = javalang.parser.Parser(tokens1) |
|
ast1 = parser.parse() |
|
ast_data1 = ast_to_graph_data(ast1, ast_encoder) |
|
|
|
inputs1 = tokenizer(norm_code1, return_tensors="pt", truncation=True).to(DEVICE) |
|
with torch.no_grad(): |
|
code_embed1 = code_model(**inputs1).last_hidden_state.mean(dim=1) |
|
|
|
|
|
norm_code2 = normalize_java_code(code2) |
|
tokens2 = list(javalang.tokenizer.tokenize(norm_code2)) |
|
parser = javalang.parser.Parser(tokens2) |
|
ast2 = parser.parse() |
|
ast_data2 = ast_to_graph_data(ast2, ast_encoder) |
|
|
|
inputs2 = tokenizer(norm_code2, return_tensors="pt", truncation=True).to(DEVICE) |
|
with torch.no_grad(): |
|
code_embed2 = code_model(**inputs2).last_hidden_state.mean(dim=1) |
|
|
|
|
|
with torch.no_grad(): |
|
logits1 = model(ast_data1.to(DEVICE), code_embed1) |
|
logits2 = model(ast_data2.to(DEVICE), code_embed2) |
|
sim_score = F.cosine_similarity(logits1, logits2).item() |
|
|
|
return { |
|
"Similarity": f"{sim_score:.3f}", |
|
"Clone": "Yes" if sim_score > 0.7 else "No" |
|
} |
|
except Exception as e: |
|
return {"Error": str(e)} |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict_clone, |
|
inputs=[ |
|
gr.Textbox(label="First Java Code", lines=10), |
|
gr.Textbox(label="Second Java Code", lines=10) |
|
], |
|
outputs=gr.JSON(label="Prediction"), |
|
examples=[ |
|
["""public class Hello { |
|
public static void main(String[] args) { |
|
System.out.println("Hello, World!"); |
|
} |
|
}""", |
|
"""public class Greet { |
|
public static void main(String[] args) { |
|
System.out.println("Hello, World!"); |
|
} |
|
}"""], |
|
["""public int add(int a, int b) { |
|
return a + b; |
|
}""", |
|
"""public int sum(int x, int y) { |
|
return x + y; |
|
}"""] |
|
], |
|
title="Java Code Clone Detector", |
|
description="Detect code clones between two Java code snippets using AST and neural embeddings" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |