From 1ab88bb69866d35efe1f1efb5b3d27fc2630217f Mon Sep 17 00:00:00 2001 From: David Lenfesty Date: Sat, 27 May 2023 10:36:20 -0600 Subject: [PATCH] gw: more peak detector unit tests --- gateware/litex_main.py | 6 + gateware/sampler/peak_detector.py | 179 +++++++++++++++++++++++++++++- 2 files changed, 183 insertions(+), 2 deletions(-) diff --git a/gateware/litex_main.py b/gateware/litex_main.py index e39b5bb..25379b2 100755 --- a/gateware/litex_main.py +++ b/gateware/litex_main.py @@ -184,6 +184,12 @@ def main(): results.append(run_test("CircularBuffer", circular_buffer.testbench)) results.append(run_test("SamplerController", controller.test_bus_access)) results.append(run_test("PeakDetector", peak_detector.test_simple_waveform)) + results.append(run_test("PeakDetector", peak_detector.test_scrunched_simple_waveform)) + results.append(run_test("PeakDetector", peak_detector.test_decay_simple_waveform)) + results.append(run_test("PeakDetector", peak_detector.test_decay_simple_waveform_too_much)) + results.append(run_test("PeakDetector", peak_detector.test_decay_compensates_bias)) + results.append(run_test("PeakDetector", peak_detector.test_biased_simple_waveform)) + results.append(run_test("PeakDetector", peak_detector.test_noise_spike)) passed = sum((1 for result in results if result.result == TestResult.PASS)) failed = sum((1 for result in results if result.result == TestResult.FAIL)) diff --git a/gateware/sampler/peak_detector.py b/gateware/sampler/peak_detector.py index 72241f5..cbbf1e8 100644 --- a/gateware/sampler/peak_detector.py +++ b/gateware/sampler/peak_detector.py @@ -106,7 +106,7 @@ import numpy as np import matplotlib.pyplot as plt -def create_waveform(dc_bias: int = 0, scale: float = 1) -> Tuple[np.ndarray[float], np.ndarray[int]]: +def create_waveform(*, dc_bias: int = 0, scale: float = 1) -> Tuple[np.ndarray[float], np.ndarray[int]]: """ Create a simple 40kHz sine wave in integer values that can be used by peak detector """ @@ -149,7 +149,13 @@ def set_settings(dut: PeakDetector, thresh_value: int, thresh_time: int, decay_v yield +# NOTE: These tests aren't really exhaustive. They get good coverage and generally outline results, +# but do require some amount of manual validation and checking of waveforms if things change +# majorly. At least they are a small set of things that I don't have to re-create later. + + def test_simple_waveform(): + """Test an ideal waveform with simple settings guaranteed to trigger""" (_, signal) = create_waveform() dut = PeakDetector(10) @@ -174,4 +180,173 @@ def test_simple_waveform(): assert False, "No trigger, test has failed..." - run_simulation(dut, test_fn(), vcd_name="peak_detector.vcd") + run_simulation(dut, test_fn()) + + +def test_scrunched_simple_waveform(): + """Test a smaller waveform with smaller peaks""" + (_, signal) = create_waveform(scale=0.4) + dut = PeakDetector(10) + + def test_fn(): + yield from set_settings(dut, 200, 10, 0, 0) + + # Enable + (yield dut.enable.eq(1)) + yield + + # Load data in until we trigger + for i, val in enumerate(signal): + if (yield dut.triggered) == 1: + # Test passed! + return + + # Load in data, set valid + (yield dut.data.eq(int(val))) + (yield dut.data_valid.eq(1)) + yield # Tick + + assert False, "No trigger, test has failed..." + + run_simulation(dut, test_fn()) + + +def test_decay_simple_waveform(): + """Test that simple case of decay works """ + (_, signal) = create_waveform() + dut = PeakDetector(10) + + def test_fn(): + yield from set_settings(dut, 800, 10, 5, 10) + + # Enable + (yield dut.enable.eq(1)) + yield + + # Load data in until we trigger + for i, val in enumerate(signal): + if (yield dut.triggered) == 1: + # Test passed! + return + + # Load in data, set valid + (yield dut.data.eq(int(val))) + (yield dut.data_valid.eq(1)) + yield # Tick + + assert False, "No trigger, test has failed..." + + run_simulation(dut, test_fn()) + + +def test_decay_simple_waveform_too_much(): + """Test that we can overuse decay and discard valid waveforms""" + (_, signal) = create_waveform() + dut = PeakDetector(10) + + def test_fn(): + yield from set_settings(dut, 800, 10, 40, 0) + + # Enable + (yield dut.enable.eq(1)) + yield + + # Load data in, ensuring we don't trigger + for i, val in enumerate(signal): + assert (yield dut.triggered) == 0, "Must not trigger!" + + # Load in data, set valid + (yield dut.data.eq(int(val))) + (yield dut.data_valid.eq(1)) + yield # Tick + + run_simulation(dut, test_fn()) + + +def test_decay_compensates_bias(): + signal = [800] * int(20e3) + dut = PeakDetector(10) + + def test_fn(): + yield from set_settings(dut, 800 - 512, 20, 1, 0) + + # Enable + (yield dut.enable.eq(1)) + yield + + for i, val in enumerate(signal): + assert (yield dut.triggered) == 0, "Must not trigger!" + assert i < 500, "Decay must not take too long to work!" + + # min val won't necessarily match max_val because of the crossover, but as long + # as it's close enough, we're good + if (yield dut._max_val) == 800 and (yield dut._min_val) >= 798: + # Test pass + return + + # Load data in + (yield dut.data.eq(val)) + (yield dut.data_valid.eq(1)) + yield # Tick + + run_simulation(dut, test_fn()) + + +def test_biased_simple_waveform(): + """Test a biased waveform. This does also need to test decay, else we get invalid results + + TODO this test is slightly broken. I think it's passing on the initial bias, not later stuff. + It should get fixed, but I think it mostly covers area already covered in other tests so I'm + fine with this for now. + """ + (_, signal) = create_waveform(dc_bias=200) + dut = PeakDetector(10) + + def test_fn(): + yield from set_settings(dut, 200, 20, 20, 1) + + # Enable + (yield dut.enable.eq(1)) + yield + + # Load data in until we trigger + for i, val in enumerate(signal): + if (yield dut.triggered) == 1: + # Test passed! + return + + # Load in data, set valid + (yield dut.data.eq(int(val))) + (yield dut.data_valid.eq(1)) + yield # Tick + + assert False, "No trigger, test has failed..." + + run_simulation(dut, test_fn()) + + +def test_noise_spike(): + """Test that appropriate filtering and decay can filter out a spike in noise""" + signal = [512, 512, 512, 1024, 1024] + [512] * 1000 + dut = PeakDetector(10) + + def test_fn(): + yield from set_settings(dut, 300, 20, 20, 0) + + # Enable + (yield dut.enable.eq(1)) + yield # Tick + + # Load data in until we trigger + for val in signal: + assert (yield dut.triggered) == 0, "Can't trigger!" + + # Load in data, set valid + (yield dut.data.eq(int(val))) + (yield dut.data_valid.eq(1)) + yield # Tick + + # Test success + return + + run_simulation(dut, test_fn())