Compare commits

...

7 Commits

Author SHA1 Message Date
0e3328aac1 gw: add some end to end unit tests for sampler controller 2023-05-27 12:08:34 -06:00
e624d82742 gw: minor fixes for correct controller operation 2023-05-27 12:08:12 -06:00
627550840c gw: add skipping test suites 2023-05-27 11:03:18 -06:00
942dba8ea3 gw: fix CircularBuffer test
with_wb = True overrides the rd_addr, which broke the test. Just had to
disable that param and the test was correct again
2023-05-27 11:02:31 -06:00
1ab88bb698 gw: more peak detector unit tests 2023-05-27 10:36:48 -06:00
0185d81d46 gw: minor tweaks to peak detector to improve behaviour 2023-05-27 10:36:46 -06:00
1d8c9ca224 gw: Start testing peak detector, and fix a bug! 2023-05-27 10:36:42 -06:00
5 changed files with 459 additions and 52 deletions

View File

@ -30,7 +30,7 @@ from liteeth.phy.ecp5rgmii import LiteEthPHYRGMII
from sampler import Sampler
from litex.soc.integration.soc import SoCRegion
from test import run_test, TestResult
from test import run_test, TestResult, skip_suite
# CRG ----------------------------------------------------------------------------------------------
@ -178,10 +178,20 @@ def main():
if args.test:
from sampler import circular_buffer
from sampler import controller
from sampler import peak_detector
results = []
results.append(run_test("CircularBuffer", circular_buffer.testbench))
results.append(run_test("SamplerController", controller.test_bus_access))
results.append(run_test("SamplerController", controller.test_simple_waveform))
results.append(run_test("SamplerController", controller.test_simple_waveform_capture_offset))
results.append(run_test("PeakDetector", peak_detector.test_simple_waveform))
results.append(run_test("PeakDetector", peak_detector.test_scrunched_simple_waveform))
results.append(run_test("PeakDetector", peak_detector.test_decay_simple_waveform))
results.append(run_test("PeakDetector", peak_detector.test_decay_simple_waveform_too_much))
results.append(run_test("PeakDetector", peak_detector.test_decay_compensates_bias))
results.append(run_test("PeakDetector", peak_detector.test_biased_simple_waveform))
results.append(run_test("PeakDetector", peak_detector.test_noise_spike))
passed = sum((1 for result in results if result.result == TestResult.PASS))
failed = sum((1 for result in results if result.result == TestResult.FAIL))

View File

@ -21,7 +21,7 @@ class CircularBuffer(Module):
ptr_width = ceil(log2(depth))
# External Signals
self.len = Signal(ptr_width) # Amount of valid data in the buffer
self.len = Signal(ptr_width + 1) # Amount of valid data in the buffer
self.clear = Signal() # Strobe to clear memory
self.rd_addr = Signal(ptr_width)
self.rd_data = Signal(width)
@ -111,7 +111,7 @@ class CircularBuffer(Module):
def testbench():
dut = CircularBuffer(9, 24)
dut = CircularBuffer(9, 24, with_wb=False)
def test_fn():
assert (yield dut.len) == 0
assert (yield dut.wr_ready) == 1

View File

