import * as tf from "@tensorflow/tfjs";
import "@tensorflow/tfjs-backend-webgl";

export default class ARHandler {
  constructor(cornersUpdateCallback, camUpdateCallback) {
    this.cornersUpdateCallback = cornersUpdateCallback;
    this.camUpdateCallback = camUpdateCallback;
  }

  async init() {
    this.loadedModel = await this.loadModel();
    this.module = await window.createDigicampModule().then((module) => {
      module.getHeight = () => {
        return document.getElementById("camera-feed").offsetHeight;
      };
      module.getWidth = () => {
        return document.getElementById("camera-feed").offsetWidth;
      };
      module.updateCorners = (cornersPtr) => {
        const corners = new Float32Array(module.HEAPF32.buffer, cornersPtr, 28);
        this.cornersUpdateCallback(corners.slice(0)); // copy corners array
        module._free(cornersPtr);
      };
      module.updateCam = (camPtr, net, success) => {
        const cam = new Float32Array(module.HEAPF32.buffer, camPtr, 6);
        this.camUpdateCallback(cam.slice(0), net, success);
        module._free(camPtr);
      };
      return module;
    });
  }

  async loadModel() {
    const modelUrl = "model/model.json";
    const model = await tf.loadGraphModel(modelUrl);
    return model;
  }

  async predictImage(image, position, lookAt) {
    tf.engine().startScope();
    const tensor = tf.browser
      .fromPixels(image)
      .resizeNearestNeighbor([640, 640])
      .toFloat();
    const input = tensor.div(255.0).expandDims();
    const start = new Date().getTime();
    const predictions = await this.loadedModel.predict(input);
    const pred = new Date().getTime();
    console.log("Prediction time: ", pred - start, "ms");
    const nmsResult = await this.nonMaximumSuppression(predictions);
    if (nmsResult) {
      const cornersPointer = this.module._malloc(
        28 * Float32Array.BYTES_PER_ELEMENT
      );
      this.module.HEAPF32.set(
        nmsResult.corners,
        cornersPointer / Float32Array.BYTES_PER_ELEMENT
      );
      const positionPointer = this.module._malloc(
        3 * Float32Array.BYTES_PER_ELEMENT
      );
      this.module.HEAPF32.set(
        position,
        positionPointer / Float32Array.BYTES_PER_ELEMENT
      );
      const lookAtPointer = this.module._malloc(
        3 * Float32Array.BYTES_PER_ELEMENT
      );
      this.module.HEAPF32.set(
        lookAt,
        lookAtPointer / Float32Array.BYTES_PER_ELEMENT
      );

      this.module._calcCameraPose(
        nmsResult.class,
        cornersPointer,
        positionPointer,
        lookAtPointer
      );

      this.module._free(cornersPointer);
      this.module._free(lookAtPointer);
      this.module._free(positionPointer);
    } else {
      this.cornersUpdateCallback();
    }
    console.log("Postprocess time: ", new Date().getTime() - pred, "ms");
    console.log("--------------------");
    tf.engine().endScope();
  }

  async nonMaximumSuppression(
    imagePrediction,
    confThres = 0.6,
    iouThres = 0.45,
    maxDet = 1
  ) {
    const [batchSize, numBoxes, numElementsPerBox] = imagePrediction.shape; // [1, 54, 8400]
    if (batchSize !== 1) {
      console.error(
        `Expected batchSize to be 1, but got ${batchSize} instead.`
      );
      return null;
    }

    const flattenedPrediction = tf.transpose(
      imagePrediction.reshape([numBoxes, numElementsPerBox]),
      [1, 0]
    ); // [8400, 54]

    const boxes = flattenedPrediction.slice([0, 0], [-1, 4]); // first 4 elements are the box coordinates, [8400, 4]
    const scores = flattenedPrediction.slice([0, 4], [-1, 22]); // next 22 elements are the scores per class, [8400, 22]
    const maxScores = tf.max(scores, 1); // [8400]

    const indices = await tf.whereAsync(
      tf.greaterEqual(maxScores, confThres).asType("bool")
    );

    if (indices.size === 0) {
      return null;
    }

    const filteredScores = tf.gather(maxScores, indices, 0).squeeze();
    const filteredBoxes = tf.gather(boxes, indices, 0).squeeze();

    if (filteredScores.rank !== 1 || filteredBoxes.rank !== 2) {
      console.warn(
        "Got unexpected rank for filteredScores or filteredBoxes:",
        filteredScores.rank,
        filteredBoxes.rank
      );
      return null;
    }

    const nmsIndex = await tf.image.nonMaxSuppressionAsync(
      filteredBoxes,
      filteredScores,
      maxDet,
      iouThres
    );

    if (nmsIndex.size !== 1) {
      console.warn("Got !=1 nms result:", nmsIndex.size);
      return null;
    }

    const corners = tf.gather(
      flattenedPrediction.slice([0, 26], [-1, -1]), // last 28 elements are the corners
      indices
    );
    const selectedCorners = tf.gather(corners, nmsIndex, 0);
    const predictedClass = tf.argMax(
      tf.gather(scores, tf.gather(indices, nmsIndex), 0).squeeze(),
      0
    );
    console.log(
      tf.gather(scores, tf.gather(indices, nmsIndex), 0).shape,
      predictedClass.dataSync()
    );

    return {
      corners: selectedCorners.dataSync(),
      class: predictedClass.dataSync()[0],
    };
  }
}
