import { MultiChannelSignalRecorder, Recordable, SignalRecorder } from '@egzotech/exo-session/features/common';
import { SampleBasedTimer } from 'libs/exo-session-manager/core/common/SampleBasedTimer';

export type DefaultSignalRecorderId =
  | `${'knee' | 'toes' | 'heel' | 'torque' | 'extension'}-force`
  | `${'knee' | 'ankle' | 'main'}-angle`;

export type Recordings<T extends string = DefaultSignalRecorderId> = {
  [key in T]?: {
    samples: Float32Array;
    timePoints: Uint32Array;
  };
} & {
  emg?: Record<
    number,
    {
      samples: Float32Array;
      timePoints: Uint32Array;
    }
  >;
};

export type SerializedRecordings<T extends string = DefaultSignalRecorderId> = {
  [key in T]?: {
    samples: number[];
    timePoints: number[];
  };
} & {
  emg?: Record<
    number,
    {
      samples: number[];
      timePoints: number[];
    }
  >;
};

export type RecorderController<T extends string = DefaultSignalRecorderId> =
  | {
      id: T;
      recorder: SignalRecorder;
    }
  | {
      id: 'emg';
      recorder: MultiChannelSignalRecorder;
    };

export class SignalRecorderController {
  private recorders: RecorderController[] = [];
  private intervalId: NodeJS.Timer | null = null;
  recordings: Recordings = {};

  private _startTime = 0;

  static INTERVAL_TIME = 50;

  get started() {
    return this._startTime !== 0;
  }

  get timestamp() {
    if (this.timeSource) {
      return this.timeSource.duration * 1000;
    }
    return Date.now() - this._startTime;
  }

  private getTimestampForChannel(channel: number) {
    return this.timeSource ? this.timeSource.getTimestampForChannel(channel) * 1000 : this.timestamp;
  }

  constructor(
    private recordables: Recordable<'single' | 'multi'>[],
    private channels?: number[],
    private timeSource?: SampleBasedTimer,
  ) {
    recordables.forEach(recordable => {
      if (recordable.recordableType === 'multi') {
        if (recordable.recordableId === 'emg' && channels) {
          this.recorders.push({ id: recordable.recordableId, recorder: new MultiChannelSignalRecorder(channels) });
        }
      } else {
        this.recorders.push({ id: recordable.recordableId as DefaultSignalRecorderId, recorder: new SignalRecorder() });
      }
    });
  }

  start() {
    this.recorders.forEach(v => v.recorder.start());
    this._startTime = Date.now();
    this.record();
  }

  pause() {
    this.recorders.forEach(v => v.recorder.pause());
    if (this.intervalId !== null) {
      clearInterval(this.intervalId);
      this.intervalId = null;
    }
  }

  resume() {
    this.recorders.forEach(v => v.recorder.resume());
    this.record();
  }

  reset() {
    this.pause();
    this._startTime = 0;
    this.recorders.forEach(v => v.recorder.stop());
  }

  private record() {
    this.intervalId = setInterval(() => {
      this.recordables.forEach(recordable => {
        if (recordable.recordableType === 'multi' && recordable.recordableId === 'emg') {
          const emgRecorder = this.recorders.find(({ id }) => id === recordable.recordableId)?.recorder as
            | MultiChannelSignalRecorder
            | undefined;
          const snapshot = (recordable as Recordable<'multi'>).getSnapshot();

          // Timestamp for each channel can be different
          for (const _shapshotChannel in snapshot) {
            const snapshotChannel = Number(_shapshotChannel);
            emgRecorder?.record(
              { [Number(snapshotChannel)]: snapshot[snapshotChannel] } as Record<number, number>,
              this.getTimestampForChannel(snapshotChannel),
            );
          }
        } else {
          const recorder = this.recorders.find(({ id }) => id === recordable.recordableId)?.recorder as SignalRecorder;
          const snapshot = (recordable as Recordable<'single'>).getSnapshot();
          const lastSnapshotValue = snapshot?.at(-1);
          if (recorder && typeof lastSnapshotValue === 'number') {
            recorder.record(lastSnapshotValue, this.timestamp);
          }
        }
      });
    }, SignalRecorderController.INTERVAL_TIME);
  }