@ -31,7 +31,7 @@ class SamplerController(Module):
Bit 0 - Begin capture. Resets all FIFOs and starts the peak detector
0x01: Status Register (RO)
Bit 0 - Capture complete. Set by peak detection block and cleared by software or when
Bit 0 - Capture complete. Set by peak detection block and cleared when capture is began
0x02: trigger_run_len (RW)
Number of samples to acquire after triggering sample.
@ -67,6 +67,9 @@ class SamplerController(Module):
# Connect each buffer to each sampler
for buffer, sampler in zip(self.buffers, self.samplers):
self.submodules += buffer
self.submodules += sampler
self.comb += [
# Connect only top 9 bits to memory
buffer.wr_data.eq(sampler.data[1:]),
@ -76,8 +79,9 @@ class SamplerController(Module):
# Each sampler gets some chunk of memory at least large enough to fit
# all of it's data, so use that as a consistent offset
sample_mem_addr_width = ceil(log2(buffer_len))
# all of it's data, so use that as a consistent offset. Use a minimum
# address of 0x800 to avoid conflicts with control registers
sample_mem_addr_width = max(ceil(log2(buffer_len)), ceil(log2(0x800)))
# 1 control block + number of channels used = control bits
control_block_addr_width = ceil(log2(num_channels + 1))
@ -100,11 +104,9 @@ class SamplerController(Module):
adr = (i + 1) << sample_mem_addr_width
print(f"Sampler {i} available at 0x{adr:08x}")
self.decoder = Decoder(self.bus, slaves)
# TODO how to submodule
self.submodules.decoder = self.decoder
self.submodules.decoder = Decoder(self.bus, slaves)
self.peak_detector = PeakDetector(10)
self.submodules.peak_detector = PeakDetector(10)
self.comb += [
# Simply enable whenever we start capturing
self.peak_detector.enable.eq(sample_enable),
@ -232,6 +234,8 @@ def read_wishbone(bus, address,):
(yield bus.stb.eq(0))
(yield bus.cyc.eq(0))
yield # Tick
break
else:
# Tick until we receive an ACK
@ -247,24 +251,23 @@ class MockSampler(Module):
Index of data to use from provided data
"""
def __init__(self, data: List[int]):
memory = Memory(width=10, depth=len(data), init=data)
self.specials.memory = Memory(width=10, depth=len(data), init=data)
self.index = Signal(ceil(log2(len(data))))
self.data = Signal(10)
self.valid = Signal()
read_port = memory.get_port(async_read=True)
read_port = self.memory.get_port(async_read=True)
self.comb += [
read_port.adr.eq(self.index),
self.data.eq(read_port.dat_r),
]
class TestSoC(Module):
def __init__(self, data):
sampler = MockSampler(data)
self.submodules.sampler = sampler
def __init__(self, data: List[int], *, buffer_len: int = 1024, num_samplers: int = 1):
# TODO multiple mock samplers to test that functionality
self.controller = SamplerController([MockSampler(data)], 1024)
self.samplers = [MockSampler(data) for _ in range(num_samplers)]
self.controller = SamplerController(self.samplers, buffer_len)
self.submodules.controller = self.controller
self.bus = self.controller.bus
@ -279,3 +282,139 @@ def test_bus_access():
# TODO test writing to RO register fails
run_simulation(dut, test_fn(), vcd_name="test_bus_access.vcd")
def test_simple_waveform():
"""End-to-end test of a simple waveform"""
from .peak_detector import create_waveform
_, data = create_waveform()
data = [int(d) for d in data]
dut = TestSoC(data, buffer_len=32)
def test_fn():
# Set settings
yield from write_wishbone(dut.bus, 2, 0) # trigger_run_len = 0
yield from write_wishbone(dut.bus, 3, 800) # thresh_value = 800
yield from write_wishbone(dut.bus, 4, 10) # thresh_time = 10
yield from write_wishbone(dut.bus, 5, 1) # decay_value = 1
yield from write_wishbone(dut.bus, 5, 0) # decay_period = 0
# Start controller
yield from write_wishbone(dut.bus, 0, 1)
triggered_yet = False
triggered_num = 0
for i in range(1000):
(yield dut.samplers[0].index.eq(i))
(yield dut.samplers[0].valid.eq(1))
yield
(yield dut.samplers[0].valid.eq(0))
yield
# Total of 6 clocks per sample clock
yield
yield
yield
yield
if not triggered_yet and (yield dut.controller.peak_detector.triggered) == 1:
# Triggered, now we need to run some number of cycles
triggered_yet = True
if triggered_yet:
triggered_num += 1
if triggered_num > 32:
# We should now have collected all our samples
yield from read_wishbone(dut.bus, 1)
assert (yield dut.bus.dat_r) == 1, "Trigger did not propogate to WB!"
# Check that length is correct
yield from read_wishbone(dut.bus, 0x100)
len = (yield dut.bus.dat_r)
assert len == 32, f"Len ({len}) not correct!"
# Read data in
data = []
for i in range(32):
yield from read_wishbone(dut.bus, 0x800 + i)
sample = (yield dut.bus.dat_r)
data.append(sample)
# Test pass
return
assert False, "We should have triggered"
run_simulation(dut, test_fn())
def test_simple_waveform_capture_offset():
"""Test a simple waveform captured at an offset"""
from .peak_detector import create_waveform
_, data = create_waveform()
data = [int(d) for d in data]
dut = TestSoC(data, buffer_len=32)
def test_fn():
# Set settings
yield from write_wishbone(dut.bus, 2, 16) # trigger_run_len = 16
yield from write_wishbone(dut.bus, 3, 800) # thresh_value = 800
yield from write_wishbone(dut.bus, 4, 10) # thresh_time = 10
yield from write_wishbone(dut.bus, 5, 1) # decay_value = 1
yield from write_wishbone(dut.bus, 5, 0) # decay_period = 0
# Start controller
yield from write_wishbone(dut.bus, 0, 1)
triggered_yet = False
triggered_num = 0
for i in range(1000):
(yield dut.samplers[0].index.eq(i))
(yield dut.samplers[0].valid.eq(1))
yield
(yield dut.samplers[0].valid.eq(0))
yield
# Total of 6 clocks per sample clock
yield
yield
yield
yield
if not triggered_yet and (yield dut.controller.peak_detector.triggered) == 1:
# Triggered, now we need to run some number of cycles
triggered_yet = True
if triggered_yet:
triggered_num += 1
if triggered_num > 16:
# We should now have collected all our samples
yield from read_wishbone(dut.bus, 1)
assert (yield dut.bus.dat_r) == 1, "Trigger did not propogate to WB!"
# Check that length is correct
yield from read_wishbone(dut.bus, 0x100)
len = (yield dut.bus.dat_r)
assert len == 32, f"Len ({len}) not correct!"
# Read data in
data = []
for i in range(32):
yield from read_wishbone(dut.bus, 0x800 + i)
sample = (yield dut.bus.dat_r)
data.append(sample)
# Manually validated from test above to be offset into the
# data
assert data[0] == 138
assert data[1] == 132
# Test pass
return
assert False, "We should have triggered"
run_simulation(dut, test_fn())

View File

@ -36,12 +36,12 @@ class PeakDetector(Module):
"""
def __init__(self, data_width: int):
# Create all state signals
min_val = Signal(data_width)
max_val = Signal(data_width)
diff = Signal(data_width)
triggered_time = Signal(32)
decay_counter = Signal(32)
# Create all state signals (underscored in self to be accessible in tests)
self._min_val = Signal(data_width)
self._max_val = Signal(data_width)
self._diff = Signal(data_width)
self._triggered_time = Signal(32)
self._decay_counter = Signal(32)
# Control signals
self.data = Signal(data_width)
@ -57,42 +57,296 @@ class PeakDetector(Module):
self.sync += If(~self.enable,
# Reset halfway. ADCs are 0-2V, and everything should be centered at 1V, so this is approximating the initial value
min_val.eq(int(2**data_width /2)),
max_val.eq(int(2**data_width /2)),
self._min_val.eq(int(2**data_width /2)),
self._max_val.eq(int(2**data_width /2)),
self.triggered.eq(0),
decay_counter.eq(0),
triggered_time.eq(0),
self._decay_counter.eq(0),
self._triggered_time.eq(0),
)
# Constantly updating diff to simplify some statements
self.comb += diff.eq(max_val - min_val)
# Constantly updating self._diff to simplify some statements
self.comb += self._diff.eq(self._max_val - self._min_val)
self.sync += If(self.enable & self.data_valid,
# Decay should run irrespective of if we have triggered,
# and before everything else so it can be overwritten
self._decay_counter.eq(self._decay_counter + 1),
# Decay threshold has been reached, apply decay to peaks
If(self._decay_counter >= self.decay_period,
self._decay_counter.eq(0),
# Only apply decay if the values would not overlap, and we use the decay
If((self._diff >= (self.decay_value << 1)) & (self.decay_value > 0),
self._max_val.eq(self._max_val - self.decay_value),
self._min_val.eq(self._min_val + self.decay_value))),
# Update maximum value
If(self.data > max_val, max_val.eq(self.data)),
If(self.data > self._max_val, self._max_val.eq(self.data)),
# Update minimum value
If(self.data < min_val, min_val.eq(self.data)),
If(diff > self.thresh_value,
If(self.data < self._min_val, self._min_val.eq(self.data)),
If(self._diff > self.thresh_value,
# We have met the threshold for triggering, start counting
triggered_time.eq(triggered_time + 1),
decay_counter.eq(0),
self._triggered_time.eq(self._triggered_time + 1),
# We have triggered, so we can set the output. After this point,
# nothing we do matters until enable is de-asserted and we reset
# triggered.
If(triggered_time + 1 >= self.thresh_time, self.triggered.eq(1)))
If(self._triggered_time + 1 >= self.thresh_time, self.triggered.eq(1)))
.Else(
# We have not met the threshold, reset timer and handle decay
triggered_time.eq(0),
decay_counter.eq(decay_counter + 1),
# We have not met the threshold, reset timer
self._triggered_time.eq(0),
),
# Decay threshold has been reached, apply decay to peaks
If(decay_counter >= self.decay_period,
decay_counter.eq(0),
)
# Only apply decay if the values would not overlap
If(diff >= (self.decay_value << 1),
max_val.eq(max_val - self.decay_value),
min_val.eq(min_val + self.decay_value)))
)
)
from typing import Tuple
import numpy as np
import matplotlib.pyplot as plt
def create_waveform(*, dc_bias: int = 0, scale: float = 1) -> Tuple[np.ndarray[float], np.ndarray[int]]:
"""
Create a simple 40kHz sine wave in integer values that can be used by peak detector
"""
assert scale <= 1.0, "Scale factor must be some ratio of full range"
# Constants
f_s = 10e6 # Sample rate (Hz)
f = 40e3 # Signal Frequency (Hz)
t = 0.002 # Sample period (s)
n = int(f_s * 0.002) # Number of samples
# Create time from 0ms to 2ms
x = np.linspace(0, t, n)
# Create signal!
y = np.sin(x * 2*np.pi*f)
# Scale according to user inputs
y = y * scale
# Convert to positive integer at provided bias
signal = np.ndarray((len(y)), dtype=np.uint16)
# "unsafe" casting because numpy doesn't know we are staying under 10 bit values
np.copyto(signal, y * 512 + 512 + dc_bias, casting='unsafe')
#plt.plot(x, signal)
#plt.show()
return x, signal
def set_settings(dut: PeakDetector, thresh_value: int, thresh_time: int, decay_value: int, decay_period: int) -> None:
"""Set peak detector settings simply"""
(yield dut.thresh_value.eq(thresh_value))
(yield dut.thresh_time.eq(thresh_time))
(yield dut.decay_value.eq(decay_value))
(yield dut.decay_period.eq(decay_period))
# Load in values with a new clock
yield
# NOTE: These tests aren't really exhaustive. They get good coverage and generally outline results,
# but do require some amount of manual validation and checking of waveforms if things change
# majorly. At least they are a small set of things that I don't have to re-create later.
def test_simple_waveform():
"""Test an ideal waveform with simple settings guaranteed to trigger"""
(_, signal) = create_waveform()
dut = PeakDetector(10)
def test_fn():
# First set settings to be simple, we want to trigger pretty much immediately
yield from set_settings(dut, 800, 10, 0, 0)
# Enable device
(yield dut.enable.eq(1))
yield
# Load data in until we trigger
for i, val in enumerate(signal):
if (yield dut.triggered) == 1:
# Test passed!
return
# Load in data, set valid
(yield dut.data.eq(int(val)))
(yield dut.data_valid.eq(1))
yield # Tick
assert False, "No trigger, test has failed..."
run_simulation(dut, test_fn())
def test_scrunched_simple_waveform():
"""Test a smaller waveform with smaller peaks"""
(_, signal) = create_waveform(scale=0.4)
dut = PeakDetector(10)
def test_fn():
yield from set_settings(dut, 200, 10, 0, 0)
# Enable
(yield dut.enable.eq(1))
yield
# Load data in until we trigger
for i, val in enumerate(signal):
if (yield dut.triggered) == 1:
# Test passed!
return
# Load in data, set valid
(yield dut.data.eq(int(val)))
(yield dut.data_valid.eq(1))
yield # Tick
assert False, "No trigger, test has failed..."
run_simulation(dut, test_fn())
def test_decay_simple_waveform():
"""Test that simple case of decay works """
(_, signal) = create_waveform()
dut = PeakDetector(10)
def test_fn():
yield from set_settings(dut, 800, 10, 5, 10)
# Enable
(yield dut.enable.eq(1))
yield
# Load data in until we trigger
for i, val in enumerate(signal):
if (yield dut.triggered) == 1:
# Test passed!
return
# Load in data, set valid
(yield dut.data.eq(int(val)))
(yield dut.data_valid.eq(1))
yield # Tick
assert False, "No trigger, test has failed..."
run_simulation(dut, test_fn())
def test_decay_simple_waveform_too_much():
"""Test that we can overuse decay and discard valid waveforms"""
(_, signal) = create_waveform()
dut = PeakDetector(10)
def test_fn():
yield from set_settings(dut, 800, 10, 40, 0)
# Enable
(yield dut.enable.eq(1))
yield
# Load data in, ensuring we don't trigger
for i, val in enumerate(signal):
assert (yield dut.triggered) == 0, "Must not trigger!"
# Load in data, set valid
(yield dut.data.eq(int(val)))
(yield dut.data_valid.eq(1))
yield # Tick
run_simulation(dut, test_fn())
def test_decay_compensates_bias():
signal = [800] * int(20e3)
dut = PeakDetector(10)
def test_fn():
yield from set_settings(dut, 800 - 512, 20, 1, 0)
# Enable
(yield dut.enable.eq(1))
yield
for i, val in enumerate(signal):
assert (yield dut.triggered) == 0, "Must not trigger!"
assert i < 500, "Decay must not take too long to work!"
# min val won't necessarily match max_val because of the crossover, but as long
# as it's close enough, we're good
if (yield dut._max_val) == 800 and (yield dut._min_val) >= 798:
# Test pass
return
# Load data in
(yield dut.data.eq(val))
(yield dut.data_valid.eq(1))
yield # Tick
run_simulation(dut, test_fn())
def test_biased_simple_waveform():
"""Test a biased waveform. This does also need to test decay, else we get invalid results
TODO this test is slightly broken. I think it's passing on the initial bias, not later stuff.
It should get fixed, but I think it mostly covers area already covered in other tests so I'm
fine with this for now.
"""
(_, signal) = create_waveform(dc_bias=200)
dut = PeakDetector(10)
def test_fn():
yield from set_settings(dut, 200, 20, 20, 1)
# Enable
(yield dut.enable.eq(1))
yield
# Load data in until we trigger
for i, val in enumerate(signal):
if (yield dut.triggered) == 1:
# Test passed!
return
# Load in data, set valid
(yield dut.data.eq(int(val)))
(yield dut.data_valid.eq(1))
yield # Tick
assert False, "No trigger, test has failed..."
run_simulation(dut, test_fn())
def test_noise_spike():
"""Test that appropriate filtering and decay can filter out a spike in noise"""
signal = [512, 512, 512, 1024, 1024] + [512] * 1000
dut = PeakDetector(10)
def test_fn():
yield from set_settings(dut, 300, 20, 20, 0)
# Enable
(yield dut.enable.eq(1))
yield # Tick
# Load data in until we trigger
for val in signal:
assert (yield dut.triggered) == 0, "Can't trigger!"
# Load in data, set valid
(yield dut.data.eq(int(val)))
(yield dut.data_valid.eq(1))
yield # Tick
# Test success
return
run_simulation(dut, test_fn())

View File

@ -2,7 +2,7 @@
Helper functions for a basic test suite
"""
from typing import Callable
from typing import Callable, List
from enum import StrEnum
from dataclasses import dataclass
from traceback import print_exc
@ -25,15 +25,21 @@ class TestInfo:
return f"[{self.suite_name}.{self.test_name}] {self.result}"
skipped_suites: List[str] = []
def skip_suite(suite_name: str):
"""Skips running tests from a specific suite"""
skipped_suites.append(suite_name)
def run_test(suite_name: str, test_fn: Callable, do_skip = False) -> TestResult:
test_name = test_fn.__name__
print(f"[{suite_name}.{test_name}] Running...")
if do_skip:
if do_skip or suite_name in skipped_suites:
res = TestInfo(suite_name, test_name, TestResult.SKIP)
else:
print(f"[{suite_name}.{test_name}] Running...")
try:
test_fn()
res = TestInfo(suite_name, test_name, TestResult.PASS)
@ -41,7 +47,5 @@ def run_test(suite_name: str, test_fn: Callable, do_skip = False) -> TestResult:
res = TestInfo(suite_name, test_name, TestResult.FAIL)
print_exc()
print(res)
return res