#!/usr/bin/env python3

import array
import bisect
from dataclasses import dataclass
from datetime import datetime, UTC
import logging
import os
import re
from typing import Union


_LOGGER = logging.getLogger(__name__)

_MARKS_LINE_REGEXP = re.compile(r"^written:\s+(?P<written>\d+),\s+now_ns:\s+(?P<now_ns>\d+)\n$")

@dataclass
class TimedPoint:
    """A timed point of the input sound data.

    written_samples: How many samples had already passed (and were written to the file)
    when the one sample this TimedPoint is concerned about arrived.
    t_ns: Nanoseconds since some arbitrary fixed point in time early in the sampling
    when the one sample this TimedPoint is concerned about arrived.
    """
    written_samples: int
    t_ns: int

@dataclass
class TimedPointExtended(TimedPoint):
    """A timed point that has a backprojected start of sampling.

    Nanoseconds since the arbitrary fixed point in time early in the sampling
    when sampling really started, that is, point in time when 0 samples would have been received.

    Through inaccuracies, this is expected to move over time, so we calculate it anew
    for each TimedPoint, using some smoothing.
    """
    backprojected_sample_start: int
    
class GrabSound:

    def __init__(self, file_basename: str, check_timestamp_quality: bool=True) -> None:
        self.timed_points: list[TimedPoint] = []
        first_time = None
        samples = []
        times = []
        times_samples_avg = []
        with open(f"{file_basename}.marks", "r") as m:
            
            for line in m:
                if mo := _MARKS_LINE_REGEXP.fullmatch(line):
                    written_bytes = int(mo.group("written"))
                    written_samples = written_bytes // 2
                    now_ns = int(mo.group("now_ns"))
                    if first_time is None:
                        first_time = now_ns
                        self.first_time = first_time
                    now_ns -= first_time
                    self.timed_points.append(TimedPoint(written_samples=written_samples, t_ns=now_ns))
                else:
                    raise RuntimeError(f'Cannot parse: "{line}"')

        # Samples per second calculation:
        total_seconds = ((self.timed_points[-1].t_ns - self.timed_points[0].t_ns) / 1e9)
        # Debateable: Do the samples from our first batch of samples count, or not? I say: No.
        total_samples = self.timed_points[-1].written_samples - self.timed_points[0].written_samples
        samples_per_second = total_samples / total_seconds
        # Calculate the nominal value. These values are always multiples of 100.
        self.samples_per_second = round(samples_per_second / 100) * 100
        self.samples_per_second_measured = samples_per_second
        _LOGGER.debug("Time values read for most samples (%d total) at %.4f ≈ %d samples per second",
                      self.timed_points[-1].written_samples, self.samples_per_second_measured, self.samples_per_second)

        # Timestamp of start of sampling, calculated back from what we measured and the nominal sample per seconds value.
        raw_start_of_sampling = list(map(
            lambda tp: tp.t_ns - (1000000000 * tp.written_samples + self.samples_per_second // 2) // self.samples_per_second,
            self.timed_points
        ))
        assert(len(raw_start_of_sampling) == len(self.timed_points))

        # List of TimedPointExtended:
        self.timed_points_extended: list[TimedPointExtended] = []

        # Smooth these somewhat:

        half_window_size = 50
        window_size = 2 * half_window_size
        outlier_range = 2000000 # 2 ms
        for i in range(half_window_size, len(self.timed_points) - half_window_size):
            raw_average = sum(raw_start_of_sampling[i-half_window_size:i+half_window_size]) // window_size

            values_used_count = 0
            values_summed = 0
            for j in range(-half_window_size, half_window_size):
                value = raw_start_of_sampling[i+j]
                if raw_average - outlier_range < value < raw_average + outlier_range:
                    values_used_count += 1
                    values_summed += value

            if (half_window_size < values_used_count) or \
               ((not check_timestamp_quality) and half_window_size < values_used_count * 2):
                start_time = (values_summed + values_used_count // 2) // values_used_count
                
                written_samples = self.timed_points[i].written_samples
                smoothed_time = start_time + \
                    (written_samples * 1000000000 + self.samples_per_second // 2) // self.samples_per_second

                # sanity: We don't want stepping back in time:
                if len(self.timed_points_extended) == 0 or self.timed_points_extended[-1].t_ns < smoothed_time:
                    self.timed_points_extended.append(
                        TimedPointExtended(written_samples=written_samples,
                                           t_ns=smoothed_time,
                                           backprojected_sample_start = start_time)
                    )
                else:
                    _LOGGER.warning("Smoothed sample time at sample index %d went back in time from %f s to %f s",
                                  written_samples, self.timed_point_extended[-1].t_ns * 1e-9, smoothed_time * 1e-9)
            else:
                _LOGGER.debug("Didn't use index %d as only %d values survived smoothing", i, values_used_count)
                pass

        if (len(self.timed_points) - window_size) * 0.98 < len(self.timed_points_extended):
            pass
        else:
            mess = f"Of {len(self.timed_points)} timed points, " \
                f"only {len(self.timed_points_extended)} smootned values resulted. Bad data quality detected."
            if check_timestamp_quality:
                _LOGGER.error(mess)
                raise RuntimeError(mess)
            else:
                _LOGGER.warning(mess)

        with open(f"{file_basename}.raw", "rb", buffering=0) as f:
            num_of_samples = os.stat(f.fileno()).st_size // 2
            self.samples = array.array("h")
            try:
                self.samples.fromfile(f, num_of_samples)
            except:
                _LOGGER.error("Expected %d samples, could read only %d", num_of_samples, len(self.samples))
                raise
            assert(num_of_samples == len(self.samples))
            self.num_of_samples = len(self.samples)

        _LOGGER.debug("Read %d samples, spanning about %.2f minutes",
                      self.num_of_samples, self.num_of_samples / (self.samples_per_second * 60)
        )

    def start_time_raw(self, sample_index: int) -> TimedPointExtended:
        if 0 <= sample_index < self.num_of_samples:
            i = bisect.bisect_left(self.timed_points_extended, sample_index, key=lambda tpe: tpe.written_samples)
            if i < len(self.timed_points_extended):
                return self.timed_points_extended[i]
            else:
                return self.timed_points_extended[-1]
        else:
            raise ValueError(f"Illegal sample index {sample_index}, only 0 to {self.num_of_samples-1} allowed")

    def time_raw(self, sample_index: int) -> int:
        """
        Given a sample index, return the (approximate) ns
        since the basepoint near sample start when that sample happened.
        """
        near_start_ns = self.start_time_raw(sample_index).backprojected_sample_start
        return near_start_ns + (sample_index * 1000000000 + self.samples_per_second // 2) // self.samples_per_second

    def time_ns(self, sample_index: int) -> int:
        """
        Given a sample index, return the (approximate) ns
        since the epoch when that sample happened.
        """
        return self.time_raw(sample_index) + self.first_time

    def datetime(self, sample_index: int) -> datetime:
        """
        Given the sample index, return the (approximate) UTC timestampe when that sample happened.
        """
        return datetime.fromtimestamp(1e-9 * self.time_ns(sample_index), UTC)

    def start_time(self) -> datetime:
        """Return the UTC timestamp just before the first sample."""
        return self.datetime(0)

    def end_time(self) -> datetime:
        """Return the UTC timestamp just when the last sample arrived."""
        return self.datetime(self.num_of_samples - 1)

    def sample_index_from_ns(self, t: Union[float,int]) -> int:
        """Given a timestamp in ns after the epoch that corresponds to a time when we have been sampling,
        return the sample index of the sample that came in closest in time to that timestamp.
        """
        if self.time_ns(0) <= t <= self.time_ns(self.num_of_samples - 1):
            t_intern = round(t - self.first_time)
            i = bisect.bisect_left(self.timed_points_extended, t_intern, key=lambda p: p.t_ns)
            if i == len(self.timed_points_extended):
                i -= 1
            near_tpe = self.timed_points_extended[i]
            index = round((t_intern-near_tpe.backprojected_sample_start)*1e-9*self.samples_per_second)
            if 0 <= index < self.num_of_samples:
                return index
            else:
                raise ValueError(f"Resulting index should: 0 <= {index} < {self.num_of_samples}")
        else:
            raise ValueError(f"Need {self.time_ns(0)} <= {t} (i.e., your value) "
                             f"<= {self.time_ns(self.num_of_samples - 1)}")

    def sample_index(self, t: datetime) -> int:
        """Given a timestamp that corresponds to a time when we have been sampling,
        return the sample index of the sample that came in closest in time to that timestamp.
        """
        return self.sample_index_from_ns(t.timestamp() * 1e9)

    def export_as_wav(self, filename:str, t_from: datetime, t_to: datetime) -> None:
        """Export some part of the samples in WAV file format."""
        import scipy.io
        rate = round(1e9/self.slope)
        i_from = samples_index(t_from)
        i_to = sample_index(t_to)
        scipy.io.wavfile.write(filename, rate, self.samples[i_from:i_to])
        

if __name__ == "__main__":
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s %(levelname)s %(name)s: %(message)s"
    )
    gs1 = GrabSound("RWM/2026-03-02_181109")
    print(f"gs1.samples length: {len(gs1.samples)}")
    print(f"Start:      {gs1.start_time()}")
    print(f"End:        {gs1.end_time()}")
    print(f"Nominal speed: {gs1.samples_per_second}")
    print(f"True speed:    {gs1.samples_per_second_measured:.3f}")
    print(f"Duration 1: {len(gs1.samples)/48000*1e9} ns (nominal)")
    print(f"Duration 2: {(gs1.end_time() - gs1.start_time()).total_seconds() * 1e9} ns (via timediff)")
    print(f"Duration 3: {gs1.time_raw(len(gs1.samples)-1)-gs1.time_raw(0)} ns (via ns)")
    print(f"Duration 4: {gs1.time_ns(len(gs1.samples)-1)-gs1.time_ns(0)} ns (via ns)")
