import * as math from "mathjs";
import * as ml from "ml-matrix";
import { KalmanFilter as BaseKalmanFilter } from "./3rdparty/kalmanFilter";
import { SECOND_TO_MICRO_SECONDS } from "./3rdparty/tfjs-models/shared/calculators/constants";

function createFk(C: number, N: number, dt: number): ml.Matrix {
  let F0: math.Matrix;
  if (C === 1) {
    F0 = math.matrix([[1]]);
  } else if (C === 2) {
    F0 = math.matrix([
      [1, dt],
      [0, 1],
    ]);
  } else if (C === 3) {
    F0 = math.matrix([
      [1, dt, dt ** 2 / 2],
      [0, 1, dt],
      [0, 0, 1],
    ]);
  } else {
    throw new Error("Invalid Kalman Filters");
  }

  const I = math.identity(N) as math.Matrix;
  const F1 = math.kron(I, F0);
  const F2 = F1.toArray() as number[][];
  const F = new ml.Matrix(F2);
  return F;
}

function createQk(C: number, N: number, varQ: number, dt: number): ml.Matrix {
  let Q0: math.Matrix;
  if (C === 1) {
    Q0 = math.matrix([[1]]);
  } else if (C === 2) {
    Q0 = math.matrix([
      [dt ** 2, dt],
      [dt, 1],
    ]);
  } else if (C === 3) {
    // Continuous acceleration
    // const Q0 = math.matrix([
    //   [dt ** 5 / 20, dt ** 4 / 8, dt ** 3 / 6],
    //   [dt ** 4 / 8, dt ** 3 / 3, dt ** 2 / 2],
    //   [dt ** 3 / 6, dt ** 2 / 2, dt],
    // ]);

    // Discrete acceleration
    Q0 = math.matrix([
      [dt ** 4 / 4, dt ** 3 / 2, dt ** 2 / 2],
      [dt ** 3 / 2, dt ** 2, dt],
      [dt ** 2 / 2, dt, 1],
    ]);
  } else {
    throw new Error("Invalid Kalman Filters");
  }

  const Q1 = math.multiply(Q0, varQ);
  const I = math.identity(N) as math.Matrix;
  const Q2 = math.kron(I, Q1);
  const Q3 = Q2.toArray() as number[][];
  const Q = new ml.Matrix(Q3);
  return Q;
}

function createR(N: number, varR = 1): ml.Matrix {
  return ml.Matrix.eye(N, N).mul(varR);
}

function createH(C: number, N: number): ml.Matrix {
  const H = ml.Matrix.zeros(N, N * C);
  for (let i = 0; i < N; i += 1) {
    H.set(i, i * C, 1);
  }
  return H;
}

export interface KalmanFilterConfig {
  /* Number of derivatives + 1:
  1: constant space
  2: constant velocity
  3: constant acceleration
  */
  C: 1 | 2 | 3;

  // variance for Q matrix
  varQ: number;

  // variance for R matrix
  varR: number;
}

export class KalmanFilter {
  private filter?: BaseKalmanFilter;

  // Number of observations
  private N: number;

  private config: KalmanFilterConfig;

  // H matrix
  private H: ml.Matrix;

  // R matrix
  private R: ml.Matrix;

  private lastTimestamp?: number;

  constructor(config: KalmanFilterConfig, N: number) {
    this.config = config;
    this.N = N;
    this.H = createH(this.config.C, this.N);
    this.R = createR(this.N, this.config.varR);
    this.filter = undefined;
    this.lastTimestamp = undefined;
  }

  reset() {
    this.filter = undefined;
    this.lastTimestamp = undefined;
  }

  apply(observation: number[], timestamp: number) {
    const dt =
      (timestamp - (this.lastTimestamp ?? timestamp)) / SECOND_TO_MICRO_SECONDS;
    this.lastTimestamp = timestamp;
    if (this.filter === undefined) {
      if (observation.length !== this.N) {
        throw new Error("Invalid observation");
      }

      const initialState = new Array<number>(this.N * this.config.C);
      initialState.fill(0);
      for (let i = 0; i < this.N; i += 1) {
        initialState[i * this.config.C] = observation[i];
      }
      const initialStateCovariance = ml.Matrix.eye(this.N * this.config.C).mul(
        0
      );

      this.filter = new BaseKalmanFilter(initialState, initialStateCovariance);
    }

    const Fk = createFk(this.config.C, this.N, dt);
    const Qk = createQk(this.config.C, this.N, this.config.varQ, dt);
    const alpha = 1;

    this.filter.predict(Fk, Qk, alpha);
    this.filter.update(observation, this.H, this.R);

    const filtered = this.H.mmul(this.filter.state).to1DArray();
    return filtered;
  }
}

export class KalmanFilter1D {
  private filter: KalmanFilter;

  constructor(config: KalmanFilterConfig) {
    this.filter = new KalmanFilter(config, 1);
  }

  reset() {
    this.filter.reset();
  }

  apply(observation: number, timestamp: number, valueScale: number): number {
    const filtered = this.filter.apply([observation * valueScale], timestamp);
    return filtered[0];
  }
}
