from amaranth import * from amaranth.lib import data, wiring, stream, io from amaranth.lib.wiring import In, Out from amaranth.sim import Simulator class BitSerializer(wiring.Component): def __init__(self, *, width, length): self._length = length super().__init__({ "en": In(1), "stream": In(stream.Signature(data.StructLayout({ "o": data.ArrayLayout(width, length), "oe": 1 }))), "buffer": Out(io.FFBuffer.Signature("o", width)), }) def elaborate(self, platform): m = Module() o_reg = Signal.like(self.stream.p.o) oe_reg = Signal.like(self.stream.p.oe) m.d.comb += self.buffer.o.eq(o_reg[0]) m.d.comb += self.buffer.oe.eq(oe_reg) count = Signal(range(self._length)) with m.If(self.en): with m.If(count == 0): with m.If(self.stream.valid): m.d.comb += self.stream.ready.eq(1) m.d.sync += count.eq(self._length - 1) m.d.sync += o_reg.eq(self.stream.p.o) m.d.sync += oe_reg.eq(self.stream.p.oe) with m.Else(): m.d.sync += count.eq(count - 1) m.d.sync += o_reg.eq(o_reg.as_value()[len(o_reg[0]):]) return m class BitDeserializer(wiring.Component): def __init__(self, *, width, length): self._length = length super().__init__({ "en": In(1), "stream": Out(stream.Signature(data.StructLayout({ "i": data.ArrayLayout(width, length), }))), "buffer": Out(io.FFBuffer.Signature("i", width)), }) def elaborate(self, platform): m = Module() i_reg = Signal.like(self.stream.p.i) m.d.comb += self.stream.p.i.eq(i_reg) count = Signal(range(self._length)) with m.If(self.stream.valid): with m.If(self.stream.ready): m.d.sync += self.stream.valid.eq(0) with m.Elif(self.en): with m.If(count == self._length - 1): m.d.sync += count.eq(0) m.d.sync += self.stream.valid.eq(1) with m.Else(): m.d.sync += count.eq(count + 1) m.d.sync += i_reg.eq(Cat(i_reg.as_value()[len(i_reg[0]):], self.buffer.i)) return m class BitEnableGenerator(wiring.Component): cycles: In(stream.Signature(8)) # how many cycles to produce clk_en: Out(1) # high for clock negedge and clock posedge o_en: Out(1) # high for clock negedge i_en: Out(1) # high for clock posedge, delayed by `latency` cycles def __init__(self, *, half_period, latency): self._half_period = half_period self._latency = latency super().__init__() def elaborate(self, platform): m = Module() negedge = Signal() posedge = Signal() m.d.comb += self.clk_en.eq(negedge | posedge) m.d.comb += self.o_en.eq(negedge) i_en = posedge for _ in range(self._latency): i_en_delay = Signal() m.d.sync += i_en_delay.eq(i_en) i_en = i_en_delay m.d.comb += self.i_en.eq(i_en) count = Signal.like(self.cycles.payload) timer = Signal(range(self._half_period)) phase = Signal() with m.If(count == 0): m.d.comb += self.cycles.ready.eq(1) with m.If(self.cycles.valid): m.d.sync += count.eq(self.cycles.payload) m.d.sync += timer.eq(0) m.d.sync += phase.eq(0) with m.Else(): # meow with m.If(timer == self._half_period - 1): m.d.sync += count.eq(Mux(phase, count - 1, count)) m.d.sync += timer.eq(0) m.d.sync += phase.eq(~phase) m.d.comb += negedge.eq(phase == 0) m.d.comb += posedge.eq(phase == 1) with m.Else(): m.d.sync += timer.eq(timer + 1) return m class SPIControllerBus(wiring.Component): o_stream: In(stream.Signature(8)) i_stream: Out(stream.Signature(8)) sck_buffer: Out(io.FFBuffer.Signature("o", 1)) copi_buffer: Out(io.FFBuffer.Signature("o", 1)) cipo_buffer: Out(io.FFBuffer.Signature("i", 1)) def __init__(self, *, half_period): self._half_period = half_period super().__init__() def elaborate(self, platform): m = Module() m.submodules.en_gen = en_gen = \ BitEnableGenerator(half_period=self._half_period, latency=1) m.d.comb += [ en_gen.cycles.p.eq(8), en_gen.cycles.valid.eq(self.o_stream.valid & self.i_stream.ready), ] m.submodules.sck_ser = sck_ser = BitSerializer(width=1, length=16) wiring.connect(m, clk=sck_ser.buffer, buf=wiring.flipped(self.sck_buffer)) m.d.comb += [ sck_ser.en.eq(en_gen.clk_en), sck_ser.stream.p.o.eq(0b1010101010101010), sck_ser.stream.p.oe.eq(1), sck_ser.stream.valid.eq(1), ] m.submodules.copi_ser = copi_ser = BitSerializer(width=1, length=8) wiring.connect(m, ser=copi_ser.buffer, buf=wiring.flipped(self.copi_buffer)) m.d.comb += [ copi_ser.en.eq(en_gen.o_en), copi_ser.stream.p.o.eq(self.o_stream.p), copi_ser.stream.p.oe.eq(1), copi_ser.stream.valid.eq(self.o_stream.valid), self.o_stream.ready.eq(copi_ser.stream.ready), ] m.submodules.cipo_des = cipo_des = BitDeserializer(width=1, length=8) wiring.connect(m, des=cipo_des.buffer, buf=wiring.flipped(self.cipo_buffer)) m.d.comb += [ cipo_des.en.eq(en_gen.i_en), self.i_stream.p.eq(cipo_des.stream.p), self.i_stream.valid.eq(cipo_des.stream.valid), cipo_des.stream.ready.eq(self.i_stream.ready), ] return m dut = Module() dut.submodules.bus = bus = SPIControllerBus(half_period=10) dut.d.sync += bus.cipo_buffer.i.eq(bus.copi_buffer.o) async def stream_get(ctx, stream): ctx.set(stream.ready, 1) payload, = await ctx.tick().sample(stream.payload).until(stream.valid) ctx.set(stream.ready, 0) return payload async def stream_put(ctx, stream, payload): ctx.set(stream.valid, 1) ctx.set(stream.payload, payload) await ctx.tick().until(stream.ready) ctx.set(stream.valid, 0) async def testbench_ser(ctx): await stream_put(ctx, bus.o_stream, 0x69) await stream_put(ctx, bus.o_stream, 0xAA) async def testbench_des(ctx): assert await stream_get(ctx, bus.i_stream) == 0x69 assert await stream_get(ctx, bus.i_stream) == 0xAA sim = Simulator(dut) sim.add_clock(1e-6) sim.add_testbench(testbench_ser) sim.add_testbench(testbench_des) with sim.write_vcd("test.vcd"): sim.run()