import cv2
import mediapipe as mp
from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2
import numpy as np
import math

# visualization libraries
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import style

def draw_eyes_on_image(rgb_image, detection_result):
  
  # return rgb_image, 0, 0
  
  # canonical_face_model_uv_visualization in the below link
  # https://github.com/google/mediapipe/blob/a908d668c730da128dfa8d9f6bd25d519d006692/mediapipe/modules/face_geometry/data/canonical_face_model_uv_visualization.png
  left_eyes_bottom_list = [33, 7, 163, 144, 145, 153, 154, 155, 133]
  left_eyes_top_list = [246, 161, 160, 159, 158, 157, 173]
  right_eyes_bottom_list = [362, 382, 381, 380, 374, 373, 390, 249, 263]
  right_eyes_top_list = [398, 384, 385, 386, 387, 388, 466]
  
  face_landmarks_list = detection_result.face_landmarks
  annotated_image = np.copy(rgb_image)
  
  # We resize image to 640 * 360
  height, width, channels = rgb_image.shape
  
  # Loop through the detected faces to visualize. Actually, if we detect more than two faces, we will require user closer to the camera
  for idx in range(len(face_landmarks_list)):
    face_landmarks = face_landmarks_list[idx]

    mlist = []
    for landmark in face_landmarks:
      mlist.append([int(landmark.x * width), int(landmark.y * height), landmark.z])
      
    narray = np.copy(mlist)
    
    # Vertical line
    #
    #
    # Pick the largest difference (middle of the eyes)
    leftUp = narray[159]
    leftDown = narray[145]
    rightUp = narray[386]
    rightDown = narray[374]
    
    # compute left eye distance (vertical)
    leftUp_x = int(leftUp[0])
    leftUp_y = int(leftUp[1])
    leftDown_x = int(leftDown[0])
    leftDown_y = int(leftDown[1])
    leftVerDis = math.dist([leftUp_x, leftUp_y],[leftDown_x, leftDown_y])
    
    # compute right eye distance (vertical)
    rightUp_x = int(rightUp[0])
    rightUp_y = int(rightUp[1])
    rightDown_x = int(rightDown[0])
    rightDown_y = int(rightDown[1])
    rightVerDis = math.dist([rightUp_x, rightUp_y],[rightDown_x, rightDown_y])
    
    # print(f'leftVerDis: {leftVerDis} and rightVerDis: {rightVerDis}')
    
    # draw a line from left eye top to bottom
    annotated_image = cv2.line(rgb_image, (int(leftUp_x), int(leftUp_y)), (int(leftDown_x), int(leftDown_y)), (0, 200, 0), 1)
    
    # draw a line from right eye top to bottom
    annotated_image = cv2.line(rgb_image, (int(rightUp_x), int(rightUp_y)), (int(rightDown_x), int(rightDown_y)), (0, 200, 0), 1)
    #
    #
    # Horizontonal line
    #
    #
    # Pick the largest difference (middle of the eyes)
    leftLeft = narray[33]
    leftRight = narray[133]
    rightLeft = narray[362]
    rightRight = narray[263]
    
    # compute left eye distance (horizontal)
    leftLeft_x = int(leftLeft[0])
    leftLeft_y = int(leftLeft[1])
    leftRight_x = int(leftRight[0])
    leftRight_y = int(leftRight[1])
    leftHorDis = math.dist([leftLeft_x, leftLeft_y],[leftRight_x, leftRight_y])
    
    # compute right eye distance (horizontal)
    rightLeft_x = int(rightLeft[0])
    rightLeft_y = int(rightLeft[1])
    rightRight_x = int(rightRight[0])
    rightRight_y = int(rightRight[1])
    rightHorDis = math.dist([rightLeft_x, rightLeft_y],[rightRight_x, rightRight_y])
    
    # print(f'leftHorDis: {leftHorDis} and rightHorDis: {rightHorDis}')
    
    # draw a line from left eye top to bottom
    annotated_image = cv2.line(rgb_image, (int(leftLeft_x), int(leftLeft_y)), (int(leftRight_x), int(leftRight_y)), (0, 200, 0), 1)
    
    # draw a line from right eye top to bottom
    annotated_image = cv2.line(rgb_image, (int(rightLeft_x), int(rightLeft_y)), (int(rightRight_x), int(rightRight_y)), (0, 200, 0), 1)
    #
    #
    #
    #
    # print(f'leftRatio: {leftVerDis/leftHorDis} and rightRatio: {rightVerDis/rightHorDis}')
    
    leftRatio = leftVerDis/leftHorDis*100
    rightRatio = rightVerDis/rightHorDis*100

    
    # left_eyes_bottom = [narray[x] for x in left_eyes_bottom_list]
    # left_eyes_top = [narray[x] for x in left_eyes_top_list]
    # right_eyes_bottom = [narray[x] for x in right_eyes_bottom_list]
    # right_eyes_top = [narray[x] for x in right_eyes_top_list]
    
    # for p in left_eyes_bottom:
    #   annotated_image = cv2.circle(rgb_image, (int(p[0]), int(p[1])), radius=1, color=(0,0,255), thickness=1)
   
    # for p in left_eyes_top:
    #   annotated_image = cv2.circle(rgb_image, (int(p[0]), int(p[1])), radius=1, color=(0,0,255), thickness=1)
      
    # for p in right_eyes_bottom:
    #   annotated_image = cv2.circle(rgb_image, (int(p[0]), int(p[1])), radius=1, color=(0,0,255), thickness=1)
   
    # for p in right_eyes_top:
    #   annotated_image = cv2.circle(rgb_image, (int(p[0]), int(p[1])), radius=1, color=(0,0,255), thickness=1)
      
    
  return annotated_image, leftRatio, rightRatio

