#coding=utf-8

import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
import matplotlib.pyplot as plt
from sklearn.manifold import MDS
from sklearn.metrics.pairwise import cosine_similarity
import random
from matplotlib.font_manager import FontProperties
from sklearn.cluster import KMeans
from sklearn import metrics
from collections import Counter
from scipy.cluster.hierarchy import ward, dendrogram
import bottom_function.normalization as normalization
import bottom_function.m_SQL as qb

# import json
# from flask import Flask
# from flask import request
# from flask_cors import CORS


class Culter:

    def __init__(self, start_time, end_time):
        self.start_time = start_time
        self.end_time = end_time

        csv_data = pd.DataFrame()
        self.chat_data = pd.DataFrame()
        tablename = "semantic_data_table"
        db = qb.Schema(host="localhost", user="560193", password="jay560193", mysqlName="semantic_data_schema",
                       port="3306")
        csv_data = db.getData(tableName=tablename, startTime=start_time, endTime=end_time)
        self.chat_data = csv_data[(csv_data['domain'] == 'chat')]

        # self.chat_data.drop_duplicates(subset=['query'], inplace=True)
        # self.chat_data.dropna(subset=['query'], inplace=True)

        self.out_data = ''  # 写入分析结果
        self.feature_names = []
        self.f_sse = []
        self.feature_matrix = np.matrix([])

    def build_feature_matrix(self, documents, feature_type, ngram_range, min_df, max_df):

        feature_type = feature_type.lower().strip()
        if feature_type == 'binary':
            vectorizer = CountVectorizer(binary=True,
                                         max_df=max_df, ngram_range=ngram_range)
        elif feature_type == 'frequency':
            vectorizer = CountVectorizer(binary=False, min_df=min_df,
                                         max_df=max_df, ngram_range=ngram_range)
        elif feature_type == 'tfidf':
            vectorizer = TfidfVectorizer(token_pattern=r"(?u)\b\w+\b", max_df=max_df)
        else:
            raise Exception("Wrong feature type entered. Possible values: 'binary', 'frequency', 'tfidf'")

        feature_matrix = vectorizer.fit_transform(documents).astype(float)

        return vectorizer, feature_matrix

    def feature_extraction_data(self):

        chat_one = self.chat_data['query'].tolist()

        norm_chat_one = normalization.normalize_corpus(chat_one, pos=False)

        # 提取 tf-idf 特征
        vectorizer, self.feature_matrix = self.build_feature_matrix(norm_chat_one, feature_type='tfidf', min_df=0.2,
                                                                    max_df=0.90,
                                                                    ngram_range=(1, 2))

        # 查看特征数量)
        self.out_data = '聚类分析结果:\n' + '**' * 30
        self.out_data = self.out_data + '\n特征数量:\n' + str(self.feature_matrix.shape)

        # 获取特征名字
        self.feature_names = vectorizer.get_feature_names()

        # 打印某些特征
        self.out_data = self.out_data + '\n部分特征:\n' + ', '.join(self.feature_names[:5])

    def get_cluster_data(self, clustering_obj, m_data, feature_names, num_clusters, topn_features):
        cluster_data = {}

        # 获取cluster的center
        ordered_centroids = clustering_obj.cluster_centers_.argsort()[:, ::-1]
        # 获取每个cluster的关键特征
        # 获取每个cluster的query
        for cluster_num in range(num_clusters):
            cluster_data[cluster_num] = {}
            cluster_data[cluster_num]['cluster_num'] = cluster_num
            key_features = [feature_names[index]
                            for index
                            in ordered_centroids[cluster_num, :topn_features]]
            cluster_data[cluster_num]['key_features'] = key_features

            c_datas = m_data[m_data['Cluster'] == cluster_num]['query'].values.tolist()
            cluster_data[cluster_num]['query'] = c_datas

        return cluster_data

    def print_cluster_data(self, cluster_data):
        self.out_data = self.out_data + '\n\n聚类详细数据:\n'

        for cluster_num, cluster_details in cluster_data.items():
            self.out_data = self.out_data + '\nCluster {} details:\n'.format(cluster_num)

            self.out_data = self.out_data + '-' * 20
            self.out_data = self.out_data + '\nKey features:\n'
            self.out_data = self.out_data + ', '.join(cluster_details['key_features'])

            self.out_data = self.out_data + '\ndata in this cluster:\n'
            self.out_data = self.out_data + ', '.join(cluster_details['query'])
            self.out_data = self.out_data + '\n' + '=' * 40

    def plot_clusters(self, feature_matrix, cluster_data, m_data, plot_size):
        def generate_random_color():  # generate random color for clusters
            color = '#%06x' % random.randint(0, 0xFFFFFF)
            return color

        # define markers for clusters
        markers = ['o', 'v', '^', '<', '>', '8', 's', 'p', '*', 'h', 'H', 'D', 'd']
        # build cosine distance matrix
        cosine_distance = 1 - cosine_similarity(feature_matrix)
        # dimensionality reduction using MDS
        mds = MDS(n_components=2, dissimilarity="precomputed",
                  random_state=1)
        # get coordinates of clusters in new low-dimensional space
        plot_positions = mds.fit_transform(cosine_distance)
        x_pos, y_pos = plot_positions[:, 0], plot_positions[:, 1]
        # build cluster plotting data
        cluster_color_map = {}
        cluster_name_map = {}
        # print(cluster_data)
        for cluster_num, cluster_details in cluster_data.items():
            # assign cluster features to unique label
            cluster_color_map[cluster_num] = generate_random_color()
            cluster_name_map[cluster_num] = ', '.join(cluster_details['key_features'][:5]).strip()
        # map each unique cluster label with its coordinates and books
        cluster_plot_frame = pd.DataFrame({'x': x_pos,
                                           'y': y_pos,
                                           'label': m_data['Cluster'].values.tolist(),
                                           'query': m_data['query'].values.tolist()
                                           })
        grouped_plot_frame = cluster_plot_frame.groupby('label')
        # set plot figure size and axes

        plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
        matplotlib.rcParams['font.family'] = 'sans-serif'
        plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
        fig, ax = plt.subplots(figsize=plot_size)
        ax.margins(0.05)
        # plot each cluster using co-ordinates and  titles
        for cluster_num, cluster_frame in grouped_plot_frame:
            marker = markers[cluster_num] if cluster_num < len(markers) \
                else np.random.choice(markers, size=1)[0]
            ax.plot(cluster_frame['x'], cluster_frame['y'],
                    marker=marker, linestyle='', ms=12,
                    label=cluster_name_map[cluster_num],
                    color=cluster_color_map[cluster_num], mec='none')
            ax.set_aspect('auto')
            ax.tick_params(axis='x', which='both', bottom=False, top=False,
                           labelbottom='off')
            ax.tick_params(axis='y', which='both', left=False, top=False,
                           labelleft=False)
        fontP = FontProperties()
        fontP.set_size(23)
        ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.01), fancybox=True,
                  shadow=True, ncol=5, numpoints=1, prop=fontP)
        # add labels as the film titles
        for index in range(len(cluster_plot_frame)):
            ax.text(cluster_plot_frame.ix[index]['x'], cluster_plot_frame.ix[index]['y'],
                    cluster_plot_frame.ix[index]['query'], size=20)
            # show the plot
        plt.title(self.start_time + ' to ' + self.end_time + 'chat data cluster point set', fontsize=25)
        path = '/roobo/soft/phpmyadmin/cluster_point.jpg'
        plt.savefig(path)
        return 'http://120.79.171.145:8000/cluster_point.jpg'

    def k_means(self, feature_matrix):
        f_sse = []
        num_clusters = []
        for i in range(2, 21):
            km = KMeans(n_clusters=i, max_iter=10000)
            km.fit(feature_matrix)
            clusters = km.labels_
            num_matrix = feature_matrix.todense()
            sse = metrics.calinski_harabaz_score(num_matrix, clusters)
            num_clusters.append(i)
            f_sse.append(sse)

        pd_see = pd.Series(f_sse, index=num_clusters)
        pct_see = pd_see.pct_change()

        fig, ax = plt.subplots(figsize=(10, 8))
        ax.plot(num_clusters, f_sse, 'o-', c='orangered', label='clustering quality')
        plt.legend(loc=2)
        plt.xticks(num_clusters)
        ax.set_xlabel("cluster number")

        ax.set_ylabel("coefficient")

        ax1 = ax.twinx()
        ax1.plot(pct_see.values, 'o-', c='blue', label='gradient change')
        ax1.set_ylabel("gradient")
        plt.legend(loc=1)

        plt.title(self.start_time + " to " + self.end_time + " the analysis of clusters with different numbers", fontsize=12)
        path = '/roobo/soft/phpmyadmin/choice_num.jpg'
        plt.savefig(path)

        # input_num = input('输入最优聚类数目:')
        # best_num = int(input_num)

        self.f_sse = f_sse
        return 'http://120.79.171.145:8000/choice_num.jpg'

    def k_means_cluster(self, best_num):

        self.out_data = self.out_data + '\n' + "=" * 20
        self.out_data = self.out_data + "\n\n聚类效果分析:\n"
        self.out_data = self.out_data + "\n聚类数目为:" + str(best_num)

        f_sse = self.f_sse
        sse = f_sse[best_num]
        km = KMeans(n_clusters=best_num, max_iter=10000)
        km.fit(self.feature_matrix)
        clusters = km.labels_
        self.chat_data['Cluster'] = clusters

        # 获取每个cluster的数量
        c = Counter(clusters)

        sort_c = sorted(c.items(), key=lambda c: c[0], reverse=False)
        c.clear()
        for key, value in sort_c:
            c[key] = value

        self.out_data = self.out_data + '\nCalinski-Harabasz分数:' + str(sse)
        self.out_data = self.out_data + '\n每个特征的数据量:\n'
        self.out_data = self.out_data + (str(c.items()))
        self.out_data = self.out_data + '\n' + "=" * 20
        cluster_data = self.get_cluster_data(clustering_obj=km,
                                             m_data=self.chat_data,
                                             feature_names=self.feature_names,
                                             num_clusters=best_num,
                                             topn_features=5)

        self.print_cluster_data(cluster_data)

        path = self.plot_clusters(feature_matrix=self.feature_matrix, cluster_data=cluster_data, m_data=self.chat_data,
                                  plot_size=(40, 25))
        return path

    def ward_hierarchical_clustering(self, feature_matrix):
        cosine_distance = 1 - cosine_similarity(feature_matrix)
        linkage_matrix = ward(cosine_distance)
        return linkage_matrix

    def plot_hierarchical_clusters(self, linkage_matrix, m_data, figure_size):
        # set size
        fig, ax = plt.subplots(figsize=figure_size)
        m_titles = m_data['query'].values.tolist()

        # plot dendrogram
        ax = dendrogram(linkage_matrix, orientation="left", labels=m_titles)
        plt.tick_params(axis='x',
                        which='both',
                        bottom=False,
                        top=False,
                        labelbottom=False)
        plt.tight_layout()
        plt.title(self.start_time + ' to ' + self.end_time + 'chat data ward hierachical clusters',fontsize=12)
        path = '/roobo/soft/phpmyadmin/hierachical_clusters.jpg'
        plt.savefig(path)
        return 'http://120.79.171.145:8000/hierachical_clusters.jpg'


