File size: 8,608 Bytes
81a794d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import json
import logging
import gradio as gr

from backend.inference import section_infer, cwe_infer, PREDEF_MODEL_MAP, LOCAL_MODEL_PEFT_MAP, PREDEF_CWE_MODEL

APP_TITLE = "PATCHOULI"

STYLE_APP_TITLE = '<div style="text-align: center; font-weight: bold; font-family: Arial, sans-serif; font-size: 44px;">' + \
    '<span style="color: #14e166">PATCH</span> ' + \
    '<span style="color: #14e166">O</span>bserving ' + \
    'and ' + \
    '<span style="color: #14e166">U</span>ntang<span style="color: #14e166">l</span>ing ' + \
    'Eng<span style="color: #14e166">i</span>ne' + \
    '</div>'

# from 0.00 to 1.00, 41 colors
NONVUL_GRADIENT_COLORS = ["#d3f8d6",
"#d3f8d6", "#d0f8d3", "#ccf7d0", "#c9f7cd", "#c6f6cb", "#c2f6c8", "#bff5c5", "#bcf5c2", "#b8f4bf", "#b5f4bc", 
"#b1f3ba", "#aef2b7", "#aaf2b4", "#a7f1b1", "#a3f1ae", "#9ff0ab", "#9cf0a9", "#98efa6", "#94efa3", "#90eea0", 
"#8ced9d", "#88ed9a", "#84ec98", "#80ec95", "#7ceb92", "#78ea8f", "#73ea8c", "#6fe989", "#6ae886", "#65e883", 
"#60e781", "#5ae67e", "#55e67b", "#4fe578", "#48e475", "#41e472", "#39e36f", "#30e26c", "#25e269", "#14e166"
]

# from 0.00 to 1.00, 41 colors
VUL_GRADIENT_COLORS = ["#d3f8d6",
"#fdcfc9", "#fdccc5", "#fcc9c2", "#fcc5bf", "#fcc2bb", "#fbbfb8", "#fbbcb4", "#fab9b1", "#fab5ad", "#f9b2aa", 
"#f8afa7", "#f8aca3", "#f7a8a0", "#f7a59c", "#f6a299", "#f59f96", "#f59c92", "#f4988f", "#f3958c", "#f29288", 
"#f18e85", "#f18b82", "#f0887f", "#ef847c", "#ee8178", "#ed7e75", "#ec7a72", "#eb776f", "#ea736c", "#e97068", 
"#e86c65", "#e76962", "#e6655f", "#e5615c", "#e45e59", "#e35a56", "#e25653", "#e05250", "#df4e4d", "#de4a4a"
]


logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.getLogger("httpx").setLevel(logging.WARNING)


def generate_color_map():
    color_map = {}
    for i in range(0, 101):
        color_map[f"non-vul-fixing: {i/100:0.2f}"] = NONVUL_GRADIENT_COLORS[int(i * 0.4)]
        color_map[f"vul-fixing: {i/100:0.2f}"] = VUL_GRADIENT_COLORS[int(i * 0.4)]
    return color_map


def on_submit(diff_code, patch_message, cwe_model, section_model_type, progress = gr.Progress(track_tqdm=True), *model_config):
    if diff_code == "":
        return gr.skip(), gr.skip(), gr.skip()
    
    try:
        section_results = section_infer(diff_code, patch_message, section_model_type, *model_config)
    except Exception as e:
        raise gr.Error(f"Error: {str(e)}")
    
    vul_cnt = 0
    for file_results in section_results.values():
        for item in file_results:
            if item["predict"] == 1:
                vul_cnt += 1
    label_text = f"Vul-fixing patch" if vul_cnt > 0 \
            else f"Non-vul-fixing patch"
    color = "#de4a4a" if vul_cnt > 0 else "#14e166"
    patch_category_label = gr.Label(value = label_text, color = color)
    
    if cwe_model == "":
        cwe_cls_result = "No model selected"
    elif vul_cnt == 0:
        cwe_cls_result = "No vulnerability found"
    else:
        cwe_cls_result = cwe_infer(diff_code, patch_message, cwe_model)
    
    return patch_category_label, section_results, cwe_cls_result