  retrieveRange({ min, max, recorderId }: { min: number; max: number; recorderId?: 'emg' }): ({
    samples: Float32Array;
    timePoints: Uint32Array;
  } | null)[];
  retrieveRange({ min, max, recorderId }: { min: number; max: number; recorderId?: DefaultSignalRecorderId }): {
    samples: Float32Array;
    timePoints: Uint32Array;
  } | null;
  retrieveRange({
    min,
    max,
    recorderId = 'emg',
  }: {
    min: number;
    max: number;
    recorderId?: DefaultSignalRecorderId | 'emg';
  }) {
    const recorder = this.recorders.find(rec => rec.id === recorderId)?.recorder;

    if (!recorder) {
      throw new Error('EMG recorder is not present');
    }

    return recorder.retrieveRange({ min, max });
  }

  /**
   * In some rare cases it can happen that samples length is not the same as time points length,
   * this method adjust arrays to common length for those cases
   */
  adjustSamples(recordings = this.recordings) {
    for (const key in recordings) {
      const recordingKey = key as keyof typeof recordings;
      if (recordingKey !== 'emg') {
        const recordingData = recordings[recordingKey];
        if (recordingData?.samples && recordingData?.timePoints) {
          const minLength = Math.min(recordingData.samples.length, recordingData.timePoints.length);
          if (recordingData.samples.length > minLength) {
            recordingData.samples = recordingData.samples.subarray(0, minLength);
          }
          if (recordingData.timePoints.length > minLength) {
            recordingData.timePoints = recordingData.timePoints.subarray(0, minLength);
          }
        }
      }
    }

    if (recordings.emg) {
      for (const channel in recordings.emg) {
        const emgData = recordings.emg[channel];
        if (emgData?.samples && emgData?.timePoints) {
          const minLength = Math.min(emgData.samples.length, emgData.timePoints.length);
          if (emgData.samples.length > minLength) {
            emgData.samples = emgData.samples.subarray(0, minLength);
          }
          if (emgData.timePoints.length > minLength) {
            emgData.timePoints = emgData.timePoints.subarray(0, minLength);
          }
        }
      }
    }

    return recordings;
  }

  stop() {
    this.pause();
    this._startTime = 0;
    this.recorders.forEach(v => {
      if (v.id === 'emg') {
        const emgRecorder = v.recorder as MultiChannelSignalRecorder;
        const emg = emgRecorder.stop().reduce(
          (prev, curr, i) => {
            if (curr && this.channels) {
              prev[this.channels[i]] = curr;
              return prev;
            }
            return prev;
          },
          {} as Record<
            number,
            {
              samples: Float32Array;
              timePoints: Uint32Array;
            }
          >,
        );
        this.recordings.emg = emg;
      } else {
        const recorder = v.recorder as SignalRecorder;
        const record = recorder.stop();
        if (record) {
          this.recordings[v.id] = record;
        }
      }
    });
    return this.adjustSamples(this.recordings);
  }

  addBaselineRecording(baselineRecording: Recordings['emg'], timeOffset: number) {
    if (this.recordings.emg) {
      for (const channel in this.recordings.emg) {
        const baselineChannel = baselineRecording?.[channel];
        if (!baselineChannel) {
          throw new Error(`Channel: ${channel} is not present in baseline signal recording`);
        }
        const recordingChannel = this.recordings.emg[channel];

        const newSamples = new Float32Array(baselineChannel.samples.length + recordingChannel.samples.length);
        newSamples.set(baselineChannel.samples, 0);
        newSamples.set(recordingChannel.samples, baselineChannel.samples.length);
        this.recordings.emg[channel].samples = newSamples;

        const newTimePoints = new Uint32Array(baselineChannel.timePoints.length + recordingChannel.timePoints.length);
        newTimePoints.set(baselineChannel.timePoints, 0);
        const shiftedPoints = new Uint32Array(Array.from(recordingChannel.timePoints.map(point => point + timeOffset)));
        newTimePoints.set(shiftedPoints, baselineChannel.timePoints.length);
        this.recordings.emg[channel].timePoints = newTimePoints;
      }
    }
    return this.adjustSamples(this.recordings);
  }
}
