rahideer commited on
Commit
20e20dc
·
verified ·
1 Parent(s): d015db1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import javalang
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch_geometric.data import Data
6
+ import re
7
+ import gradio as gr
8
+ from transformers import AutoTokenizer, AutoModel
9
+ from pathlib import Path
10
+
11
+ # Configuration
12
+ MAX_FILE_SIZE = 5000
13
+ MAX_AST_DEPTH = 50
14
+ EMBEDDING_DIM = 128
15
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+ # AST Encoder
18
+ class ASTNodeEncoder:
19
+ def __init__(self):
20
+ self.node_types = set()
21
+ self.type_to_idx = {}
22
+
23
+ def fit(self, ast_nodes):
24
+ self.node_types.update(ast_nodes)
25
+ self.type_to_idx = {t: i for i, t in enumerate(sorted(self.node_types))}
26
+
27
+ def encode(self, node_type):
28
+ if node_type not in self.type_to_idx:
29
+ return torch.zeros(EMBEDDING_DIM)
30
+ idx = self.type_to_idx[node_type]
31
+ embedding = torch.zeros(EMBEDDING_DIM)
32
+ embedding[idx % EMBEDDING_DIM] = 1
33
+ embedding += torch.randn(EMBEDDING_DIM) * 0.1
34
+ return embedding
35
+
36
+ # Code Normalization
37
+ def normalize_java_code(code):
38
+ code = re.sub(r'//.*', '', code)
39
+ code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
40
+ code = re.sub(r'"[^"]*"', '"<STRING>"', code)
41
+ code = re.sub(r"'[^']*'", "'<CHAR>'", code)
42
+ return ' '.join(code.split())
43
+
44
+ # AST Processing
45
+ def extract_ast_paths(node, encoder, current_path=None, paths=None, depth=0):
46
+ if current_path is None:
47
+ current_path = []
48
+ if paths is None:
49
+ paths = []
50
+
51
+ if depth > MAX_AST_DEPTH:
52
+ return paths
53
+
54
+ node_type = str(type(node).__name__)
55
+ node_embedding = encoder.encode(node_type)
56
+ current_path.append(node_embedding)
57
+
58
+ if not hasattr(node, 'children') or depth == MAX_AST_DEPTH:
59
+ paths.append(torch.stack(current_path))
60
+ current_path.pop()
61
+ return paths
62
+
63
+ for child in node.children:
64
+ if isinstance(child, (javalang.ast.Node, list, tuple)):
65
+ if isinstance(child, (list, tuple)):
66
+ for c in child:
67
+ if isinstance(c, javalang.ast.Node):
68
+ extract_ast_paths(c, encoder, current_path, paths, depth+1)
69
+ elif isinstance(child, javalang.ast.Node):
70
+ extract_ast_paths(child, encoder, current_path, paths, depth+1)
71
+
72
+ current_path.pop()
73
+ return paths
74
+
75
+ def ast_to_graph_data(ast, encoder):
76
+ paths = extract_ast_paths(ast, encoder)
77
+ if not paths:
78
+ return None
79
+
80
+ edge_index = []
81
+ node_features = []
82
+ node_counter = 0
83
+ node_mapping = {}
84
+
85
+ for path in paths:
86
+ for i in range(len(path) - 1):
87
+ for j in [i, i+1]:
88
+ node_key = tuple(path[j].tolist())
89
+ if node_key not in node_mapping:
90
+ node_mapping[node_key] = node_counter
91
+ node_features.append(path[j])
92
+ node_counter += 1
93
+
94
+ src = node_mapping[tuple(path[i].tolist())]
95
+ dst = node_mapping[tuple(path[i+1].tolist())]
96
+ edge_index.append([src, dst])
97
+
98
+ if not edge_index:
99
+ return None
100
+
101
+ node_features = torch.stack(node_features)
102
+ edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
103
+ return Data(x=node_features, edge_index=edge_index)
104
+
105
+ # Model Architecture
106
+ class ASTGNN(nn.Module):
107
+ def __init__(self, input_dim, hidden_dim):
108
+ super().__init__()
109
+ self.conv1 = nn.Sequential(
110
+ nn.Linear(input_dim, hidden_dim),
111
+ nn.ReLU(),
112
+ nn.Linear(hidden_dim, hidden_dim)
113
+ )
114
+ self.conv2 = nn.Sequential(
115
+ nn.Linear(hidden_dim, hidden_dim),
116
+ nn.ReLU(),
117
+ nn.Linear(hidden_dim, hidden_dim)
118
+ )
119
+ self.pool = nn.AdaptiveMaxPool1d(1)
120
+
121
+ def forward(self, data):
122
+ x, edge_index = data.x.to(DEVICE), data.edge_index.to(DEVICE)
123
+ x = self.conv1(x)
124
+ x = self.conv2(x)
125
+ x = x.t().unsqueeze(0)
126
+ x = self.pool(x)
127
+ return x.squeeze(0).squeeze(-1)
128
+
129
+ class HybridCloneDetector(nn.Module):
130
+ def __init__(self, ast_input_dim, hidden_dim):
131
+ super().__init__()
132
+ self.ast_gnn = ASTGNN(ast_input_dim, hidden_dim)
133
+ self.classifier = nn.Sequential(
134
+ nn.Linear(hidden_dim * 2, hidden_dim),
135
+ nn.ReLU(),
136
+ nn.Linear(hidden_dim, 2))
137
+
138
+ def forward(self, ast_data, code_embedding):
139
+ ast_embed = self.ast_gnn(ast_data)
140
+ combined = torch.cat([ast_embed, code_embedding.squeeze(0)], dim=0)
141
+ return self.classifier(combined.unsqueeze(0))
142
+
143
+ # Load Models
144
+ def load_models():
145
+ ast_encoder = ASTNodeEncoder()
146
+ ast_encoder.fit(['MethodDeclaration', 'VariableDeclaration', 'IfStatement'])
147
+
148
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
149
+ code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE)
150
+
151
+ model = HybridCloneDetector(EMBEDDING_DIM, EMBEDDING_DIM).to(DEVICE)
152
+ if Path('model.pth').exists():
153
+ model.load_state_dict(torch.load('model.pth', map_location=DEVICE))
154
+
155
+ return ast_encoder, tokenizer, code_model, model
156
+
157
+ ast_encoder, tokenizer, code_model, model = load_models()
158
+
159
+ # Prediction Function
160
+ def predict_clone(code1, code2):
161
+ try:
162
+ # Process first code
163
+ norm_code1 = normalize_java_code(code1)
164
+ tokens1 = list(javalang.tokenizer.tokenize(norm_code1))
165
+ parser = javalang.parser.Parser(tokens1)
166
+ ast1 = parser.parse()
167
+ ast_data1 = ast_to_graph_data(ast1, ast_encoder)
168
+
169
+ inputs1 = tokenizer(norm_code1, return_tensors="pt", truncation=True).to(DEVICE)
170
+ with torch.no_grad():
171
+ code_embed1 = code_model(**inputs1).last_hidden_state.mean(dim=1)
172
+
173
+ # Process second code
174
+ norm_code2 = normalize_java_code(code2)
175
+ tokens2 = list(javalang.tokenizer.tokenize(norm_code2))
176
+ parser = javalang.parser.Parser(tokens2)
177
+ ast2 = parser.parse()
178
+ ast_data2 = ast_to_graph_data(ast2, ast_encoder)
179
+
180
+ inputs2 = tokenizer(norm_code2, return_tensors="pt", truncation=True).to(DEVICE)
181
+ with torch.no_grad():
182
+ code_embed2 = code_model(**inputs2).last_hidden_state.mean(dim=1)
183
+
184
+ # Predict
185
+ with torch.no_grad():
186
+ logits1 = model(ast_data1.to(DEVICE), code_embed1)
187
+ logits2 = model(ast_data2.to(DEVICE), code_embed2)
188
+ sim_score = F.cosine_similarity(logits1, logits2).item()
189
+
190
+ return {
191
+ "Similarity": f"{sim_score:.3f}",
192
+ "Clone": "Yes" if sim_score > 0.7 else "No"
193
+ }
194
+ except Exception as e:
195
+ return {"Error": str(e)}
196
+
197
+ # Gradio Interface
198
+ demo = gr.Interface(
199
+ fn=predict_clone,
200
+ inputs=[
201
+ gr.Textbox(label="First Java Code", lines=10),
202
+ gr.Textbox(label="Second Java Code", lines=10)
203
+ ],
204
+ outputs=gr.JSON(label="Prediction"),
205
+ examples=[
206
+ ["""public class Hello {
207
+ public static void main(String[] args) {
208
+ System.out.println("Hello, World!");
209
+ }
210
+ }""",
211
+ """public class Greet {
212
+ public static void main(String[] args) {
213
+ System.out.println("Hello, World!");
214
+ }
215
+ }"""],
216
+ ["""public int add(int a, int b) {
217
+ return a + b;
218
+ }""",
219
+ """public int sum(int x, int y) {
220
+ return x + y;
221
+ }"""]
222
+ ],
223
+ title="Java Code Clone Detector",
224
+ description="Detect code clones between two Java code snippets using AST and neural embeddings"
225
+ )
226
+
227
+ if __name__ == "__main__":
228
+ demo.launch()