Compare commits

...

7 Commits

Author SHA1 Message Date
0e3328aac1 gw: add some end to end unit tests for sampler controller 2023-05-27 12:08:34 -06:00
e624d82742 gw: minor fixes for correct controller operation 2023-05-27 12:08:12 -06:00
627550840c gw: add skipping test suites 2023-05-27 11:03:18 -06:00
942dba8ea3 gw: fix CircularBuffer test
with_wb = True overrides the rd_addr, which broke the test. Just had to
disable that param and the test was correct again
2023-05-27 11:02:31 -06:00
1ab88bb698 gw: more peak detector unit tests 2023-05-27 10:36:48 -06:00
0185d81d46 gw: minor tweaks to peak detector to improve behaviour 2023-05-27 10:36:46 -06:00
1d8c9ca224 gw: Start testing peak detector, and fix a bug! 2023-05-27 10:36:42 -06:00
5 changed files with 459 additions and 52 deletions

View File

@ -30,7 +30,7 @@ from liteeth.phy.ecp5rgmii import LiteEthPHYRGMII
from sampler import Sampler from sampler import Sampler
from litex.soc.integration.soc import SoCRegion from litex.soc.integration.soc import SoCRegion
from test import run_test, TestResult from test import run_test, TestResult, skip_suite
# CRG ---------------------------------------------------------------------------------------------- # CRG ----------------------------------------------------------------------------------------------
@ -178,10 +178,20 @@ def main():
if args.test: if args.test:
from sampler import circular_buffer from sampler import circular_buffer
from sampler import controller from sampler import controller
from sampler import peak_detector
results = [] results = []
results.append(run_test("CircularBuffer", circular_buffer.testbench)) results.append(run_test("CircularBuffer", circular_buffer.testbench))
results.append(run_test("SamplerController", controller.test_bus_access)) results.append(run_test("SamplerController", controller.test_bus_access))
results.append(run_test("SamplerController", controller.test_simple_waveform))
results.append(run_test("SamplerController", controller.test_simple_waveform_capture_offset))
results.append(run_test("PeakDetector", peak_detector.test_simple_waveform))
results.append(run_test("PeakDetector", peak_detector.test_scrunched_simple_waveform))
results.append(run_test("PeakDetector", peak_detector.test_decay_simple_waveform))
results.append(run_test("PeakDetector", peak_detector.test_decay_simple_waveform_too_much))
results.append(run_test("PeakDetector", peak_detector.test_decay_compensates_bias))
results.append(run_test("PeakDetector", peak_detector.test_biased_simple_waveform))
results.append(run_test("PeakDetector", peak_detector.test_noise_spike))
passed = sum((1 for result in results if result.result == TestResult.PASS)) passed = sum((1 for result in results if result.result == TestResult.PASS))
failed = sum((1 for result in results if result.result == TestResult.FAIL)) failed = sum((1 for result in results if result.result == TestResult.FAIL))

View File

@ -21,7 +21,7 @@ class CircularBuffer(Module):
ptr_width = ceil(log2(depth)) ptr_width = ceil(log2(depth))
# External Signals # External Signals
self.len = Signal(ptr_width) # Amount of valid data in the buffer self.len = Signal(ptr_width + 1) # Amount of valid data in the buffer
self.clear = Signal() # Strobe to clear memory self.clear = Signal() # Strobe to clear memory
self.rd_addr = Signal(ptr_width) self.rd_addr = Signal(ptr_width)
self.rd_data = Signal(width) self.rd_data = Signal(width)
@ -111,7 +111,7 @@ class CircularBuffer(Module):
def testbench(): def testbench():
dut = CircularBuffer(9, 24) dut = CircularBuffer(9, 24, with_wb=False)
def test_fn(): def test_fn():
assert (yield dut.len) == 0 assert (yield dut.len) == 0
assert (yield dut.wr_ready) == 1 assert (yield dut.wr_ready) == 1

View File

