import time
import statistics
import math
import u6

# ================= CONFIG =================

AIN_CHANNELS = [0, 1, 2, 3]

SLOW_DURATION = 60.0
FAST_SAMPLES = 200
INTERVAL = 0.02

STABLE_STD_THRESHOLD = 0.0025
SETTLE_TIME = 1.0

FULL_SCALE = 20.0  # assumed ±10V range

# =========================================

dev = u6.U6()


# ---------- HARDWARE INIT -----------------

def init_high_resolution():
    try:
        dev.configU6(resolution=8)
    except Exception:
        pass


def get_serial():
    try:
        info = dev.configU6()
        return info.get("SerialNumber", "unknown")
    except Exception:
        return "unknown"


def read_ain(ch):
    return dev.getAIN(ch, resolutionIndex=8)


# ---------- STATS ------------------------

def stats(vals):
    return {
        "mean": sum(vals) / len(vals),
        "stdev": statistics.pstdev(vals) if len(vals) > 1 else 0.0,
        "min": min(vals),
        "max": max(vals),
        "p2p": max(vals) - min(vals),
    }


def estimate_bits(stdev):
    if stdev <= 0:
        return float("inf")
    return math.log(FULL_SCALE / stdev, 2)


def print_stats(label, s, n=None):
    bits = estimate_bits(s["stdev"])

    print(f"{label}")
    print(f"  mean : {s['mean']:.6f} V")
    print(f"  stdev: {s['stdev']:.6f} V")
    print(f"  min  : {s['min']:.6f} V")
    print(f"  max  : {s['max']:.6f} V")
    print(f"  p2p  : {s['p2p']:.6f} V")
    print(f"  ~bits: {bits:.2f}")

    if n is not None:
        print(f"  samples: {n}")


def print_comparison(target, measured):
    print("\n--- RESULT SUMMARY ---")
    print(f"Reference : {target:.6f} V")
    print(f"Measured  : {measured:.6f} V")
    print(f"Error     : {measured - target:+.6f} V")
    print("----------------------\n")


# ---------- ACQUISITION ------------------

def fast_read(ch):
    return [read_ain(ch) for _ in range(FAST_SAMPLES)]


def slow_read(ch, duration=SLOW_DURATION, tick=2.0):
    out = []
    t_end = time.time() + duration
    next_tick = time.time() + tick

    print("Progress: ", end="", flush=True)

    while time.time() < t_end:
        out.append(read_ain(ch))

        if time.time() >= next_tick:
            print("#", end="", flush=True)
            next_tick += tick

        time.sleep(INTERVAL)

    print()
    return out


# ---------- AD584 SETUP ------------------

def get_ad584_steps():
    DEFAULTS = [2.49887, 5.00152, 7.49990, 10.00217]

    print("\nAD584 reference steps:")
    print("  2.49887 V")
    print("  5.00152 V")
    print("  7.49990 V")
    print("  10.00217 V")

    raw = input("\nENTER = defaults, or comma list: ").strip()

    if not raw:
        vals = DEFAULTS
    else:
        try:
            vals = [float(x) for x in raw.split(",")]
            if len(vals) != 4:
                raise ValueError()
        except:
            print("Invalid input → using defaults")
            vals = DEFAULTS

    return [
        ("2.5V", vals[0]),
        ("5.0V", vals[1]),
        ("7.5V", vals[2]),
        ("10.0V", vals[3]),
    ]


# ---------- DAC MODE ---------------------

def dac_mode(dac):
    print(f"\n=== DAC{dac} MODE ===\n")

    reg = 5000 if dac == 0 else 5002

    # header
    print("Set V |   AIN0     AIN1     AIN2     AIN3   | Spread")
    print("------|--------------------------------------|--------")

    for v in [i * 0.5 for i in range(11)]:
        dev.writeRegister(reg, v)
        time.sleep(0.05)

        vals = [read_ain(ch) for ch in AIN_CHANNELS]
        spread = max(vals) - min(vals)

        print(f"{v:5.1f} | " +
              " ".join(f"{x:8.5f}" for x in vals) +
              f" | {spread:7.5f}")


# ---------- AD584 MODE -------------------

def ad584_mode(ch, steps):
    print(f"\n=== AD584 MODE (AIN {ch}) ===")
    print("High-resolution verification mode\n")

    results = []

    for label, target in steps:

        print(f"\n→ Set {label}")

        while True:
            vals = [read_ain(ch) for _ in range(25)]
            m = sum(vals) / len(vals)
            s = statistics.pstdev(vals)

            if abs(m - target) < 0.8 and s < STABLE_STD_THRESHOLD:
                break

            time.sleep(0.2)

        print("Signal stable → acquiring data")
        time.sleep(SETTLE_TIME)

        fast = fast_read(ch)
        slow = slow_read(ch)

        fs = stats(fast)
        ss = stats(slow)

        fast_n = len(fast)
        slow_n = len(slow)

        print("\nFAST WINDOW")
        print_stats("FAST", fs, fast_n)

        print("\nSLOW WINDOW (~60s)")
        print_stats("SLOW", ss, slow_n)

        print_comparison(target, fs["mean"])

        results.append((target, fs["mean"]))

    print("\n=== FINAL SUMMARY ===")
    print("Vset | Measured | Error")

    for v, m in results:
        print(f"{v:4.1f} | {m:.6f} | {m - v:+.6f}")


# ---------- MAIN -------------------------

def main():
    serial = get_serial()
    print(f"LabJack U6 HIGH-RES DIAGNOSTIC v2  |  SN: {serial}\n")

    init_high_resolution()

    mode = input("mode 0/1/E: ").strip().upper()

    if mode in ["0", "1"]:
        dac_mode(int(mode))

    elif mode == "E":
        steps = get_ad584_steps()

        for ch in AIN_CHANNELS:
            ans = input(f"\nAIN {ch} [ENTER / S skip / SA stop]: ").strip().upper()

            if ans == "SA":
                break
            if ans == "S":
                continue

            ad584_mode(ch, steps)

    else:
        print("Invalid mode")


if __name__ == "__main__":
    main()
