Created
June 18, 2024 02:49
-
-
Save whitequark/62e83edd46bb4fd78b7241377615f2f0 to your computer and use it in GitHub Desktop.
Revisions
-
whitequark created this gist
Jun 18, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,210 @@ from amaranth import * from amaranth.lib import data, wiring, stream, io from amaranth.lib.wiring import In, Out from amaranth.sim import Simulator class BitSerializer(wiring.Component): def __init__(self, *, width, length): self._length = length super().__init__({ "en": In(1), "stream": In(stream.Signature(data.StructLayout({ "o": data.ArrayLayout(width, length), "oe": 1 }))), "buffer": Out(io.FFBuffer.Signature("o", width)), }) def elaborate(self, platform): m = Module() o_reg = Signal.like(self.stream.p.o) oe_reg = Signal.like(self.stream.p.oe) m.d.comb += self.buffer.o.eq(o_reg[0]) m.d.comb += self.buffer.oe.eq(oe_reg) count = Signal(range(self._length)) with m.If(self.en): with m.If(count == 0): with m.If(self.stream.valid): m.d.comb += self.stream.ready.eq(1) m.d.sync += count.eq(self._length - 1) m.d.sync += o_reg.eq(self.stream.p.o) m.d.sync += oe_reg.eq(self.stream.p.oe) with m.Else(): m.d.sync += count.eq(count - 1) m.d.sync += o_reg.eq(o_reg.as_value()[len(o_reg[0]):]) return m class BitDeserializer(wiring.Component): def __init__(self, *, width, length): self._length = length super().__init__({ "en": In(1), "stream": Out(stream.Signature(data.StructLayout({ "i": data.ArrayLayout(width, length), }))), "buffer": Out(io.FFBuffer.Signature("i", width)), }) def elaborate(self, platform): m = Module() i_reg = Signal.like(self.stream.p.i) m.d.comb += self.stream.p.i.eq(i_reg) count = Signal(range(self._length)) with m.If(self.stream.valid): with m.If(self.stream.ready): m.d.sync += self.stream.valid.eq(0) with m.Elif(self.en): with m.If(count == self._length - 1): m.d.sync += count.eq(0) m.d.sync += self.stream.valid.eq(1) with m.Else(): m.d.sync += count.eq(count + 1) m.d.sync += i_reg.eq(Cat(i_reg.as_value()[len(i_reg[0]):], self.buffer.i)) return m class BitEnableGenerator(wiring.Component): cycles: In(stream.Signature(8)) # how many cycles to produce clk_en: Out(1) # high for clock negedge and clock posedge o_en: Out(1) # high for clock negedge i_en: Out(1) # high for clock posedge, delayed by `latency` cycles def __init__(self, *, half_period, latency): self._half_period = half_period self._latency = latency super().__init__() def elaborate(self, platform): m = Module() negedge = Signal() posedge = Signal() m.d.comb += self.clk_en.eq(negedge | posedge) m.d.comb += self.o_en.eq(negedge) i_en = posedge for _ in range(self._latency): i_en_delay = Signal() m.d.sync += i_en_delay.eq(i_en) i_en = i_en_delay m.d.comb += self.i_en.eq(i_en) count = Signal.like(self.cycles.payload) timer = Signal(range(self._half_period)) phase = Signal() with m.If(count == 0): m.d.comb += self.cycles.ready.eq(1) with m.If(self.cycles.valid): m.d.sync += count.eq(self.cycles.payload) m.d.sync += timer.eq(0) m.d.sync += phase.eq(0) with m.Else(): # meow with m.If(timer == self._half_period - 1): m.d.sync += count.eq(Mux(phase, count - 1, count)) m.d.sync += timer.eq(0) m.d.sync += phase.eq(~phase) m.d.comb += negedge.eq(phase == 0) m.d.comb += posedge.eq(phase == 1) with m.Else(): m.d.sync += timer.eq(timer + 1) return m class SPIControllerBus(wiring.Component): o_stream: In(stream.Signature(8)) i_stream: Out(stream.Signature(8)) sck_buffer: Out(io.FFBuffer.Signature("o", 1)) copi_buffer: Out(io.FFBuffer.Signature("o", 1)) cipo_buffer: Out(io.FFBuffer.Signature("i", 1)) def __init__(self, *, half_period): self._half_period = half_period super().__init__() def elaborate(self, platform): m = Module() m.submodules.en_gen = en_gen = \ BitEnableGenerator(half_period=self._half_period, latency=1) m.d.comb += [ en_gen.cycles.p.eq(8), en_gen.cycles.valid.eq(self.o_stream.valid & self.i_stream.ready), ] m.submodules.sck_ser = sck_ser = BitSerializer(width=1, length=16) wiring.connect(m, clk=sck_ser.buffer, buf=wiring.flipped(self.sck_buffer)) m.d.comb += [ sck_ser.en.eq(en_gen.clk_en), sck_ser.stream.p.o.eq(0b1010101010101010), sck_ser.stream.p.oe.eq(1), sck_ser.stream.valid.eq(1), ] m.submodules.copi_ser = copi_ser = BitSerializer(width=1, length=8) wiring.connect(m, ser=copi_ser.buffer, buf=wiring.flipped(self.copi_buffer)) m.d.comb += [ copi_ser.en.eq(en_gen.o_en), copi_ser.stream.p.o.eq(self.o_stream.p), copi_ser.stream.p.oe.eq(1), copi_ser.stream.valid.eq(self.o_stream.valid), self.o_stream.ready.eq(copi_ser.stream.ready), ] m.submodules.cipo_des = cipo_des = BitDeserializer(width=1, length=8) wiring.connect(m, des=cipo_des.buffer, buf=wiring.flipped(self.cipo_buffer)) m.d.comb += [ cipo_des.en.eq(en_gen.i_en), self.i_stream.p.eq(cipo_des.stream.p), self.i_stream.valid.eq(cipo_des.stream.valid), cipo_des.stream.ready.eq(self.i_stream.ready), ] return m dut = Module() dut.submodules.bus = bus = SPIControllerBus(half_period=10) dut.d.sync += bus.cipo_buffer.i.eq(bus.copi_buffer.o) async def stream_get(ctx, stream): ctx.set(stream.ready, 1) payload, = await ctx.tick().sample(stream.payload).until(stream.valid) ctx.set(stream.ready, 0) return payload async def stream_put(ctx, stream, payload): ctx.set(stream.valid, 1) ctx.set(stream.payload, payload) await ctx.tick().until(stream.ready) ctx.set(stream.valid, 0) async def testbench_ser(ctx): await stream_put(ctx, bus.o_stream, 0x69) await stream_put(ctx, bus.o_stream, 0xAA) async def testbench_des(ctx): assert await stream_get(ctx, bus.i_stream) == 0x69 assert await stream_get(ctx, bus.i_stream) == 0xAA sim = Simulator(dut) sim.add_clock(1e-6) sim.add_testbench(testbench_ser) sim.add_testbench(testbench_des) with sim.write_vcd("test.vcd"): sim.run()