with gr.Blocks(title = APP_TITLE, fill_width=True) as demo:
    
    section_results_state = gr.State({})
    cls_results_state = gr.State({})

    title = gr.HTML(STYLE_APP_TITLE)
    
    with gr.Row() as main_block:
        
        with gr.Column(scale=1) as input_block:
            diff_codebox = gr.Code(label="Input git diff here", max_lines=10)
            
            with gr.Accordion("Patch message (optional)", open=False):
                message_textbox = gr.Textbox(label="Patch message", placeholder="Enter patch message here", container=False, lines=2, max_lines=5)
            
            cwe_model_selector = gr.Dropdown(PREDEF_CWE_MODEL, label="Select vulnerability type classifier", allow_custom_value=True)
            
            with gr.Tabs(selected=0) as model_type_tabs:
                MODEL_TYPE_NAMES = list(PREDEF_MODEL_MAP.keys())
                with gr.Tab(MODEL_TYPE_NAMES[0]) as local_llm_tab:                
                    local_model_selector = gr.Dropdown(PREDEF_MODEL_MAP[MODEL_TYPE_NAMES[0]], label="Select model", allow_custom_value=True)
                    local_peft_selector = gr.Dropdown(LOCAL_MODEL_PEFT_MAP[local_model_selector.value], label="Select PEFT model (optional)", allow_custom_value=True)
                    local_submit_btn = gr.Button("Run", variant="primary")
                with gr.Tab(MODEL_TYPE_NAMES[1]) as online_llm_tab:
                    online_model_selector = gr.Dropdown(PREDEF_MODEL_MAP[MODEL_TYPE_NAMES[1]], label="Select model", allow_custom_value=True)
                    online_api_url_textbox = gr.Textbox(label="API URL")
                    online_api_key_textbox = gr.Textbox(label="API Key", placeholder="We won't store your API key", value=os.getenv("ONLINE_API_KEY"), type="password")
                    online_submit_btn = gr.Button("Run", variant="primary")

            section_model_type = gr.State(model_type_tabs.children[0].label)

            with gr.Accordion("Load examples", open=False):
                with open("./backend/examples.json", "r") as f:
                    examples = json.load(f)
                gr.Button("Load example 1", size='sm').click(lambda : examples[0], outputs=[diff_codebox, message_textbox])
                gr.Button("Load example 2", size='sm').click(lambda : examples[1], outputs=[diff_codebox, message_textbox])
                gr.Button("Load example 3", size='sm').click(lambda : examples[2], outputs=[diff_codebox, message_textbox])
        
        with gr.Column(scale=2) as section_result_block:
            @gr.render(inputs=section_results_state, triggers=[section_results_state.change, demo.load])
            def display_result(section_results):
                if not section_results or len(section_results) == 0:
                    with gr.Tab("File tabs"):
                        gr.Markdown("No results")
                else:
                    for file_name, file_results in section_results.items():
                        with gr.Tab(file_name) as file_tab:
                            highlited_results = []
                            full_color_map = generate_color_map()
                            this_color_map = {}
                            for item in file_results:
                                predict_result = {-1: 'error', 0: 'non-vul-fixing', 1: 'vul-fixing'}
                                text_label = f"{predict_result[item['predict']]}: {item['conf']:0.2f}"
                                this_color_map[text_label] = full_color_map[text_label]
                                highlited_results.append((
                                    item["section"],
                                    text_label
                                ))
                            gr.HighlightedText(
                                highlited_results,
                                label="Results",
                                color_map=this_color_map
                            )

        with gr.Column(scale=1) as result_block:
            patch_category_label = gr.Label(value = "No results", label = "Result of the whole patch")
            def update_vul_type_label(cls_results):
                return gr.Label(cls_results)
            vul_type_label = gr.Label(update_vul_type_label, label = "Possible fixed vulnerability type", inputs = [cls_results_state])
    
    
    def update_model_type_state(evt: gr.SelectData): 
        return evt.value
    model_type_tabs.select(update_model_type_state, outputs = [section_model_type])
    
    def update_support_peft(base_model):
        return gr.Dropdown(LOCAL_MODEL_PEFT_MAP[base_model], value = LOCAL_MODEL_PEFT_MAP[base_model][0])
    local_model_selector.change(update_support_peft, inputs=[local_model_selector], outputs = [local_peft_selector])
    
    local_submit_btn.click(fn = on_submit, 
                           inputs = [diff_codebox, message_textbox, cwe_model_selector, section_model_type, local_model_selector, local_peft_selector], 
                           outputs = [patch_category_label, section_results_state, cls_results_state])
    online_submit_btn.click(fn = on_submit, 
                            inputs = [diff_codebox, message_textbox, cwe_model_selector, section_model_type, online_model_selector, online_api_url_textbox, online_api_key_textbox], 
                            outputs = [patch_category_label, section_results_state, cls_results_state])
    
        
demo.launch()