|
import gradio as gr |
|
import pandas as pd |
|
from datasets import load_dataset |
|
import random |
|
from typing import Dict, Any, List |
|
import json |
|
|
|
|
|
def load_community_alignment_dataset(): |
|
"""Load the Facebook Community Alignment Dataset""" |
|
try: |
|
dataset = load_dataset("facebook/community-alignment-dataset") |
|
print(f"Dataset loaded successfully. Available splits: {list(dataset.keys())}") |
|
for split_name, split_data in dataset.items(): |
|
print(f"Split '{split_name}': {len(split_data)} items") |
|
return dataset |
|
except Exception as e: |
|
print(f"Error loading dataset: {e}") |
|
return None |
|
|
|
|
|
dataset = load_community_alignment_dataset() |
|
|
|
def format_conversation_turn(turn_data: Dict[str, Any], turn_number: int) -> str: |
|
"""Format a conversation turn for display""" |
|
if not turn_data: |
|
return "" |
|
|
|
prompt = turn_data.get('prompt', '') |
|
responses = turn_data.get('responses', '') |
|
preferred = turn_data.get('preferred_response', '') |
|
|
|
formatted = f"**Turn {turn_number}**\n\n" |
|
formatted += f"**Prompt:** {prompt}\n\n" |
|
|
|
if responses: |
|
formatted += "**Responses:**\n" |
|
formatted += responses.replace('# Response ', '**Response ').replace(':\n', ':**\n') |
|
formatted += "\n" |
|
|
|
if preferred: |
|
formatted += f"**Preferred Response:** {preferred.upper()}\n" |
|
|
|
return formatted |
|
|
|
def get_conversation_data(conversation_id: int) -> Dict[str, Any]: |
|
"""Get conversation data by ID""" |
|
if not dataset: |
|
return None |
|
|
|
try: |
|
|
|
for split in dataset.keys(): |
|
split_data = dataset[split] |
|
for i in range(len(split_data)): |
|
item = split_data[i] |
|
if item.get('conversation_id') == conversation_id: |
|
return item |
|
return None |
|
except Exception as e: |
|
print(f"Error getting conversation data: {e}") |
|
return None |
|
|
|
def format_annotator_info(item: Dict[str, Any]) -> str: |
|
"""Format annotator information""" |
|
info = "**Annotator Information:**\n\n" |
|
|
|
demographics = [ |
|
('Age', 'annotator_age'), |
|
('Gender', 'annotator_gender'), |
|
('Education', 'annotator_education_level'), |
|
('Political', 'annotator_political'), |
|
('Ethnicity', 'annotator_ethnicity'), |
|
('Country', 'annotator_country') |
|
] |
|
|
|
for label, key in demographics: |
|
value = item.get(key, 'N/A') |
|
if value and value != 'None': |
|
info += f"**{label}:** {value}\n" |
|
|
|
return info |
|
|
|
def display_conversation(conversation_id: int) -> tuple: |
|
"""Display a conversation by ID""" |
|
if not dataset: |
|
return "Dataset not loaded", "", "", "" |
|
|
|
item = get_conversation_data(conversation_id) |
|
if not item: |
|
return f"Conversation ID {conversation_id} not found", "", "", "" |
|
|
|
|
|
conversation_text = "" |
|
|
|
|
|
if item.get('first_turn_prompt'): |
|
first_turn = { |
|
'prompt': item['first_turn_prompt'], |
|
'responses': item['first_turn_responses'], |
|
'preferred_response': item['first_turn_preferred_response'] |
|
} |
|
conversation_text += format_conversation_turn(first_turn, 1) + "\n" |
|
|
|
|
|
if item.get('second_turn_prompt'): |
|
second_turn = { |
|
'prompt': item['second_turn_prompt'], |
|
'responses': item['second_turn_responses'], |
|
'preferred_response': item['second_turn_preferred_response'] |
|
} |
|
conversation_text += format_conversation_turn(second_turn, 2) + "\n" |
|
|
|
|
|
if item.get('third_turn_prompt'): |
|
third_turn = { |
|
'prompt': item['third_turn_prompt'], |
|
'responses': item['third_turn_responses'], |
|
'preferred_response': item['third_turn_preferred_response'] |
|
} |
|
conversation_text += format_conversation_turn(third_turn, 3) + "\n" |
|
|
|
|
|
if item.get('fourth_turn_prompt'): |
|
fourth_turn = { |
|
'prompt': item['fourth_turn_prompt'], |
|
'responses': item['fourth_turn_responses'], |
|
'preferred_response': item['fourth_turn_preferred_response'] |
|
} |
|
conversation_text += format_conversation_turn(fourth_turn, 4) + "\n" |
|
|
|
|
|
annotator_info = format_annotator_info(item) |
|
|
|
|
|
metadata = f"**Metadata:**\n\n" |
|
metadata += f"**Conversation ID:** {item.get('conversation_id', 'N/A')}\n" |
|
metadata += f"**Language:** {item.get('assigned_lang', 'N/A')}\n" |
|
metadata += f"**Annotator ID:** {item.get('annotator_id', 'N/A')}\n" |
|
metadata += f"**In Balanced Subset:** {item.get('in_balanced_subset', 'N/A')}\n" |
|
metadata += f"**In Balanced Subset 10:** {item.get('in_balanced_subset_10', 'N/A')}\n" |
|
metadata += f"**Is Pregenerated First Prompt:** {item.get('is_pregenerated_first_prompt', 'N/A')}\n" |
|
|
|
|
|
raw_json = json.dumps(item, indent=2) |
|
|
|
return conversation_text, annotator_info, metadata, raw_json |
|
|
|
def get_random_conversation() -> int: |
|
"""Get a random conversation ID""" |
|
if not dataset: |
|
return 0 |
|
|
|
try: |
|
|
|
split = random.choice(list(dataset.keys())) |
|
split_data = dataset[split] |
|
|
|
|
|
random_index = random.randint(0, len(split_data) - 1) |
|
item = split_data[random_index] |
|
|
|
return item.get('conversation_id', 0) |
|
except Exception as e: |
|
print(f"Error getting random conversation: {e}") |
|
|
|
return 1061830552573006 |
|
|
|
def get_dataset_stats() -> str: |
|
"""Get dataset statistics""" |
|
if not dataset: |
|
return "Dataset not loaded" |
|
|
|
stats = "**Dataset Statistics:**\n\n" |
|
|
|
for split_name, split_data in dataset.items(): |
|
stats += f"**{split_name}:** {len(split_data)} conversations\n" |
|
|
|
|
|
if 'train' in dataset and len(dataset['train']) > 0: |
|
sample_item = dataset['train'][0] |
|
stats += f"\n**Sample Fields:**\n" |
|
for key in list(sample_item.keys())[:10]: |
|
stats += f"- {key}\n" |
|
|
|
return stats |
|
|
|
def search_conversations(query: str, field: str) -> str: |
|
"""Search conversations by field""" |
|
if not dataset or not query: |
|
return "Please provide a search query" |
|
|
|
results = [] |
|
query_lower = query.lower() |
|
|
|
try: |
|
for split_name, split_data in dataset.items(): |
|
|
|
for i in range(min(100, len(split_data))): |
|
item = split_data[i] |
|
if field in item and item[field]: |
|
field_value = str(item[field]).lower() |
|
if query_lower in field_value: |
|
results.append({ |
|
'conversation_id': item.get('conversation_id'), |
|
'split': split_name, |
|
'field_value': str(item[field])[:100] + "..." if len(str(item[field])) > 100 else str(item[field]) |
|
}) |
|
except Exception as e: |
|
return f"Error during search: {e}" |
|
|
|
if not results: |
|
return f"No results found for '{query}' in field '{field}'" |
|
|
|
result_text = f"**Search Results for '{query}' in '{field}':**\n\n" |
|
for i, result in enumerate(results[:10]): |
|
result_text += f"{i+1}. **Conversation ID:** {result['conversation_id']} (Split: {result['split']})\n" |
|
result_text += f" **Value:** {result['field_value']}\n\n" |
|
|
|
return result_text |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(title="Facebook Community Alignment Dataset Viewer", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# π€ Facebook Community Alignment Dataset Viewer") |
|
gr.Markdown("Explore conversations, responses, and annotations from the Facebook Community Alignment Dataset.") |
|
|
|
with gr.Tabs(): |
|
|
|
with gr.Tab("Conversation Viewer"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
conversation_id_input = gr.Number( |
|
label="Conversation ID", |
|
value=get_random_conversation(), |
|
interactive=True |
|
) |
|
random_btn = gr.Button("π² Random Conversation", variant="secondary") |
|
load_btn = gr.Button("π Load Conversation", variant="primary") |
|
|
|
with gr.Column(scale=3): |
|
conversation_display = gr.Markdown(label="Conversation") |
|
annotator_display = gr.Markdown(label="Annotator Information") |
|
metadata_display = gr.Markdown(label="Metadata") |
|
raw_json_display = gr.Code(label="Raw JSON", language="json") |
|
|
|
|
|
random_btn.click( |
|
fn=get_random_conversation, |
|
outputs=conversation_id_input |
|
) |
|
|
|
load_btn.click( |
|
fn=display_conversation, |
|
inputs=conversation_id_input, |
|
outputs=[conversation_display, annotator_display, metadata_display, raw_json_display] |
|
) |
|
|
|
conversation_id_input.submit( |
|
fn=display_conversation, |
|
inputs=conversation_id_input, |
|
outputs=[conversation_display, annotator_display, metadata_display, raw_json_display] |
|
) |
|
|
|
|
|
with gr.Tab("Dataset Statistics"): |
|
stats_btn = gr.Button("π Load Statistics", variant="primary") |
|
stats_display = gr.Markdown(label="Dataset Statistics") |
|
|
|
stats_btn.click( |
|
fn=get_dataset_stats, |
|
outputs=stats_display |
|
) |
|
|
|
|
|
with gr.Tab("Search Conversations"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
search_query = gr.Textbox( |
|
label="Search Query", |
|
placeholder="Enter search term...", |
|
interactive=True |
|
) |
|
search_field = gr.Dropdown( |
|
label="Search Field", |
|
choices=[ |
|
"first_turn_prompt", |
|
"second_turn_prompt", |
|
"third_turn_prompt", |
|
"annotator_country", |
|
"annotator_age", |
|
"annotator_gender", |
|
"assigned_lang" |
|
], |
|
value="first_turn_prompt", |
|
interactive=True |
|
) |
|
search_btn = gr.Button("π Search", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
search_results = gr.Markdown(label="Search Results") |
|
|
|
search_btn.click( |
|
fn=search_conversations, |
|
inputs=[search_query, search_field], |
|
outputs=search_results |
|
) |
|
|
|
search_query.submit( |
|
fn=search_conversations, |
|
inputs=[search_query, search_field], |
|
outputs=search_results |
|
) |
|
|
|
|
|
with gr.Tab("About"): |
|
gr.Markdown(""" |
|
## About the Facebook Community Alignment Dataset |
|
|
|
This dataset contains conversations with multiple response options and human annotations indicating which responses are preferred by different demographic groups. |
|
|
|
### Dataset Structure: |
|
- **Conversations**: Multi-turn dialogues with prompts and multiple response options |
|
- **Annotations**: Human preferences for different response options |
|
- **Demographics**: Annotator information including age, gender, education, political views, ethnicity, and country |
|
|
|
### Key Features: |
|
- Multi-turn conversations (up to 4 turns) |
|
- 4 response options per turn (A, B, C, D) |
|
- Human preference annotations |
|
- Diverse annotator demographics |
|
- Balanced subsets for analysis |
|
|
|
### Use Cases: |
|
- Studying response preferences across demographics |
|
- Training models to generate community-aligned responses |
|
- Analyzing conversation dynamics |
|
- Understanding cultural and demographic differences in communication preferences |
|
|
|
### Citation: |
|
If you use this dataset, please cite the original Facebook research paper. |
|
""") |
|
|
|
|
|
demo.load( |
|
fn=display_conversation, |
|
inputs=conversation_id_input, |
|
outputs=[conversation_display, annotator_display, metadata_display, raw_json_display] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
show_error=True |
|
) |