CCD / app.py
rahideer's picture
Create app.py
20e20dc verified
raw
history blame
7.53 kB
import javalang
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
import re
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
# Configuration
MAX_FILE_SIZE = 5000
MAX_AST_DEPTH = 50
EMBEDDING_DIM = 128
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# AST Encoder
class ASTNodeEncoder:
def __init__(self):
self.node_types = set()
self.type_to_idx = {}
def fit(self, ast_nodes):
self.node_types.update(ast_nodes)
self.type_to_idx = {t: i for i, t in enumerate(sorted(self.node_types))}
def encode(self, node_type):
if node_type not in self.type_to_idx:
return torch.zeros(EMBEDDING_DIM)
idx = self.type_to_idx[node_type]
embedding = torch.zeros(EMBEDDING_DIM)
embedding[idx % EMBEDDING_DIM] = 1
embedding += torch.randn(EMBEDDING_DIM) * 0.1
return embedding
# Code Normalization
def normalize_java_code(code):
code = re.sub(r'//.*', '', code)
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
code = re.sub(r'"[^"]*"', '"<STRING>"', code)
code = re.sub(r"'[^']*'", "'<CHAR>'", code)
return ' '.join(code.split())
# AST Processing
def extract_ast_paths(node, encoder, current_path=None, paths=None, depth=0):
if current_path is None:
current_path = []
if paths is None:
paths = []
if depth > MAX_AST_DEPTH:
return paths
node_type = str(type(node).__name__)
node_embedding = encoder.encode(node_type)
current_path.append(node_embedding)
if not hasattr(node, 'children') or depth == MAX_AST_DEPTH:
paths.append(torch.stack(current_path))
current_path.pop()
return paths
for child in node.children:
if isinstance(child, (javalang.ast.Node, list, tuple)):
if isinstance(child, (list, tuple)):
for c in child:
if isinstance(c, javalang.ast.Node):
extract_ast_paths(c, encoder, current_path, paths, depth+1)
elif isinstance(child, javalang.ast.Node):
extract_ast_paths(child, encoder, current_path, paths, depth+1)
current_path.pop()
return paths
def ast_to_graph_data(ast, encoder):
paths = extract_ast_paths(ast, encoder)
if not paths:
return None
edge_index = []
node_features = []
node_counter = 0
node_mapping = {}
for path in paths:
for i in range(len(path) - 1):
for j in [i, i+1]:
node_key = tuple(path[j].tolist())
if node_key not in node_mapping:
node_mapping[node_key] = node_counter
node_features.append(path[j])
node_counter += 1
src = node_mapping[tuple(path[i].tolist())]
dst = node_mapping[tuple(path[i+1].tolist())]
edge_index.append([src, dst])
if not edge_index:
return None
node_features = torch.stack(node_features)
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
return Data(x=node_features, edge_index=edge_index)
# Model Architecture
class ASTGNN(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.conv1 = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.conv2 = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.pool = nn.AdaptiveMaxPool1d(1)
def forward(self, data):
x, edge_index = data.x.to(DEVICE), data.edge_index.to(DEVICE)
x = self.conv1(x)
x = self.conv2(x)
x = x.t().unsqueeze(0)
x = self.pool(x)
return x.squeeze(0).squeeze(-1)
class HybridCloneDetector(nn.Module):
def __init__(self, ast_input_dim, hidden_dim):
super().__init__()
self.ast_gnn = ASTGNN(ast_input_dim, hidden_dim)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2))
def forward(self, ast_data, code_embedding):
ast_embed = self.ast_gnn(ast_data)
combined = torch.cat([ast_embed, code_embedding.squeeze(0)], dim=0)
return self.classifier(combined.unsqueeze(0))
# Load Models
def load_models():
ast_encoder = ASTNodeEncoder()
ast_encoder.fit(['MethodDeclaration', 'VariableDeclaration', 'IfStatement'])
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE)
model = HybridCloneDetector(EMBEDDING_DIM, EMBEDDING_DIM).to(DEVICE)
if Path('model.pth').exists():
model.load_state_dict(torch.load('model.pth', map_location=DEVICE))
return ast_encoder, tokenizer, code_model, model
ast_encoder, tokenizer, code_model, model = load_models()
# Prediction Function
def predict_clone(code1, code2):
try:
# Process first code
norm_code1 = normalize_java_code(code1)
tokens1 = list(javalang.tokenizer.tokenize(norm_code1))
parser = javalang.parser.Parser(tokens1)
ast1 = parser.parse()
ast_data1 = ast_to_graph_data(ast1, ast_encoder)
inputs1 = tokenizer(norm_code1, return_tensors="pt", truncation=True).to(DEVICE)
with torch.no_grad():
code_embed1 = code_model(**inputs1).last_hidden_state.mean(dim=1)
# Process second code
norm_code2 = normalize_java_code(code2)
tokens2 = list(javalang.tokenizer.tokenize(norm_code2))
parser = javalang.parser.Parser(tokens2)
ast2 = parser.parse()
ast_data2 = ast_to_graph_data(ast2, ast_encoder)
inputs2 = tokenizer(norm_code2, return_tensors="pt", truncation=True).to(DEVICE)
with torch.no_grad():
code_embed2 = code_model(**inputs2).last_hidden_state.mean(dim=1)
# Predict
with torch.no_grad():
logits1 = model(ast_data1.to(DEVICE), code_embed1)
logits2 = model(ast_data2.to(DEVICE), code_embed2)
sim_score = F.cosine_similarity(logits1, logits2).item()
return {
"Similarity": f"{sim_score:.3f}",
"Clone": "Yes" if sim_score > 0.7 else "No"
}
except Exception as e:
return {"Error": str(e)}
# Gradio Interface
demo = gr.Interface(
fn=predict_clone,
inputs=[
gr.Textbox(label="First Java Code", lines=10),
gr.Textbox(label="Second Java Code", lines=10)
],
outputs=gr.JSON(label="Prediction"),
examples=[
["""public class Hello {
public static void main(String[] args) {
System.out.println("Hello, World!");
}
}""",
"""public class Greet {
public static void main(String[] args) {
System.out.println("Hello, World!");
}
}"""],
["""public int add(int a, int b) {
return a + b;
}""",
"""public int sum(int x, int y) {
return x + y;
}"""]
],
title="Java Code Clone Detector",
description="Detect code clones between two Java code snippets using AST and neural embeddings"
)
if __name__ == "__main__":
demo.launch()