import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')

import plot_utils
from constants import *


class MatplotlibDataPlotter:
    def __init__(self, single_df, pair_df, num_domains_in_region_df):
        self.single_df = single_df
        self.pair_df = pair_df

        self.num_domains_in_region_df = num_domains_in_region_df

        self.single_domains_fig = plt.figure(figsize=(5, 10))
        self.pair_domains_fig = plt.figure(figsize=(5, 10))

    def plot_single_domains(self, num_domains, split_name="stratified"):
        selected_region_ids = self.num_domains_in_region_df.loc[
            self.num_domains_in_region_df.num_domains >= num_domains, 
            'cds_region_id'].values

        single_df_subset = self.single_df.loc[self.single_df.cds_region_id.isin(selected_region_ids)]

        biosyn_counts_single = single_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
        hue2count_single = dict(biosyn_counts_single.values)

        # split_name = 'stratified'
        column_name = f'cosine_similarity_{split_name}'
        # single_df_subset = single_df.loc[single_df.dom_location_len >= num_domains]
        selected_keyword_index = single_df_subset.groupby('cds_region_id').agg(
            {column_name: 'idxmax'}
        ).values.flatten()
        targets_list = single_df_subset.loc[selected_keyword_index, 'biosyn_class_index'].values
        label_list = single_df_subset.loc[selected_keyword_index, 'profile_name'].values

        top_n=5
        bin_width=1
        hue_group_offset=0.5
        width=0.9

        fig = self.single_domains_fig
        fig.clf()

        ax = fig.gca()
        plot_utils.draw_barplots(
            targets_list, 
            label_list=label_list,
            top_n=top_n,
            bin_width=bin_width,
            hue_group_offset=hue_group_offset,
            hue_order=BIOSYN_CLASS_NAMES,
            hue2count=hue2count_single,
            width=width,
            ax=ax, 
            show_legend=False,
            palette=COLOR_PALETTE
        )
        fig.tight_layout()
        return fig

    def plot_pair_domains(self, num_domains, split_name="stratified"):
        selected_region_ids = self.num_domains_in_region_df.loc[
            self.num_domains_in_region_df.num_domains >= num_domains, 
            'cds_region_id'].values
        
        pair_df_subset = self.pair_df.loc[self.pair_df.cds_region_id.isin(selected_region_ids)]
        
        biosyn_counts_pairs = pair_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
        hue2count_pairs = dict(biosyn_counts_pairs.values)
        
        column_name = f'cosine_similarity_{split_name}'

        selected_keyword_index = pair_df_subset.groupby('cds_region_id').agg(
            {column_name: 'idxmax'}
        ).values.flatten()
        targets_list = pair_df_subset.loc[
            selected_keyword_index, 'biosyn_class_index'].values
        label_list=pair_df_subset.loc[
            selected_keyword_index, 'profile_name'].values

        top_n=5
        bin_width=1
        hue_group_offset=0.5
        # hue_order=BIOSYN_CLASS_NAMES
        hue2count={}
        width=0.9

        show_legend=False
        fig = self.pair_domains_fig
        fig.clf()

        ax = fig.gca()
        plot_utils.draw_barplots(
            targets_list, 
            label_list=label_list,
            top_n=top_n,
            bin_width=bin_width,
            hue_group_offset=hue_group_offset,
            hue_order=BIOSYN_CLASS_NAMES,
            hue2count=hue2count_pairs,
            width=width,
            ax=ax, 
            show_legend=show_legend,
            palette=COLOR_PALETTE
        )
        fig.tight_layout()
        return fig  #plt.gcf()