@ -31,7 +31,7 @@ class SamplerController(Module):
Bit 0 - Begin capture. Resets all FIFOs and starts the peak detector Bit 0 - Begin capture. Resets all FIFOs and starts the peak detector
0x01: Status Register (RO) 0x01: Status Register (RO)
Bit 0 - Capture complete. Set by peak detection block and cleared by software or when Bit 0 - Capture complete. Set by peak detection block and cleared when capture is began
0x02: trigger_run_len (RW) 0x02: trigger_run_len (RW)
Number of samples to acquire after triggering sample. Number of samples to acquire after triggering sample.
@ -67,6 +67,9 @@ class SamplerController(Module):
# Connect each buffer to each sampler # Connect each buffer to each sampler
for buffer, sampler in zip(self.buffers, self.samplers): for buffer, sampler in zip(self.buffers, self.samplers):
self.submodules += buffer
self.submodules += sampler
self.comb += [ self.comb += [
# Connect only top 9 bits to memory # Connect only top 9 bits to memory
buffer.wr_data.eq(sampler.data[1:]), buffer.wr_data.eq(sampler.data[1:]),
@ -76,8 +79,9 @@ class SamplerController(Module):
# Each sampler gets some chunk of memory at least large enough to fit # Each sampler gets some chunk of memory at least large enough to fit
# all of it's data, so use that as a consistent offset # all of it's data, so use that as a consistent offset. Use a minimum
sample_mem_addr_width = ceil(log2(buffer_len)) # address of 0x800 to avoid conflicts with control registers
sample_mem_addr_width = max(ceil(log2(buffer_len)), ceil(log2(0x800)))
# 1 control block + number of channels used = control bits # 1 control block + number of channels used = control bits
control_block_addr_width = ceil(log2(num_channels + 1)) control_block_addr_width = ceil(log2(num_channels + 1))
@ -100,11 +104,9 @@ class SamplerController(Module):
adr = (i + 1) << sample_mem_addr_width adr = (i + 1) << sample_mem_addr_width
print(f"Sampler {i} available at 0x{adr:08x}") print(f"Sampler {i} available at 0x{adr:08x}")
self.decoder = Decoder(self.bus, slaves) self.submodules.decoder = Decoder(self.bus, slaves)
# TODO how to submodule
self.submodules.decoder = self.decoder
self.peak_detector = PeakDetector(10) self.submodules.peak_detector = PeakDetector(10)
self.comb += [ self.comb += [
# Simply enable whenever we start capturing # Simply enable whenever we start capturing
self.peak_detector.enable.eq(sample_enable), self.peak_detector.enable.eq(sample_enable),
@ -232,6 +234,8 @@ def read_wishbone(bus, address,):
(yield bus.stb.eq(0)) (yield bus.stb.eq(0))
(yield bus.cyc.eq(0)) (yield bus.cyc.eq(0))
yield # Tick
break break
else: else:
# Tick until we receive an ACK # Tick until we receive an ACK
@ -247,24 +251,23 @@ class MockSampler(Module):
Index of data to use from provided data Index of data to use from provided data
""" """
def __init__(self, data: List[int]): def __init__(self, data: List[int]):
memory = Memory(width=10, depth=len(data), init=data) self.specials.memory = Memory(width=10, depth=len(data), init=data)
self.index = Signal(ceil(log2(len(data)))) self.index = Signal(ceil(log2(len(data))))
self.data = Signal(10) self.data = Signal(10)
self.valid = Signal() self.valid = Signal()
read_port = memory.get_port(async_read=True) read_port = self.memory.get_port(async_read=True)
self.comb += [ self.comb += [
read_port.adr.eq(self.index), read_port.adr.eq(self.index),
self.data.eq(read_port.dat_r), self.data.eq(read_port.dat_r),
] ]
class TestSoC(Module): class TestSoC(Module):
def __init__(self, data): def __init__(self, data: List[int], *, buffer_len: int = 1024, num_samplers: int = 1):
sampler = MockSampler(data)
self.submodules.sampler = sampler
# TODO multiple mock samplers to test that functionality # TODO multiple mock samplers to test that functionality
self.controller = SamplerController([MockSampler(data)], 1024) self.samplers = [MockSampler(data) for _ in range(num_samplers)]
self.controller = SamplerController(self.samplers, buffer_len)
self.submodules.controller = self.controller self.submodules.controller = self.controller
self.bus = self.controller.bus self.bus = self.controller.bus
@ -279,3 +282,139 @@ def test_bus_access():
# TODO test writing to RO register fails # TODO test writing to RO register fails
run_simulation(dut, test_fn(), vcd_name="test_bus_access.vcd") run_simulation(dut, test_fn(), vcd_name="test_bus_access.vcd")
def test_simple_waveform():
"""End-to-end test of a simple waveform"""
from .peak_detector import create_waveform
_, data = create_waveform()
data = [int(d) for d in data]
dut = TestSoC(data, buffer_len=32)
def test_fn():
# Set settings
yield from write_wishbone(dut.bus, 2, 0) # trigger_run_len = 0
yield from write_wishbone(dut.bus, 3, 800) # thresh_value = 800
yield from write_wishbone(dut.bus, 4, 10) # thresh_time = 10
yield from write_wishbone(dut.bus, 5, 1) # decay_value = 1
yield from write_wishbone(dut.bus, 5, 0) # decay_period = 0
# Start controller
yield from write_wishbone(dut.bus, 0, 1)
triggered_yet = False
triggered_num = 0
for i in range(1000):
(yield dut.samplers[0].index.eq(i))
(yield dut.samplers[0].valid.eq(1))
yield
(yield dut.samplers[0].valid.eq(0))
yield
# Total of 6 clocks per sample clock
yield
yield
yield
yield
if not triggered_yet and (yield dut.controller.peak_detector.triggered) == 1:
# Triggered, now we need to run some number of cycles
triggered_yet = True
if triggered_yet:
triggered_num += 1
if triggered_num > 32:
# We should now have collected all our samples
yield from read_wishbone(dut.bus, 1)
assert (yield dut.bus.dat_r) == 1, "Trigger did not propogate to WB!"
# Check that length is correct
yield from read_wishbone(dut.bus, 0x100)
len = (yield dut.bus.dat_r)
assert len == 32, f"Len ({len}) not correct!"
# Read data in
data = []
for i in range(32):
yield from read_wishbone(dut.bus, 0x800 + i)
sample = (yield dut.bus.dat_r)
data.append(sample)
# Test pass
return
assert False, "We should have triggered"
run_simulation(dut, test_fn())
def test_simple_waveform_capture_offset():
"""Test a simple waveform captured at an offset"""
from .peak_detector import create_waveform
_, data = create_waveform()
data = [int(d) for d in data]
dut = TestSoC(data, buffer_len=32)
def test_fn():
# Set settings
yield from write_wishbone(dut.bus, 2, 16) # trigger_run_len = 16
yield from write_wishbone(dut.bus, 3, 800) # thresh_value = 800
yield from write_wishbone(dut.bus, 4, 10) # thresh_time = 10
yield from write_wishbone(dut.bus, 5, 1) # decay_value = 1
yield from write_wishbone(dut.bus, 5, 0) # decay_period = 0
# Start controller
yield from write_wishbone(dut.bus, 0, 1)
triggered_yet = False
triggered_num = 0
for i in range(1000):
(yield dut.samplers[0].index.eq(i))
(yield dut.samplers[0].valid.eq(1))
yield
(yield dut.samplers[0].valid.eq(0))
yield
# Total of 6 clocks per sample clock
yield
yield
yield
yield
if not triggered_yet and (yield dut.controller.peak_detector.triggered) == 1:
# Triggered, now we need to run some number of cycles
triggered_yet = True
if triggered_yet:
triggered_num += 1
if triggered_num > 16:
# We should now have collected all our samples
yield from read_wishbone(dut.bus, 1)
assert (yield dut.bus.dat_r) == 1, "Trigger did not propogate to WB!"
# Check that length is correct
yield from read_wishbone(dut.bus, 0x100)
len = (yield dut.bus.dat_r)
assert len == 32, f"Len ({len}) not correct!"
# Read data in
data = []
for i in range(32):
yield from read_wishbone(dut.bus, 0x800 + i)
sample = (yield dut.bus.dat_r)
data.append(sample)
# Manually validated from test above to be offset into the
# data
assert data[0] == 138
assert data[1] == 132
# Test pass
return
assert False, "We should have triggered"
run_simulation(dut, test_fn())

View File

@ -36,12 +36,12 @@ class PeakDetector(Module):
""" """
def __init__(self, data_width: int): def __init__(self, data_width: int):
# Create all state signals # Create all state signals (underscored in self to be accessible in tests)
min_val = Signal(data_width) self._min_val = Signal(data_width)
max_val = Signal(data_width) self._max_val = Signal(data_width)
diff = Signal(data_width) self._diff = Signal(data_width)
triggered_time = Signal(32) self._triggered_time = Signal(32)
decay_counter = Signal(32) self._decay_counter = Signal(32)
# Control signals # Control signals
self.data = Signal(data_width) self.data = Signal(data_width)
@ -57,42 +57,296 @@ class PeakDetector(Module):
self.sync += If(~self.enable, self.sync += If(~self.enable,
# Reset halfway. ADCs are 0-2V, and everything should be centered at 1V, so this is approximating the initial value # Reset halfway. ADCs are 0-2V, and everything should be centered at 1V, so this is approximating the initial value
min_val.eq(int(2**data_width /2)), self._min_val.eq(int(2**data_width /2)),
max_val.eq(int(2**data_width /2)), self._max_val.eq(int(2**data_width /2)),
self.triggered.eq(0), self.triggered.eq(0),
decay_counter.eq(0), self._decay_counter.eq(0),
triggered_time.eq(0), self._triggered_time.eq(0),
) )
# Constantly updating diff to simplify some statements # Constantly updating self._diff to simplify some statements
self.comb += diff.eq(max_val - min_val) self.comb += self._diff.eq(self._max_val - self._min_val)
self.sync += If(self.enable & self.data_valid, self.sync += If(self.enable & self.data_valid,
# Decay should run irrespective of if we have triggered,
# and before everything else so it can be overwritten
self._decay_counter.eq(self._decay_counter + 1),
# Decay threshold has been reached, apply decay to peaks
If(self._decay_counter >= self.decay_period,
self._decay_counter.eq(0),
# Only apply decay if the values would not overlap, and we use the decay
If((self._diff >= (self.decay_value << 1)) & (self.decay_value > 0),
self._max_val.eq(self._max_val - self.decay_value),
self._min_val.eq(self._min_val + self.decay_value))),
# Update maximum value # Update maximum value
If(self.data > max_val, max_val.eq(self.data)), If(self.data > self._max_val, self._max_val.eq(self.data)),
# Update minimum value # Update minimum value
If(self.data < min_val, min_val.eq(self.data)), If(self.data < self._min_val, self._min_val.eq(self.data)),
If(diff > self.thresh_value, If(self._diff > self.thresh_value,
# We have met the threshold for triggering, start counting # We have met the threshold for triggering, start counting
triggered_time.eq(triggered_time + 1), self._triggered_time.eq(self._triggered_time + 1),
decay_counter.eq(0),
# We have triggered, so we can set the output. After this point, # We have triggered, so we can set the output. After this point,
# nothing we do matters until enable is de-asserted and we reset # nothing we do matters until enable is de-asserted and we reset
# triggered. # triggered.
If(triggered_time + 1 >= self.thresh_time, self.triggered.eq(1))) If(self._triggered_time + 1 >= self.thresh_time, self.triggered.eq(1)))
.Else( .Else(
# We have not met the threshold, reset timer and handle decay # We have not met the threshold, reset timer
triggered_time.eq(0), self._triggered_time.eq(0),
decay_counter.eq(decay_counter + 1), ),
# Decay threshold has been reached, apply decay to peaks
If(decay_counter >= self.decay_period,
decay_counter.eq(0),
# Only apply decay if the values would not overlap
If(diff >= (self.decay_value << 1),
max_val.eq(max_val - self.decay_value),
min_val.eq(min_val + self.decay_value)))
)
) )
from typing import Tuple
import numpy as np
import matplotlib.pyplot as plt
def create_waveform(*, dc_bias: int = 0, scale: float = 1) -> Tuple[np.ndarray[float], np.ndarray[int]]:
"""
Create a simple 40kHz sine wave in integer values that can be used by peak detector
"""
assert scale <= 1.0, "Scale factor must be some ratio of full range"
# Constants
f_s = 10e6 # Sample rate (Hz)
f = 40e3 # Signal Frequency (Hz)
t = 0.002 # Sample period (s)
n = int(f_s * 0.002) # Number of samples
# Create time from 0ms to 2ms
x = np.linspace(0, t, n)
# Create signal!
y = np.sin(x * 2*np.pi*f)
# Scale according to user inputs
y = y * scale
# Convert to positive integer at provided bias
signal = np.ndarray((len(y)), dtype=np.uint16)
# "unsafe" casting because numpy doesn't know we are staying under 10 bit values
np.copyto(signal, y * 512 + 512 + dc_bias, casting='unsafe')
#plt.plot(x, signal)
#plt.show()
return x, signal
def set_settings(dut: PeakDetector, thresh_value: int, thresh_time: int, decay_value: int, decay_period: int) -> None:
"""Set peak detector settings simply"""
(yield dut.thresh_value.eq(thresh_value))
(yield dut.thresh_time.eq(thresh_time))
(yield dut.decay_value.eq(decay_value))
(yield dut.decay_period.eq(decay_period))
# Load in values with a new clock
yield
# NOTE: These tests aren't really exhaustive. They get good coverage and generally outline results,
# but do require some amount of manual validation and checking of waveforms if things change
# majorly. At least they are a small set of things that I don't have to re-create later.
def test_simple_waveform():
"""Test an ideal waveform with simple settings guaranteed to trigger"""
(_, signal) = create_waveform()
dut = PeakDetector(10)
def test_fn():
# First set settings to be simple, we want to trigger pretty much immediately
yield from set_settings(dut, 800, 10, 0, 0)
# Enable device
(yield dut.enable.eq(1))
yield
# Load data in until we trigger
for i, val in enumerate(signal):
if (yield dut.triggered) == 1:
# Test passed!
return
# Load in data, set valid
(yield dut.data.eq(int(val)))
(yield dut.data_valid.eq(1))
yield # Tick
assert False, "No trigger, test has failed..."
run_simulation(dut, test_fn())
def test_scrunched_simple_waveform():
"""Test a smaller waveform with smaller peaks"""
(_, signal) = create_waveform(scale=0.4)
dut = PeakDetector(10)
def test_fn():
yield from set_settings(dut, 200, 10, 0, 0)
# Enable
(yield dut.enable.eq(1))
yield
# Load data in until we trigger
for i, val in enumerate(signal):
if (yield dut.triggered) == 1:
# Test passed!
return
# Load in data, set valid
(yield dut.data.eq(int(val)))
(yield dut.data_valid.eq(1))
yield # Tick
assert False, "No trigger, test has failed..."
run_simulation(dut, test_fn())
def test_decay_simple_waveform():
"""Test that simple case of decay works """
(_, signal) = create_waveform()
dut = PeakDetector(10)
def test_fn():
yield from set_settings(dut, 800, 10, 5, 10)
# Enable
(yield dut.enable.eq(1))
yield
# Load data in until we trigger
for i, val in enumerate(signal):
if (yield dut.triggered) == 1:
# Test passed!
return
# Load in data, set valid
(yield dut.data.eq(int(val)))
(yield dut.data_valid.eq(1))
yield # Tick
assert False, "No trigger, test has failed..."
run_simulation(dut, test_fn())
def test_decay_simple_waveform_too_much():
"""Test that we can overuse decay and discard valid waveforms"""
(_, signal) = create_waveform()
dut = PeakDetector(10)
def test_fn():
yield from set_settings(dut, 800, 10, 40, 0)
# Enable
(yield dut.enable.eq(1))
yield
# Load data in, ensuring we don't trigger
for i, val in enumerate(signal):
assert (yield dut.triggered) == 0, "Must not trigger!"
# Load in data, set valid
(yield dut.data.eq(int(val)))
(yield dut.data_valid.eq(1))
yield # Tick
run_simulation(dut, test_fn())
def test_decay_compensates_bias():
signal = [800] * int(20e3)
dut = PeakDetector(10)
def test_fn():
yield from set_settings(dut, 800 - 512, 20, 1, 0)
# Enable
(yield dut.enable.eq(1))
yield
for i, val in enumerate(signal):
assert (yield dut.triggered) == 0, "Must not trigger!"
assert i < 500, "Decay must not take too long to work!"
# min val won't necessarily match max_val because of the crossover, but as long
# as it's close enough, we're good
if (yield dut._max_val) == 800 and (yield dut._min_val) >= 798:
# Test pass
return
# Load data in
(yield dut.data.eq(val))
(yield dut.data_valid.eq(1))
yield # Tick
run_simulation(dut, test_fn())
def test_biased_simple_waveform():
"""Test a biased waveform. This does also need to test decay, else we get invalid results
TODO this test is slightly broken. I think it's passing on the initial bias, not later stuff.
It should get fixed, but I think it mostly covers area already covered in other tests so I'm
fine with this for now.
"""
(_, signal) = create_waveform(dc_bias=200)
dut = PeakDetector(10)
def test_fn():
yield from set_settings(dut, 200, 20, 20, 1)
# Enable
(yield dut.enable.eq(1))
yield
# Load data in until we trigger
for i, val in enumerate(signal):
if (yield dut.triggered) == 1:
# Test passed!
return
# Load in data, set valid
(yield dut.data.eq(int(val)))
(yield dut.data_valid.eq(1))
yield # Tick
assert False, "No trigger, test has failed..."
run_simulation(dut, test_fn())
def test_noise_spike():
"""Test that appropriate filtering and decay can filter out a spike in noise"""
signal = [512, 512, 512, 1024, 1024] + [512] * 1000
dut = PeakDetector(10)
def test_fn():
yield from set_settings(dut, 300, 20, 20, 0)
# Enable
(yield dut.enable.eq(1))
yield # Tick
# Load data in until we trigger
for val in signal:
assert (yield dut.triggered) == 0, "Can't trigger!"
# Load in data, set valid
(yield dut.data.eq(int(val)))
(yield dut.data_valid.eq(1))
yield # Tick
# Test success
return
run_simulation(dut, test_fn())

View File

@ -2,7 +2,7 @@
Helper functions for a basic test suite Helper functions for a basic test suite
""" """
from typing import Callable from typing import Callable, List
from enum import StrEnum from enum import StrEnum
from dataclasses import dataclass from dataclasses import dataclass
from traceback import print_exc from traceback import print_exc
@ -25,15 +25,21 @@ class TestInfo:
return f"[{self.suite_name}.{self.test_name}] {self.result}" return f"[{self.suite_name}.{self.test_name}] {self.result}"
skipped_suites: List[str] = []
def skip_suite(suite_name: str):
"""Skips running tests from a specific suite"""
skipped_suites.append(suite_name)
def run_test(suite_name: str, test_fn: Callable, do_skip = False) -> TestResult: def run_test(suite_name: str, test_fn: Callable, do_skip = False) -> TestResult:
test_name = test_fn.__name__ test_name = test_fn.__name__
print(f"[{suite_name}.{test_name}] Running...") if do_skip or suite_name in skipped_suites:
if do_skip:
res = TestInfo(suite_name, test_name, TestResult.SKIP) res = TestInfo(suite_name, test_name, TestResult.SKIP)
else: else:
print(f"[{suite_name}.{test_name}] Running...")
try: try:
test_fn() test_fn()
res = TestInfo(suite_name, test_name, TestResult.PASS) res = TestInfo(suite_name, test_name, TestResult.PASS)
@ -41,7 +47,5 @@ def run_test(suite_name: str, test_fn: Callable, do_skip = False) -> TestResult:
res = TestInfo(suite_name, test_name, TestResult.FAIL) res = TestInfo(suite_name, test_name, TestResult.FAIL)
print_exc() print_exc()
print(res) print(res)
return res return res