package detect

import (
	"embed"
	"encoding/binary"
	"ewdetect/utils"
	"math"
	"os"
	"strings"

	"ewdetect/config"

	"github.com/argusdusty/gofft"
	"github.com/rs/zerolog/log"
	"gonum.org/v1/plot"
	"gonum.org/v1/plot/plotter"
	"gonum.org/v1/plot/vg"
)

//go:embed assets/p_wave_sample.bin
//go:embed assets/s_wave_sample.bin
var f embed.FS

var PWaveSample []int32
var SWaveSample []int32

var ExampleFile []int32

func Init() {
	pFile, err := f.Open("assets/p_wave_sample.bin")
	utils.CheckError(err)
	defer pFile.Close()

	sFile, err := f.Open("assets/s_wave_sample.bin")
	utils.CheckError(err)
	defer sFile.Close()

	var val int32
	for {
		err := binary.Read(pFile, binary.LittleEndian, &val)
		if err != nil {
			break
		}
		PWaveSample = append(PWaveSample, val)
	}

	for {
		err := binary.Read(sFile, binary.LittleEndian, &val)
		if err != nil {
			break
		}
		SWaveSample = append(SWaveSample, val)
	}

	log.Debug().Msg("Wave samples initialized")
}

func CrossCorrelation(a *[config.BufferSize]int32, b *[]int32, offset int, headPosition int, thresholdMode bool, threshold float64, drawDebugOutput bool, debugOutputFilename string) (int, float64, float64) { // find b's offset in a
	if len(*a) < len(*b) {
		log.Fatal().Msg("Analyzed signal must be longer than the reference signal")
	}

	length := 1
	for length < len(*a) {
		length *= 2
	}
	x := make([]complex128, length)
	y := make([]complex128, length)

	for i := range len(*a) {
		x[i] = complex(float64((*a)[(i+headPosition)%config.BufferSize]), 0)
	}

	for i := range len(*b) {
		y[length-i-1] = complex(float64((*b)[i]), 0)
	}
	err := gofft.FastConvolve(x, y)
	utils.CheckError(err)

	if config.Debug && drawDebugOutput {
		pts := make(plotter.XYs, len(x))
		for i := range x {
			pts[i].X = float64(i)
			pts[i].Y = real(x[i])
		}

		p := plot.New()
		p.Add(plotter.NewGrid())
		line, err := plotter.NewLine(pts)
		utils.CheckError(err)
		p.Add(line)

		err = p.Save(8*vg.Inch, 4*vg.Inch, debugOutputFilename)
		utils.CheckError(err)
		log.Debug().Str("filename", debugOutputFilename).Msg("Debug plot saved")
	}

	var dc float64
	// Calculate DC
	for i := range x {
		if i < config.BufferSize {
			dc += real(x[i])
		}
	}
	dc = dc / float64(config.BufferSize)

	var sumSquares float64
	// Calculate RMS
	for i := range x {
		if i < config.BufferSize {
			sumSquares += (real(x[i]) - dc) * (real(x[i]) - dc)
		}
	}
	rms := math.Sqrt(sumSquares / float64(config.BufferSize))

	max := real(x[offset])
	tMax := offset
	if thresholdMode {
		for i := range len(x) - offset {
			if i < config.BufferSize {
				if real(x[(i+offset)%config.BufferSize])/rms > threshold {
					max = real(x[(i+offset)%config.BufferSize])
					tMax = (i + offset) % config.BufferSize
					break
				}
			}
		}
	} else {
		for i := range len(x) - offset {
			if i < config.BufferSize {
				if real(x[(i+offset)%config.BufferSize]) > max {
					max = real(x[(i+offset)%config.BufferSize])
					tMax = (i + offset) % config.BufferSize
				}
			}
		}
	}

	log.Debug().
		Float64("rms", rms).
		Float64("max", max).
		Int("tMax", tMax).
		Msg("Cross correlation completed")

	return tMax, rms, max
}

func DetectWaves(a *[config.BufferSize]int32, headPosition int) (int, int, float64, float64, float64, float64) {
	pWaveArrival, pWaveRMS, pWaveMax := CrossCorrelation(a, &PWaveSample, 0, headPosition, true, config.PWaveSensitivity, false, "")
	sWaveArrival, sWaveRMS, sWaveMax := CrossCorrelation(a, &SWaveSample, pWaveArrival+config.MinWaveDelayInSamples, headPosition, false, 0, false, "")

	log.Debug().
		Int("pWaveArrival", pWaveArrival).
		Int("sWaveArrival", sWaveArrival).
		Float64("pWaveRMS", pWaveRMS).
		Float64("sWaveRMS", sWaveRMS).
		Msg("Wave detection completed")

	return pWaveArrival, sWaveArrival, pWaveRMS, sWaveRMS, pWaveMax, sWaveMax
}

func ThresholdDetectWaves(a *[config.BufferSize]int32, headPosition int, detectorName string, outputDebugGraphs bool) (bool, int, int) {
	pWaveArrival, sWaveArrival, pWaveRMS, sWaveRMS, pWaveMax, sWaveMax := DetectWaves(a, headPosition)
	if pWaveMax/pWaveRMS > config.PWaveSensitivity && sWaveMax/sWaveRMS > config.SWaveSensitivity {
		if config.Debug && outputDebugGraphs {
			splitPath := strings.Split(detectorName, "/")
			path_a := "debug/cross-correlation/" + strings.Join(splitPath[:len(splitPath)-1], "/")
			path_b := "debug/traces/" + strings.Join(splitPath[:len(splitPath)-1], "/")
			if _, err := os.Stat(path_a); os.IsNotExist(err) {
				err := os.Mkdir(path_a, os.ModePerm)
				utils.CheckError(err)
				err = os.Mkdir(path_b, os.ModePerm)
				utils.CheckError(err)
			}
			CrossCorrelation(a, &PWaveSample, 0, headPosition, true, config.PWaveSensitivity, true, "debug/cross-correlation/"+detectorName+"pwave-correlation.png")
			CrossCorrelation(a, &SWaveSample, pWaveArrival+config.MinWaveDelayInSamples, headPosition, false, -1, true, "debug/cross-correlation/"+detectorName+"swave-correlation.png")
			traceOutputFileName := "debug/traces/" + detectorName + "trace.png"

			pts := make(plotter.XYs, len(a))
			for i := range len(a) {
				pts[i].X = float64(i)
				pts[i].Y = float64(a[i])
			}
			p := plot.New()
			p.Add(plotter.NewGrid())
			line, err := plotter.NewLine(pts)
			utils.CheckError(err)
			p.Add(line)
			err = p.Save(8*vg.Inch, 4*vg.Inch, traceOutputFileName)
			utils.CheckError(err)
			log.Debug().Str("filename", traceOutputFileName).Msg("Debug plot saved")
			log.Debug().Str("detector", detectorName).Msg("Debug graphs generated")
		}
		return true, pWaveArrival, sWaveArrival
	}
	return false, pWaveArrival, sWaveArrival
}
