import javalang import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.data import Data import re import gradio as gr from transformers import AutoTokenizer, AutoModel from pathlib import Path # Configuration MAX_FILE_SIZE = 5000 MAX_AST_DEPTH = 50 EMBEDDING_DIM = 128 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # AST Encoder class ASTNodeEncoder: def __init__(self): self.node_types = set() self.type_to_idx = {} def fit(self, ast_nodes): self.node_types.update(ast_nodes) self.type_to_idx = {t: i for i, t in enumerate(sorted(self.node_types))} def encode(self, node_type): if node_type not in self.type_to_idx: return torch.zeros(EMBEDDING_DIM) idx = self.type_to_idx[node_type] embedding = torch.zeros(EMBEDDING_DIM) embedding[idx % EMBEDDING_DIM] = 1 embedding += torch.randn(EMBEDDING_DIM) * 0.1 return embedding # Code Normalization def normalize_java_code(code): code = re.sub(r'//.*', '', code) code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) code = re.sub(r'"[^"]*"', '""', code) code = re.sub(r"'[^']*'", "''", code) return ' '.join(code.split()) # AST Processing def extract_ast_paths(node, encoder, current_path=None, paths=None, depth=0): if current_path is None: current_path = [] if paths is None: paths = [] if depth > MAX_AST_DEPTH: return paths node_type = str(type(node).__name__) node_embedding = encoder.encode(node_type) current_path.append(node_embedding) if not hasattr(node, 'children') or depth == MAX_AST_DEPTH: paths.append(torch.stack(current_path)) current_path.pop() return paths for child in node.children: if isinstance(child, (javalang.ast.Node, list, tuple)): if isinstance(child, (list, tuple)): for c in child: if isinstance(c, javalang.ast.Node): extract_ast_paths(c, encoder, current_path, paths, depth+1) elif isinstance(child, javalang.ast.Node): extract_ast_paths(child, encoder, current_path, paths, depth+1) current_path.pop() return paths def ast_to_graph_data(ast, encoder): paths = extract_ast_paths(ast, encoder) if not paths: return None edge_index = [] node_features = [] node_counter = 0 node_mapping = {} for path in paths: for i in range(len(path) - 1): for j in [i, i+1]: node_key = tuple(path[j].tolist()) if node_key not in node_mapping: node_mapping[node_key] = node_counter node_features.append(path[j]) node_counter += 1 src = node_mapping[tuple(path[i].tolist())] dst = node_mapping[tuple(path[i+1].tolist())] edge_index.append([src, dst]) if not edge_index: return None node_features = torch.stack(node_features) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return Data(x=node_features, edge_index=edge_index) # Model Architecture class ASTGNN(nn.Module): def __init__(self, input_dim, hidden_dim): super().__init__() self.conv1 = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim) ) self.conv2 = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim) ) self.pool = nn.AdaptiveMaxPool1d(1) def forward(self, data): x, edge_index = data.x.to(DEVICE), data.edge_index.to(DEVICE) x = self.conv1(x) x = self.conv2(x) x = x.t().unsqueeze(0) x = self.pool(x) return x.squeeze(0).squeeze(-1) class HybridCloneDetector(nn.Module): def __init__(self, ast_input_dim, hidden_dim): super().__init__() self.ast_gnn = ASTGNN(ast_input_dim, hidden_dim) self.classifier = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 2)) def forward(self, ast_data, code_embedding): ast_embed = self.ast_gnn(ast_data) combined = torch.cat([ast_embed, code_embedding.squeeze(0)], dim=0) return self.classifier(combined.unsqueeze(0)) # Load Models def load_models(): ast_encoder = ASTNodeEncoder() ast_encoder.fit(['MethodDeclaration', 'VariableDeclaration', 'IfStatement']) tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE) model = HybridCloneDetector(EMBEDDING_DIM, EMBEDDING_DIM).to(DEVICE) if Path('model.pth').exists(): model.load_state_dict(torch.load('model.pth', map_location=DEVICE)) return ast_encoder, tokenizer, code_model, model ast_encoder, tokenizer, code_model, model = load_models() # Prediction Function def predict_clone(code1, code2): try: # Process first code norm_code1 = normalize_java_code(code1) tokens1 = list(javalang.tokenizer.tokenize(norm_code1)) parser = javalang.parser.Parser(tokens1) ast1 = parser.parse() ast_data1 = ast_to_graph_data(ast1, ast_encoder) inputs1 = tokenizer(norm_code1, return_tensors="pt", truncation=True).to(DEVICE) with torch.no_grad(): code_embed1 = code_model(**inputs1).last_hidden_state.mean(dim=1) # Process second code norm_code2 = normalize_java_code(code2) tokens2 = list(javalang.tokenizer.tokenize(norm_code2)) parser = javalang.parser.Parser(tokens2) ast2 = parser.parse() ast_data2 = ast_to_graph_data(ast2, ast_encoder) inputs2 = tokenizer(norm_code2, return_tensors="pt", truncation=True).to(DEVICE) with torch.no_grad(): code_embed2 = code_model(**inputs2).last_hidden_state.mean(dim=1) # Predict with torch.no_grad(): logits1 = model(ast_data1.to(DEVICE), code_embed1) logits2 = model(ast_data2.to(DEVICE), code_embed2) sim_score = F.cosine_similarity(logits1, logits2).item() return { "Similarity": f"{sim_score:.3f}", "Clone": "Yes" if sim_score > 0.7 else "No" } except Exception as e: return {"Error": str(e)} # Gradio Interface demo = gr.Interface( fn=predict_clone, inputs=[ gr.Textbox(label="First Java Code", lines=10), gr.Textbox(label="Second Java Code", lines=10) ], outputs=gr.JSON(label="Prediction"), examples=[ ["""public class Hello { public static void main(String[] args) { System.out.println("Hello, World!"); } }""", """public class Greet { public static void main(String[] args) { System.out.println("Hello, World!"); } }"""], ["""public int add(int a, int b) { return a + b; }""", """public int sum(int x, int y) { return x + y; }"""] ], title="Java Code Clone Detector", description="Detect code clones between two Java code snippets using AST and neural embeddings" ) if __name__ == "__main__": demo.launch()