Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import javalang
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch_geometric.data import Data
|
6 |
+
import re
|
7 |
+
import gradio as gr
|
8 |
+
from transformers import AutoTokenizer, AutoModel
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
# Configuration
|
12 |
+
MAX_FILE_SIZE = 5000
|
13 |
+
MAX_AST_DEPTH = 50
|
14 |
+
EMBEDDING_DIM = 128
|
15 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
|
17 |
+
# AST Encoder
|
18 |
+
class ASTNodeEncoder:
|
19 |
+
def __init__(self):
|
20 |
+
self.node_types = set()
|
21 |
+
self.type_to_idx = {}
|
22 |
+
|
23 |
+
def fit(self, ast_nodes):
|
24 |
+
self.node_types.update(ast_nodes)
|
25 |
+
self.type_to_idx = {t: i for i, t in enumerate(sorted(self.node_types))}
|
26 |
+
|
27 |
+
def encode(self, node_type):
|
28 |
+
if node_type not in self.type_to_idx:
|
29 |
+
return torch.zeros(EMBEDDING_DIM)
|
30 |
+
idx = self.type_to_idx[node_type]
|
31 |
+
embedding = torch.zeros(EMBEDDING_DIM)
|
32 |
+
embedding[idx % EMBEDDING_DIM] = 1
|
33 |
+
embedding += torch.randn(EMBEDDING_DIM) * 0.1
|
34 |
+
return embedding
|
35 |
+
|
36 |
+
# Code Normalization
|
37 |
+
def normalize_java_code(code):
|
38 |
+
code = re.sub(r'//.*', '', code)
|
39 |
+
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
|
40 |
+
code = re.sub(r'"[^"]*"', '"<STRING>"', code)
|
41 |
+
code = re.sub(r"'[^']*'", "'<CHAR>'", code)
|
42 |
+
return ' '.join(code.split())
|
43 |
+
|
44 |
+
# AST Processing
|
45 |
+
def extract_ast_paths(node, encoder, current_path=None, paths=None, depth=0):
|
46 |
+
if current_path is None:
|
47 |
+
current_path = []
|
48 |
+
if paths is None:
|
49 |
+
paths = []
|
50 |
+
|
51 |
+
if depth > MAX_AST_DEPTH:
|
52 |
+
return paths
|
53 |
+
|
54 |
+
node_type = str(type(node).__name__)
|
55 |
+
node_embedding = encoder.encode(node_type)
|
56 |
+
current_path.append(node_embedding)
|
57 |
+
|
58 |
+
if not hasattr(node, 'children') or depth == MAX_AST_DEPTH:
|
59 |
+
paths.append(torch.stack(current_path))
|
60 |
+
current_path.pop()
|
61 |
+
return paths
|
62 |
+
|
63 |
+
for child in node.children:
|
64 |
+
if isinstance(child, (javalang.ast.Node, list, tuple)):
|
65 |
+
if isinstance(child, (list, tuple)):
|
66 |
+
for c in child:
|
67 |
+
if isinstance(c, javalang.ast.Node):
|
68 |
+
extract_ast_paths(c, encoder, current_path, paths, depth+1)
|
69 |
+
elif isinstance(child, javalang.ast.Node):
|
70 |
+
extract_ast_paths(child, encoder, current_path, paths, depth+1)
|
71 |
+
|
72 |
+
current_path.pop()
|
73 |
+
return paths
|
74 |
+
|
75 |
+
def ast_to_graph_data(ast, encoder):
|
76 |
+
paths = extract_ast_paths(ast, encoder)
|
77 |
+
if not paths:
|
78 |
+
return None
|
79 |
+
|
80 |
+
edge_index = []
|
81 |
+
node_features = []
|
82 |
+
node_counter = 0
|
83 |
+
node_mapping = {}
|
84 |
+
|
85 |
+
for path in paths:
|
86 |
+
for i in range(len(path) - 1):
|
87 |
+
for j in [i, i+1]:
|
88 |
+
node_key = tuple(path[j].tolist())
|
89 |
+
if node_key not in node_mapping:
|
90 |
+
node_mapping[node_key] = node_counter
|
91 |
+
node_features.append(path[j])
|
92 |
+
node_counter += 1
|
93 |
+
|
94 |
+
src = node_mapping[tuple(path[i].tolist())]
|
95 |
+
dst = node_mapping[tuple(path[i+1].tolist())]
|
96 |
+
edge_index.append([src, dst])
|
97 |
+
|
98 |
+
if not edge_index:
|
99 |
+
return None
|
100 |
+
|
101 |
+
node_features = torch.stack(node_features)
|
102 |
+
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
|
103 |
+
return Data(x=node_features, edge_index=edge_index)
|
104 |
+
|
105 |
+
# Model Architecture
|
106 |
+
class ASTGNN(nn.Module):
|
107 |
+
def __init__(self, input_dim, hidden_dim):
|
108 |
+
super().__init__()
|
109 |
+
self.conv1 = nn.Sequential(
|
110 |
+
nn.Linear(input_dim, hidden_dim),
|
111 |
+
nn.ReLU(),
|
112 |
+
nn.Linear(hidden_dim, hidden_dim)
|
113 |
+
)
|
114 |
+
self.conv2 = nn.Sequential(
|
115 |
+
nn.Linear(hidden_dim, hidden_dim),
|
116 |
+
nn.ReLU(),
|
117 |
+
nn.Linear(hidden_dim, hidden_dim)
|
118 |
+
)
|
119 |
+
self.pool = nn.AdaptiveMaxPool1d(1)
|
120 |
+
|
121 |
+
def forward(self, data):
|
122 |
+
x, edge_index = data.x.to(DEVICE), data.edge_index.to(DEVICE)
|
123 |
+
x = self.conv1(x)
|
124 |
+
x = self.conv2(x)
|
125 |
+
x = x.t().unsqueeze(0)
|
126 |
+
x = self.pool(x)
|
127 |
+
return x.squeeze(0).squeeze(-1)
|
128 |
+
|
129 |
+
class HybridCloneDetector(nn.Module):
|
130 |
+
def __init__(self, ast_input_dim, hidden_dim):
|
131 |
+
super().__init__()
|
132 |
+
self.ast_gnn = ASTGNN(ast_input_dim, hidden_dim)
|
133 |
+
self.classifier = nn.Sequential(
|
134 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
135 |
+
nn.ReLU(),
|
136 |
+
nn.Linear(hidden_dim, 2))
|
137 |
+
|
138 |
+
def forward(self, ast_data, code_embedding):
|
139 |
+
ast_embed = self.ast_gnn(ast_data)
|
140 |
+
combined = torch.cat([ast_embed, code_embedding.squeeze(0)], dim=0)
|
141 |
+
return self.classifier(combined.unsqueeze(0))
|
142 |
+
|
143 |
+
# Load Models
|
144 |
+
def load_models():
|
145 |
+
ast_encoder = ASTNodeEncoder()
|
146 |
+
ast_encoder.fit(['MethodDeclaration', 'VariableDeclaration', 'IfStatement'])
|
147 |
+
|
148 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
|
149 |
+
code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE)
|
150 |
+
|
151 |
+
model = HybridCloneDetector(EMBEDDING_DIM, EMBEDDING_DIM).to(DEVICE)
|
152 |
+
if Path('model.pth').exists():
|
153 |
+
model.load_state_dict(torch.load('model.pth', map_location=DEVICE))
|
154 |
+
|
155 |
+
return ast_encoder, tokenizer, code_model, model
|
156 |
+
|
157 |
+
ast_encoder, tokenizer, code_model, model = load_models()
|
158 |
+
|
159 |
+
# Prediction Function
|
160 |
+
def predict_clone(code1, code2):
|
161 |
+
try:
|
162 |
+
# Process first code
|
163 |
+
norm_code1 = normalize_java_code(code1)
|
164 |
+
tokens1 = list(javalang.tokenizer.tokenize(norm_code1))
|
165 |
+
parser = javalang.parser.Parser(tokens1)
|
166 |
+
ast1 = parser.parse()
|
167 |
+
ast_data1 = ast_to_graph_data(ast1, ast_encoder)
|
168 |
+
|
169 |
+
inputs1 = tokenizer(norm_code1, return_tensors="pt", truncation=True).to(DEVICE)
|
170 |
+
with torch.no_grad():
|
171 |
+
code_embed1 = code_model(**inputs1).last_hidden_state.mean(dim=1)
|
172 |
+
|
173 |
+
# Process second code
|
174 |
+
norm_code2 = normalize_java_code(code2)
|
175 |
+
tokens2 = list(javalang.tokenizer.tokenize(norm_code2))
|
176 |
+
parser = javalang.parser.Parser(tokens2)
|
177 |
+
ast2 = parser.parse()
|
178 |
+
ast_data2 = ast_to_graph_data(ast2, ast_encoder)
|
179 |
+
|
180 |
+
inputs2 = tokenizer(norm_code2, return_tensors="pt", truncation=True).to(DEVICE)
|
181 |
+
with torch.no_grad():
|
182 |
+
code_embed2 = code_model(**inputs2).last_hidden_state.mean(dim=1)
|
183 |
+
|
184 |
+
# Predict
|
185 |
+
with torch.no_grad():
|
186 |
+
logits1 = model(ast_data1.to(DEVICE), code_embed1)
|
187 |
+
logits2 = model(ast_data2.to(DEVICE), code_embed2)
|
188 |
+
sim_score = F.cosine_similarity(logits1, logits2).item()
|
189 |
+
|
190 |
+
return {
|
191 |
+
"Similarity": f"{sim_score:.3f}",
|
192 |
+
"Clone": "Yes" if sim_score > 0.7 else "No"
|
193 |
+
}
|
194 |
+
except Exception as e:
|
195 |
+
return {"Error": str(e)}
|
196 |
+
|
197 |
+
# Gradio Interface
|
198 |
+
demo = gr.Interface(
|
199 |
+
fn=predict_clone,
|
200 |
+
inputs=[
|
201 |
+
gr.Textbox(label="First Java Code", lines=10),
|
202 |
+
gr.Textbox(label="Second Java Code", lines=10)
|
203 |
+
],
|
204 |
+
outputs=gr.JSON(label="Prediction"),
|
205 |
+
examples=[
|
206 |
+
["""public class Hello {
|
207 |
+
public static void main(String[] args) {
|
208 |
+
System.out.println("Hello, World!");
|
209 |
+
}
|
210 |
+
}""",
|
211 |
+
"""public class Greet {
|
212 |
+
public static void main(String[] args) {
|
213 |
+
System.out.println("Hello, World!");
|
214 |
+
}
|
215 |
+
}"""],
|
216 |
+
["""public int add(int a, int b) {
|
217 |
+
return a + b;
|
218 |
+
}""",
|
219 |
+
"""public int sum(int x, int y) {
|
220 |
+
return x + y;
|
221 |
+
}"""]
|
222 |
+
],
|
223 |
+
title="Java Code Clone Detector",
|
224 |
+
description="Detect code clones between two Java code snippets using AST and neural embeddings"
|
225 |
+
)
|
226 |
+
|
227 |
+
if __name__ == "__main__":
|
228 |
+
demo.launch()
|