import GlobalConfig from 'Config';
import { PromisePool } from '@supercharge/promise-pool';
import hrtime from 'browser-process-hrtime';
import { logger } from 'Utils/logger';
import { savePerformanceConfig } from '../../actions';
process.hrtime = hrtime;

export default class AutoBatchSelector {
  constructor(basecaller, tensorSize) {
    this.basecaller = basecaller;
    this.tensorSize = tensorSize;
    this.array = [];
    this.batchMetadata = [];
    this.batchSizeList = {
      Fast: [128, 160, 192, 256],
      Hac: [16, 32, 48],
    };
    this.skipBatchSizes = false;
    this.workerList = [{ inference: 1 }, { inference: 2 }];
    this.currentWorkerConfig;
  }

  async runBenchmark(dataAmount = 4000000) {
    this.batchMetadata = [];
    const algorithm = this.basecaller.model.algorithm.split(' ')[0];

    for (const workers of this.workerList) {
      this.currentWorkerConfig = workers;
      this.skipBatchSizes = false;

      for (const batchSize of this.batchSizeList[algorithm]) {
        if (this.skipBatchSizes === false) {
          console.log(`Running benchmark for batch size ${batchSize}...`);
          await this.basecaller.loadWorkers(this.basecaller.backend, this.basecaller.library, workers);
          await this.basecaller.loadModel(false);
          await this.getBatchMetadata(dataAmount, batchSize);
          this.basecaller.terminateWorkers(workers);
          console.log('Benchmark complete.');
        }
      }
    }

    const benchmarkedData = this.batchMetadata.map((obj) => {
      const timeTakenSeconds = obj.timeTaken / 1000;
      const samplesPerSec = dataAmount / timeTakenSeconds;

      return {
        batchSize: obj.batchSize,
        timeTaken: timeTakenSeconds,
        samples: dataAmount,
        samplesPerS: samplesPerSec,
        inference: obj.inference,
      };
    });

    let bestConfig;
    benchmarkedData.map((data) => {
      if (bestConfig) {
        bestConfig = bestConfig.samplesPerS < data.samplesPerS && data.timeTaken !== 0 ? data : bestConfig;
      } else {
        bestConfig = data;
      }
    });

    const { inference, batchSize } = bestConfig;

    savePerformanceConfig({
      workers: {
        inference,
        decoder: 2,
        file: 1,
      },
      batchSize,
      name: this.basecaller.model.name,
    });

    return benchmarkedData;
  }

  async generateSignal(tensorSize, batchSize) {
    const arraySize = tensorSize * batchSize;
    const array = Array.from({ length: arraySize }, () => Math.random());
    this.array = array;
    return array;
  }

  async getSliceFromSignal(batchSize) {
    if (this.array.length < 1) {
      const algorithm = this.basecaller.model.algorithm.split(' ')[0];
      const selectedBatchList = this.batchSizeList[algorithm];
      await this.generateSignal(this.tensorSize, selectedBatchList[selectedBatchList.length - 1]);
    }

    return this.array.slice(0, this.tensorSize * batchSize);
  }

  async getBatchMetadata(dataAmount, batchSize) {
    const timeTaken = await this.timeTensorRunWithDataAmount(dataAmount, batchSize);
    this.batchMetadata.push({
      batchSize: batchSize,
      timeTaken: timeTaken,
      inference: this.currentWorkerConfig.inference,
    });
  }

  async timeTensorRunWithDataAmount(dataAmount, batchSize) {
    return await this.timeFunction(async () => {
      let batches = [];
      for (let i = 0; i < dataAmount; i = i + this.tensorSize * batchSize) {
        let dataLeft = dataAmount - i;
        let newBatchSize = batchSize;

        if (dataLeft < this.tensorSize * batchSize) {
          const hasRemainder = dataLeft % this.tensorSize > 0 ? 1 : 0;
          newBatchSize = Math.floor(dataLeft / this.tensorSize) + hasRemainder;
        }

        const data = await this.getSliceFromSignal(newBatchSize);

        batches.push({
          samples: data,
          tensorSize: this.tensorSize,
          batchSize: newBatchSize,
          batchIdx: 0,
          batchMetaData: [
            {
              readId: '1',
              lastOverlapAmount: GlobalConfig.mlconfig.overlapAmount,
              totalChunkAmount: newBatchSize,
              numberSamplesInRead: newBatchSize * this.tensorSize,
              currentChunkAmount: newBatchSize,
              isEndOfRead: true,
              readIdx: 0,
            },
          ],
          modelStride: 5,
          overlapAmount: GlobalConfig.mlconfig.overlapAmount,
        });
      }

      try {
        await PromisePool.withConcurrency(this.currentWorkerConfig.inference)
          .for(batches)
          .handleError((e) => {
            throw e;
          })
          .process(async (batch) => {
            await this.basecaller.createAndRunTensors(batch);
          });
      } catch (e) {
        logger.dump('mlerror', e);
        throw e;
      }
    });
  }

  async timeFunction(callback) {
    const timer = process.hrtime();
    try {
      await callback();
      const endTime = process.hrtime(timer);
      const timeInMs = (endTime[0] * 1000000000 + endTime[1]) / 1000000; // First convert to NanoSec then to MS.

      return timeInMs;
    } catch (e) {
      this.skipBatchSizes = true;
      return null;
    }
  }
}
