Skip to content

Instantly share code, notes, and snippets.

@whitequark
Created June 18, 2024 02:49
Show Gist options
  • Select an option

  • Save whitequark/62e83edd46bb4fd78b7241377615f2f0 to your computer and use it in GitHub Desktop.

Select an option

Save whitequark/62e83edd46bb4fd78b7241377615f2f0 to your computer and use it in GitHub Desktop.

Revisions

  1. whitequark created this gist Jun 18, 2024.
    210 changes: 210 additions & 0 deletions spi_serdes.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,210 @@
    from amaranth import *
    from amaranth.lib import data, wiring, stream, io
    from amaranth.lib.wiring import In, Out
    from amaranth.sim import Simulator


    class BitSerializer(wiring.Component):
    def __init__(self, *, width, length):
    self._length = length

    super().__init__({
    "en": In(1),
    "stream": In(stream.Signature(data.StructLayout({
    "o": data.ArrayLayout(width, length),
    "oe": 1
    }))),
    "buffer": Out(io.FFBuffer.Signature("o", width)),
    })

    def elaborate(self, platform):
    m = Module()

    o_reg = Signal.like(self.stream.p.o)
    oe_reg = Signal.like(self.stream.p.oe)
    m.d.comb += self.buffer.o.eq(o_reg[0])
    m.d.comb += self.buffer.oe.eq(oe_reg)

    count = Signal(range(self._length))
    with m.If(self.en):
    with m.If(count == 0):
    with m.If(self.stream.valid):
    m.d.comb += self.stream.ready.eq(1)
    m.d.sync += count.eq(self._length - 1)
    m.d.sync += o_reg.eq(self.stream.p.o)
    m.d.sync += oe_reg.eq(self.stream.p.oe)
    with m.Else():
    m.d.sync += count.eq(count - 1)
    m.d.sync += o_reg.eq(o_reg.as_value()[len(o_reg[0]):])

    return m


    class BitDeserializer(wiring.Component):
    def __init__(self, *, width, length):
    self._length = length

    super().__init__({
    "en": In(1),
    "stream": Out(stream.Signature(data.StructLayout({
    "i": data.ArrayLayout(width, length),
    }))),
    "buffer": Out(io.FFBuffer.Signature("i", width)),
    })

    def elaborate(self, platform):
    m = Module()

    i_reg = Signal.like(self.stream.p.i)
    m.d.comb += self.stream.p.i.eq(i_reg)

    count = Signal(range(self._length))
    with m.If(self.stream.valid):
    with m.If(self.stream.ready):
    m.d.sync += self.stream.valid.eq(0)
    with m.Elif(self.en):
    with m.If(count == self._length - 1):
    m.d.sync += count.eq(0)
    m.d.sync += self.stream.valid.eq(1)
    with m.Else():
    m.d.sync += count.eq(count + 1)
    m.d.sync += i_reg.eq(Cat(i_reg.as_value()[len(i_reg[0]):], self.buffer.i))

    return m


    class BitEnableGenerator(wiring.Component):
    cycles: In(stream.Signature(8)) # how many cycles to produce

    clk_en: Out(1) # high for clock negedge and clock posedge
    o_en: Out(1) # high for clock negedge
    i_en: Out(1) # high for clock posedge, delayed by `latency` cycles

    def __init__(self, *, half_period, latency):
    self._half_period = half_period
    self._latency = latency

    super().__init__()

    def elaborate(self, platform):
    m = Module()

    negedge = Signal()
    posedge = Signal()
    m.d.comb += self.clk_en.eq(negedge | posedge)
    m.d.comb += self.o_en.eq(negedge)
    i_en = posedge
    for _ in range(self._latency):
    i_en_delay = Signal()
    m.d.sync += i_en_delay.eq(i_en)
    i_en = i_en_delay
    m.d.comb += self.i_en.eq(i_en)

    count = Signal.like(self.cycles.payload)
    timer = Signal(range(self._half_period))
    phase = Signal()

    with m.If(count == 0):
    m.d.comb += self.cycles.ready.eq(1)
    with m.If(self.cycles.valid):
    m.d.sync += count.eq(self.cycles.payload)
    m.d.sync += timer.eq(0)
    m.d.sync += phase.eq(0)
    with m.Else():
    # meow
    with m.If(timer == self._half_period - 1):
    m.d.sync += count.eq(Mux(phase, count - 1, count))
    m.d.sync += timer.eq(0)
    m.d.sync += phase.eq(~phase)
    m.d.comb += negedge.eq(phase == 0)
    m.d.comb += posedge.eq(phase == 1)
    with m.Else():
    m.d.sync += timer.eq(timer + 1)

    return m


    class SPIControllerBus(wiring.Component):
    o_stream: In(stream.Signature(8))
    i_stream: Out(stream.Signature(8))

    sck_buffer: Out(io.FFBuffer.Signature("o", 1))
    copi_buffer: Out(io.FFBuffer.Signature("o", 1))
    cipo_buffer: Out(io.FFBuffer.Signature("i", 1))

    def __init__(self, *, half_period):
    self._half_period = half_period

    super().__init__()

    def elaborate(self, platform):
    m = Module()

    m.submodules.en_gen = en_gen = \
    BitEnableGenerator(half_period=self._half_period, latency=1)
    m.d.comb += [
    en_gen.cycles.p.eq(8),
    en_gen.cycles.valid.eq(self.o_stream.valid & self.i_stream.ready),
    ]

    m.submodules.sck_ser = sck_ser = BitSerializer(width=1, length=16)
    wiring.connect(m, clk=sck_ser.buffer, buf=wiring.flipped(self.sck_buffer))
    m.d.comb += [
    sck_ser.en.eq(en_gen.clk_en),
    sck_ser.stream.p.o.eq(0b1010101010101010),
    sck_ser.stream.p.oe.eq(1),
    sck_ser.stream.valid.eq(1),
    ]

    m.submodules.copi_ser = copi_ser = BitSerializer(width=1, length=8)
    wiring.connect(m, ser=copi_ser.buffer, buf=wiring.flipped(self.copi_buffer))
    m.d.comb += [
    copi_ser.en.eq(en_gen.o_en),
    copi_ser.stream.p.o.eq(self.o_stream.p),
    copi_ser.stream.p.oe.eq(1),
    copi_ser.stream.valid.eq(self.o_stream.valid),
    self.o_stream.ready.eq(copi_ser.stream.ready),
    ]

    m.submodules.cipo_des = cipo_des = BitDeserializer(width=1, length=8)
    wiring.connect(m, des=cipo_des.buffer, buf=wiring.flipped(self.cipo_buffer))
    m.d.comb += [
    cipo_des.en.eq(en_gen.i_en),
    self.i_stream.p.eq(cipo_des.stream.p),
    self.i_stream.valid.eq(cipo_des.stream.valid),
    cipo_des.stream.ready.eq(self.i_stream.ready),
    ]

    return m


    dut = Module()
    dut.submodules.bus = bus = SPIControllerBus(half_period=10)
    dut.d.sync += bus.cipo_buffer.i.eq(bus.copi_buffer.o)

    async def stream_get(ctx, stream):
    ctx.set(stream.ready, 1)
    payload, = await ctx.tick().sample(stream.payload).until(stream.valid)
    ctx.set(stream.ready, 0)
    return payload

    async def stream_put(ctx, stream, payload):
    ctx.set(stream.valid, 1)
    ctx.set(stream.payload, payload)
    await ctx.tick().until(stream.ready)
    ctx.set(stream.valid, 0)

    async def testbench_ser(ctx):
    await stream_put(ctx, bus.o_stream, 0x69)
    await stream_put(ctx, bus.o_stream, 0xAA)

    async def testbench_des(ctx):
    assert await stream_get(ctx, bus.i_stream) == 0x69
    assert await stream_get(ctx, bus.i_stream) == 0xAA

    sim = Simulator(dut)
    sim.add_clock(1e-6)
    sim.add_testbench(testbench_ser)
    sim.add_testbench(testbench_des)
    with sim.write_vcd("test.vcd"):
    sim.run()