import { useCallback, useEffect, useRef, useState } from "react";
import Webcam from "react-webcam";
import * as tf from "@tensorflow/tfjs";
import { Measurements, MeasurementsResults } from "./Measurements";
import { Capturing, CaptureImages } from "./Capturing";
import usageTracker from "./UsageTracking/UsageTracker";
import { UsageTrackerEvent } from "./UsageTracking/UsageTracker";
import { css } from '@emotion/css';

const styles = {
  container: css({
    position: 'absolute',
    marginLeft: 'auto',
    marginRight: 'auto',
    left: '0',
    right: '0',
    zIndex: '-1'
  }),
  webCam: css({
    height: '100vh',
    width: '100%',
    objectFit: 'cover'
  }),
};
const inputImageSize = 224;

const enum Stage {
  Loading,
  Capturing,
  Reconstructing,
}

const contraints = {
  facingMode: "environment",
  aspectRatio: 1.3333,
  width: 1200,
};

const canvasToBlob = (canvas: HTMLCanvasElement): Promise<Blob> => {
  return new Promise((resolve, reject) =>
    canvas.toBlob((blob) => (blob ? resolve(blob) : reject()), "image/jpeg")
  );
};

const getMeasurements = async (
  images: CaptureImages
): Promise<MeasurementsResults> => {
  const response: MeasurementsResults = await Promise.all(images)
    .then((loaded_views) => {
      let formdata = new FormData();
      formdata.append("view1", loaded_views[0], "view1.jpg");
      formdata.append("view2", loaded_views[1], "view2.jpg");
      return formdata;
    })
    .then((formdata) => {
      return fetch(
        "https://moby-paper-scanner-stg-p5cn5bgu7a-lz.a.run.app/scan",
        {
          method: "POST",
          body: formdata,
        }
      );
    })
    .then((res) => res.json())
    .then((data: Measurements) => {
      usageTracker.track(UsageTrackerEvent.reconstructionSuccess)
      return { ok: true, value: data };
    })
    .catch((error: Error) => {
      usageTracker.track(UsageTrackerEvent.reconstructionFailure)
      return { ok: false, error: error };
    });
  return response;
};

const preprocessImage = (image: tf.Tensor): tf.Tensor => {
  return tf.expandDims(image.div(127.5).sub(1), 0);
};

const normalizePoints = (
  points: number[][],
  new_size: number[],
  old_size: number[],
  x_pad: number,
  y_pad: number
) => {
  return points.map((point) => [
    (inputImageSize / new_size[0]) *
      (point[0] - x_pad / inputImageSize) *
      old_size[0],
    (inputImageSize / new_size[1]) *
      (point[1] - y_pad / inputImageSize) *
      old_size[1],
    point[2],
  ]);
};

const predict = (
  model: tf.GraphModel,
  videoCanvas: HTMLCanvasElement,
  outputSize: { width: number; height: number }
) => {
  return tf.tidy(() => {
    let tensor = tf.browser.fromPixels(videoCanvas);
    const x_pad = Math.floor((inputImageSize - tensor.shape[1]) / 2);
    const y_pad = Math.floor((inputImageSize - tensor.shape[0]) / 2);
    let padded = tensor.pad([
      [y_pad, inputImageSize - tensor.shape[0] - y_pad],
      [x_pad, inputImageSize - tensor.shape[1] - x_pad],
      [0, 0],
    ]);
    const points = (
      model!.predict(preprocessImage(padded)) as tf.Tensor3D
    ).arraySync()[0];
    return normalizePoints(
      points,
      [videoCanvas.width, videoCanvas.height],
      [outputSize.width, outputSize.height],
      x_pad,
      y_pad
    );
  });
};

interface ScanningProps {
  resultsCallback: (measurements?: MeasurementsResults) => void;
}

export const Scanning: React.FC<ScanningProps> = (props) => {
  const [stage, setStage] = useState(Stage.Loading);
  const [model, setModel] = useState<tf.GraphModel>();
  const paperCornersRef = useRef<number[][]>(new Array(4).fill([0, 0, 0]));
  const webcamRef = useRef<Webcam>(null);
  const shootRef = useRef<HTMLCanvasElement>(null);
  const frameRef = useRef(-1);
  useEffect(() => {
    tf.loadGraphModel("/resources/model.json").then(setModel);
  }, []);

  const onFrame = useCallback(() => {
    if (!model || stage !== Stage.Capturing) {
      return;
    }
    const videoElement = webcamRef.current?.video;
    if (!videoElement) {
      frameRef.current = requestAnimationFrame(onFrame);
      return;
    }

    const scale =
      inputImageSize /
      Math.max(videoElement.videoHeight, videoElement.videoWidth);
    const predictionHeight = scale * videoElement.videoHeight;
    const predictionWidth = scale * videoElement.videoWidth;
    const videoCanvas = webcamRef.current?.getCanvas({
      height: predictionHeight,
      width: predictionWidth,
    });

    if (!videoCanvas) {
      frameRef.current = requestAnimationFrame(onFrame);
      return;
    }
    const newWidth = videoElement.clientWidth;
    const newHeight =
      (videoElement.clientWidth * videoElement.videoHeight) /
      videoElement.videoWidth;
    paperCornersRef.current = predict(model, videoCanvas, {
      width: newWidth,
      height: newHeight,
    });
    frameRef.current = requestAnimationFrame(onFrame);
  }, [model, stage]);
  useEffect(() => {
    frameRef.current = requestAnimationFrame(onFrame);
    return () => cancelAnimationFrame(frameRef.current);
  }, [onFrame]);

  const videoElement = webcamRef.current?.video;
  const aspectRatio = videoElement
    ? videoElement.videoHeight / videoElement.videoWidth
    : 0;

  return (
    <div>
      <div className={ styles.container }>
        <Webcam
          className={ styles.webCam }
          onUserMedia={() => setTimeout(() => setStage(Stage.Capturing), 1000)}
          ref={webcamRef}
          videoConstraints={contraints}
        />
        <canvas ref={shootRef} style={{display: "none"}}/>
      </div>
      {stage === Stage.Loading && <p>Loading</p>}
      {stage === Stage.Capturing && (
        <Capturing
          paperCornersRef={paperCornersRef}
          onCompletion={async (images: CaptureImages) => {
            setStage(Stage.Reconstructing);
            usageTracker.track(UsageTrackerEvent.reconstructionStarted)
            const measurements = await getMeasurements(images);
            props.resultsCallback(measurements);
          }}
          width={videoElement?.clientWidth}
          height={videoElement ? videoElement.clientWidth * aspectRatio : 0}
          screenGrab={() => {
            const ctx = shootRef.current?.getContext("2d");
            if (!videoElement || !shootRef.current || !ctx) {
              return Promise.reject("Video unavailable")
            }
            shootRef.current.height = videoElement.videoHeight
            shootRef.current.width = videoElement.videoWidth
            ctx.drawImage(videoElement!, 0, 0)
            return canvasToBlob(shootRef.current!)}
          }
        />
      )}
      {stage === Stage.Reconstructing && <p>Reconstructing</p>}
    </div>
  );
};
