import * as tf from '@tensorflow/tfjs';

import { logger } from 'Utils/logger';
import { serialization } from '@tensorflow/tfjs-core';

/**
 * Create a new TF instance
 * @class
 */
export class Tensorflow {
  operations = {
    squeeze: (tensor) => tensor.squeeze(),
    exp: (tensor) => tensor.exp(),
    argMax: (tensor, axis) => tensor.argMax(axis),
    concat: (arrOfTensors, axis) => tf.concat(arrOfTensors, axis),
  };

  /**
   * Run on class instantiation
   * @constructor
   */
  constructor(backend = 'wasm') {
    this.backend = backend;
  }

  async setBackend(backend) {
    // Dynamically load backends for perf
    if (backend === 'cpu') require('@tensorflow/tfjs-backend-cpu');
    if (backend === 'webgl') require('@tensorflow/tfjs-backend-webgl');
    if (backend === 'wasm') require('@tensorflow/tfjs-backend-wasm');
    if (backend === 'webgpu') require('@nanopore/tfjs-backend-webgpu');

    const backendSet = await tf.setBackend(backend);
    if (!backendSet) throw logger.error('mlerror', 'Failed to set backend');
    await tf.ready();
    logger.log('backendLoaded', backend);
  }

  async loadModel(path, name, availableOffline) {
    try {
      if (!path && !name) throw logger.error('noModelUrl');
      const strJoin = process.env.PUBLIC_URL.endsWith('/') ? '' : '/';
      const modelPath = availableOffline ? `indexeddb://${name}` : `${process.env.PUBLIC_URL}${strJoin}${path}`;
      this.modelPath = modelPath;
      await this.setBackend(this.backend);
      this.model = await tf.loadLayersModel(modelPath);
      // this.model = await tf.loadGraphModel(modelPath);

      if (!availableOffline) await this.model.save(`indexeddb://${name}`);
      logger.log('modelLoaded', name);
    } catch (e) {
      throw logger.error('mlerror', 'Failed to load model');
    }
  }

  async warmupModel(shape) {
    const warmupResult = await this.model.predict(tf.zeros(shape), { batchSize: shape[0] });
    await warmupResult.data();
    warmupResult.dispose();
  }

  createTensor(slice, shape, type = 'float32') {
    if (!slice || typeof shape !== 'object') throw logger.error('badTensorShape');
    return tf.tensor(slice, shape, type);
  }

  logMemory() {
    let mem = tf.memory();
    const bytes = `#Bytes ${mem.numBytes}`;
    const gpuBytes = `#GPUBytes ${mem.numBytesInGPU}`;
    const dataBuffers = `#numDataBuffers ${mem.numDataBuffers}`;
    const tensors = `#Tensors ${mem.numTensors}`;
    console.log(bytes, gpuBytes, dataBuffers, tensors);
  }

  async startInference(tensors, batchSize) {
    // const sessionOutput = await this.model.executeAsync(tensors);
    const sessionOutput = await this.model.predict(tensors, { batchSize });
    return sessionOutput;
  }

  operation(type, p1, p2, p3) {
    if (typeof this.operations[type] === 'function') {
      const opResults = this.operations[type](p1, p2, p3);
      return opResults;
    } else {
      logger.error('operationNotFound', type);
    }
    return false;
  }

  async extractReadData({ tensor, batchMetaData, modelStride, overlapAmount }) {
    this.startScope();
    const reads = [];
    let sliceStart = 0;

    for (const readMeta of batchMetaData) {
      const { currentChunkAmount, isEndOfRead, lastOverlapAmount } = readMeta;

      const readData = await tensor.slice(sliceStart, currentChunkAmount);
      sliceStart += currentChunkAmount;

      const readChunks = readData.split(currentChunkAmount, 0);
      const newOverlapAmount = Math.round(overlapAmount / modelStride);
      const newLastOverlapAmount = Math.round(lastOverlapAmount / modelStride);
      const slicedData = this.sliceData(readChunks, newOverlapAmount, newLastOverlapAmount, isEndOfRead);

      const combinedTensorData = await Promise.all(slicedData);

      readData.dispose();
      for (const readChunk of readChunks) {
        readChunk.dispose();
      }

      reads.push({ data: combinedTensorData, ...readMeta });
    }

    this.endScope();
    return reads;
  }

  async reset() {
    tf.engine().reset();
    await tf.setBackend(this.backend);
    await tf.ready();
    this.model = await tf.loadLayersModel(this.modelPath);
  }

  sliceData(tensorArray, stitchAmount, endStitchAmount, isEndOfRead) {
    const slicedTensors = [];
    for (const [i, tensor] of tensorArray.entries()) {
      const shape = tensor.shape;

      let sliceValue = stitchAmount;
      let isLastTensor = false;
      let isShortRead = false;

      if (isEndOfRead) {
        isLastTensor = i === tensorArray.length - 1;
        isShortRead = isLastTensor && tensorArray.length == 1;
        const isSecondToLastTensor = i === tensorArray.length - 2;

        if (isShortRead || isSecondToLastTensor) {
          sliceValue = endStitchAmount;
        }
      }

      slicedTensors.push(
        isLastTensor && !isShortRead ? tensor.data() : tensor.slice(0, [shape[0], shape[1] - sliceValue]).data(),
      );
    }
    return slicedTensors;
  }

  async startProfiling(callback) {
    return await tf.profile(callback);
  }

  endProfiling() {
    return true;
  }

  startScope() {
    tf.engine().startScope();
  }

  endScope() {
    tf.engine().endScope();
  }
}

class FlipLayer extends tf.layers.Layer {
  constructor(config = {}) {
    super(config);
    this.dims = config.dims;
  }

  static get className() {
    return 'Flip';
  }

  getConfig() {
    const config = super.getConfig();
    return { ...config, dims: this.dims };
  }

  call(inputs) {
    return tf.tidy(() => {
      return inputs[0].reverse(this.dims);
    });
  }
}

class ScaleLayer extends tf.layers.Layer {
  constructor(config = {}) {
    super(config);
    this.scale = config.scale;
  }

  static get className() {
    return 'Scale';
  }

  getConfig() {
    const config = super.getConfig();
    return { ...config, scale: this.scale };
  }

  call(inputs) {
    return tf.tidy(() => {
      return inputs[0].mul(tf.scalar(this.scale));
    });
  }
}
class ZeroPadding1DLayer extends tf.layers.Layer {
  constructor(config = {}) {
    super(config);
    this.padding = config.padding;
  }

  static get className() {
    return 'ZeroPadding1D';
  }

  computeOutputShape(shape) {
    const batch = this.padding.reduce((total, curr) => total + curr, shape[1]);
    return [null, batch, shape[2]];
  }

  getConfig() {
    const config = super.getConfig();
    return { ...config, padding: this.padding };
  }

  call(inputs) {
    return tf.tidy(() => {
      return inputs[0].pad([[0, 0], this.padding, [0, 0]]);
    });
  }
}

class ClipLayer extends tf.layers.Layer {
  constructor(config = {}) {
    super(config);
    this.min = config.min;
    this.max = config.max;
  }

  static get className() {
    return 'Clip';
  }

  getConfig() {
    const config = super.getConfig();
    return { ...config, min: this.min, max: this.max };
  }

  call(inputs) {
    return tf.tidy(() => {
      return inputs[0].clipByValue(this.min, this.max);
    });
  }
}

serialization.registerClass(FlipLayer);
serialization.registerClass(ScaleLayer);
serialization.registerClass(ZeroPadding1DLayer);
serialization.registerClass(ClipLayer);