# app = Flask(__name__)
# CORS(app, supports_credentials=True)
#
# data_cluster = Culter(start_time="2018-12-01 00:00:00", end_time="2018-12-02 00:00:00")
#
#
# @app.route('/SPDAS/chat_function_analysis/choice1', methods=['POST'])
# def choice():
#     param = ({"time": "2018-12-01 00:00:00/2018-12-02 00:00:00"})
#     return json.JSONEncoder().encode(param)
#
#
# @app.route('/SPDAS/chat_function_analysis/choice2', methods=['POST'])
# def choice_form():
#     # 需要从request对象读取表单内容:
#     data = request.get_data()
#     json_re = json.loads(data)
#
#     m_time = json_re['time']
#     str_time = str(m_time)
#     m_time = str_time.split('/')
#     starttime = m_time[0]
#     endtime = m_time[1]
#     data_cluster = Culter(start_time=starttime, end_time=endtime)
#     data_cluster.feature_extraction_data()
#     image_path = data_cluster.k_means(data_cluster.feature_matrix)
#     path = ({"num_image": image_path})
#     return json.JSONEncoder().encode(path)
#
#
# @app.route('/SPDAS/chat_function_analysis/chat1', methods=['POST'])
# def chat():
#     param = ({"best_num": "2"})
#     return json.JSONEncoder().encode(param)
#
#
# @app.route('/SPDAS/chat_function_analysis/chat2', methods=['POST'])
# def chat_form():
#     # 需要从request对象读取表单内容:
#     data = request.get_data()
#     json_re = json.loads(data)
#     bestnum = json_re['best_num']
#     image_path1 = data_cluster.k_means_cluster(best_num=bestnum)
#
#     linkage_matrix = data_cluster.ward_hierarchical_clustering(data_cluster.feature_matrix)
#
#     image_path2 = data_cluster.plot_hierarchical_clusters(linkage_matrix=linkage_matrix, m_data=data_cluster.chat_data,
#                                                           figure_size=(35, 40))
#     with open("/roobo/soft/phpmyadmin/chat_function_data.txt", 'w') as file:
#         file.writelines(data_cluster.out_data)
#
#     path = ({"cluster_point": image_path1, "ward_image": image_path2})
#     return json.JSONEncoder().encode(path)
#
#
# if __name__ == '__main__':
#     app.run(debug=True, host='10.7.19.129', port=5000)