def draw_landmarks_on_image(rgb_image, detection_result):
  face_landmarks_list = detection_result.face_landmarks
  annotated_image = np.copy(rgb_image)

  # Loop through the detected faces to visualize. Actually, if we detect more than two faces, we will require user closer to the camera
  for idx in range(len(face_landmarks_list)):
    face_landmarks = face_landmarks_list[idx]

    # Draw the face landmarks.
    face_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
    face_landmarks_proto.landmark.extend([
      landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in face_landmarks
    ])
    
    solutions.drawing_utils.draw_landmarks(
        image=annotated_image,
        landmark_list=face_landmarks_proto,
        connections=mp.solutions.face_mesh.FACEMESH_TESSELATION,
        landmark_drawing_spec=None,
        connection_drawing_spec=mp.solutions.drawing_styles
        .get_default_face_mesh_tesselation_style())
    solutions.drawing_utils.draw_landmarks(
        image=annotated_image,
        landmark_list=face_landmarks_proto,
        connections=mp.solutions.face_mesh.FACEMESH_CONTOURS,
        landmark_drawing_spec=None,
        connection_drawing_spec=mp.solutions.drawing_styles
        .get_default_face_mesh_contours_style())
    solutions.drawing_utils.draw_landmarks(
        image=annotated_image,
        landmark_list=face_landmarks_proto,
        connections=mp.solutions.face_mesh.FACEMESH_IRISES,
          landmark_drawing_spec=None,
          connection_drawing_spec=mp.solutions.drawing_styles
          .get_default_face_mesh_iris_connections_style())

  return annotated_image

def plot_face_blendshapes_bar_graph(face_blendshapes):
  # Extract the face blendshapes category names and scores.
  face_blendshapes_names = [face_blendshapes_category.category_name for face_blendshapes_category in face_blendshapes]
  face_blendshapes_scores = [face_blendshapes_category.score for face_blendshapes_category in face_blendshapes]
  # The blendshapes are ordered in decreasing score value.
  face_blendshapes_ranks = range(len(face_blendshapes_names))

  fig, ax = plt.subplots(figsize=(12, 12))
  bar = ax.barh(face_blendshapes_ranks, face_blendshapes_scores, label=[str(x) for x in face_blendshapes_ranks])
  ax.set_yticks(face_blendshapes_ranks, face_blendshapes_names)
  ax.invert_yaxis()

  # Label each bar with values
  for score, patch in zip(face_blendshapes_scores, bar.patches):
    plt.text(patch.get_x() + patch.get_width(), patch.get_y(), f"{score:.4f}", va="top")

  ax.set_xlabel('Score')
  ax.set_title("Face Blendshapes")
  plt.tight_layout()
  plt.show()