Skip to content

Instantly share code, notes, and snippets.

@whitequark
Last active May 14, 2025 18:48
Show Gist options
  • Save whitequark/5ef7f42c82da32331f7bf72e2d78ceb1 to your computer and use it in GitHub Desktop.
Save whitequark/5ef7f42c82da32331f7bf72e2d78ceb1 to your computer and use it in GitHub Desktop.

Revisions

  1. whitequark revised this gist Jun 21, 2024. 1 changed file with 2 additions and 1 deletion.
    3 changes: 2 additions & 1 deletion qspi_iostream.py
    Original file line number Diff line number Diff line change
    @@ -120,6 +120,7 @@ def elaborate(self, platform):
    with m.Case(QSPIMode.PutX1, QSPIMode.Swap):
    m.d.comb += self.frames.p.o.eq(rev_data.word_select(cycle, 1)[::-1])
    m.d.comb += self.frames.p.oe.eq(Cat(1, 0, 0, 0))
    m.d.comb += self.frames.p.i_en.eq(self.octets.p.mode == QSPIMode.Swap)
    with m.Case(QSPIMode.GetX1):
    m.d.comb += self.frames.p.oe.eq(Cat(1, 0, 0, 0))
    m.d.comb += self.frames.p.i_en.eq(1)
    @@ -431,7 +432,7 @@ async def bits_get(*, ox, oe, i_en, meta):
    assert (actual := await stream_get(ctx, dut.frames)) == expected, \
    f"(cycle {cycle}) {actual} != {expected}; o: {actual.o:04b} != {expected.o:04b}"

    await bits_get(ox=[1,0,1,1,1,0,1,0], oe=1, i_en=0, meta=QSPIMode.Swap)
    await bits_get(ox=[1,0,1,1,1,0,1,0], oe=1, i_en=1, meta=QSPIMode.Swap)

    await bits_get(ox=[1,0,1,0,1,0,1,0], oe=1, i_en=0, meta=QSPIMode.PutX1)
    await bits_get(ox=[0,1,0,1,0,1,0,1], oe=1, i_en=0, meta=QSPIMode.PutX1)
  2. whitequark revised this gist Jun 21, 2024. 1 changed file with 29 additions and 22 deletions.
    51 changes: 29 additions & 22 deletions qspi_iostream.py
    Original file line number Diff line number Diff line change
    @@ -75,17 +75,15 @@ def delay(value, name):
    return m


    class QSPIMode(enum.Enum):
    PutX1 = 0
    GetX1 = 1

    PutX2 = 2
    GetX2 = 3
    NopX2 = 4

    class QSPIMode(enum.Enum, shape=3):
    Swap = 0 # normal SPI
    PutX1 = 1
    GetX1 = 2
    PutX2 = 3
    GetX2 = 4
    PutX4 = 5
    GetX4 = 6
    NopX4 = 7
    Dummy = 7


    class QSPIEnframer(wiring.Component):
    @@ -107,17 +105,19 @@ def elaborate(self, platform):
    m.d.comb += self.frames.valid.eq(self.octets.valid)
    with m.If(self.octets.valid & self.frames.ready):
    with m.Switch(self.octets.p.mode):
    with m.Case(QSPIMode.PutX1, QSPIMode.GetX1):
    with m.Case(QSPIMode.PutX1, QSPIMode.GetX1, QSPIMode.Swap):
    m.d.comb += self.octets.ready.eq(cycle == 7)
    with m.Case(QSPIMode.PutX2, QSPIMode.GetX2, QSPIMode.NopX2):
    with m.Case(QSPIMode.PutX2, QSPIMode.GetX2):
    m.d.comb += self.octets.ready.eq(cycle == 3)
    with m.Case(QSPIMode.PutX4, QSPIMode.GetX4, QSPIMode.NopX4):
    with m.Case(QSPIMode.PutX4, QSPIMode.GetX4):
    m.d.comb += self.octets.ready.eq(cycle == 1)
    with m.Case(QSPIMode.Dummy):
    m.d.comb += self.octets.ready.eq(cycle == 0)
    m.d.sync += cycle.eq(Mux(self.octets.ready, 0, cycle + 1))

    rev_data = self.octets.p.data[::-1] # flipped to have MSB at 0; flipped back below
    with m.Switch(self.octets.p.mode):
    with m.Case(QSPIMode.PutX1):
    with m.Case(QSPIMode.PutX1, QSPIMode.Swap):
    m.d.comb += self.frames.p.o.eq(rev_data.word_select(cycle, 1)[::-1])
    m.d.comb += self.frames.p.oe.eq(Cat(1, 0, 0, 0))
    with m.Case(QSPIMode.GetX1):
    @@ -154,7 +154,7 @@ def elaborate(self, platform):
    m.d.comb += self.frames.ready.eq(~self.octets.valid | self.octets.ready)
    with m.If(self.frames.valid):
    with m.Switch(self.frames.p.meta):
    with m.Case(QSPIMode.GetX1):
    with m.Case(QSPIMode.GetX1, QSPIMode.Swap):
    m.d.comb += self.octets.valid.eq(cycle == 7)
    with m.Case(QSPIMode.GetX2):
    m.d.comb += self.octets.valid.eq(cycle == 3)
    @@ -165,7 +165,7 @@ def elaborate(self, platform):

    data_reg = Signal(8)
    with m.Switch(self.frames.p.meta):
    with m.Case(QSPIMode.GetX1): # unlike the enframer or the x2/x4 deframer, samples IO1
    with m.Case(QSPIMode.GetX1, QSPIMode.Swap): # note: samples IO1
    m.d.comb += self.octets.p.data.eq(Cat(self.frames.p.i[1], data_reg))
    with m.If(self.frames.valid & self.frames.ready):
    m.d.sync += data_reg.eq(Cat(self.frames.p.i[1], data_reg))
    @@ -202,7 +202,7 @@ def elaborate(self, platform):
    m.submodules.deframer = deframer = QSPIDeframer()
    connect(m, controller=flipped(self.i_octets), deframer=deframer.octets)

    latency = 0 if platform is None else 2 # amaranth-lang/amaranth#1417
    latency = 0 if platform is None else 2 # FIXME: amaranth-lang/amaranth#1417
    m.submodules.iostream = iostream = IOStream(5, meta_layout=QSPIMode, latency=latency)
    for n in range(4):
    connect(m, controller=flipped(self.io_buffers[n]), iostream=iostream.buffers[n])
    @@ -401,6 +401,8 @@ async def data_put(*, data, mode):
    # amaranth-lang/amaranth#1413
    await stream_put(ctx, dut.octets, {"data": data, "mode": mode.value})

    await data_put(data=0xBA, mode=QSPIMode.Swap)

    await data_put(data=0xAA, mode=QSPIMode.PutX1)
    await data_put(data=0x55, mode=QSPIMode.PutX1)
    await data_put(data=0xC1, mode=QSPIMode.PutX1)
    @@ -413,8 +415,8 @@ async def data_put(*, data, mode):
    await data_put(data=0x55, mode=QSPIMode.PutX4)
    await data_put(data=0xC1, mode=QSPIMode.PutX4)

    await data_put(data=0, mode=QSPIMode.NopX2)
    await data_put(data=0, mode=QSPIMode.NopX4)
    for _ in range(6):
    await data_put(data=0, mode=QSPIMode.Dummy)

    await data_put(data=0, mode=QSPIMode.GetX1)
    await data_put(data=0, mode=QSPIMode.GetX2)
    @@ -429,6 +431,8 @@ async def bits_get(*, ox, oe, i_en, meta):
    assert (actual := await stream_get(ctx, dut.frames)) == expected, \
    f"(cycle {cycle}) {actual} != {expected}; o: {actual.o:04b} != {expected.o:04b}"

    await bits_get(ox=[1,0,1,1,1,0,1,0], oe=1, i_en=0, meta=QSPIMode.Swap)

    await bits_get(ox=[1,0,1,0,1,0,1,0], oe=1, i_en=0, meta=QSPIMode.PutX1)
    await bits_get(ox=[0,1,0,1,0,1,0,1], oe=1, i_en=0, meta=QSPIMode.PutX1)
    await bits_get(ox=[1,1,0,0,0,0,0,1], oe=1, i_en=0, meta=QSPIMode.PutX1)
    @@ -441,8 +445,7 @@ async def bits_get(*, ox, oe, i_en, meta):
    await bits_get(ox=[0b0101,0b0101], oe=0b1111, i_en=0, meta=QSPIMode.PutX4)
    await bits_get(ox=[0b1100,0b0001], oe=0b1111, i_en=0, meta=QSPIMode.PutX4)

    await bits_get(ox=[0,0,0,0], oe=0, i_en=0, meta=QSPIMode.NopX2)
    await bits_get(ox=[0,0], oe=0, i_en=0, meta=QSPIMode.NopX4)
    await bits_get(ox=[0,0,0,0,0,0], oe=0, i_en=0, meta=QSPIMode.Dummy)

    await bits_get(ox=[0,0,0,0,0,0,0,0], oe=1, i_en=1, meta=QSPIMode.GetX1)
    await bits_get(ox=[0,0,0,0], oe=0, i_en=1, meta=QSPIMode.GetX2)
    @@ -465,6 +468,8 @@ async def bits_put(*, ix, meta):
    # amaranth-lang/amaranth#1413
    await stream_put(ctx, dut.frames, {"i": i, "meta": meta.value})

    await bits_put(ix=[i<<1 for i in [1,0,1,1,1,0,1,0]], meta=QSPIMode.Swap)

    await bits_put(ix=[i<<1 for i in [1,0,1,0,1,0,1,0]], meta=QSPIMode.GetX1)
    await bits_put(ix=[i<<1 for i in [0,1,0,1,0,1,0,1]], meta=QSPIMode.GetX1)
    await bits_put(ix=[i<<1 for i in [1,1,0,0,0,0,0,1]], meta=QSPIMode.GetX1)
    @@ -485,6 +490,8 @@ async def data_get(*, data):
    assert (actual := await stream_get(ctx, dut.octets)) == expected, \
    f"{actual} != {expected}; data: {actual.data:08b} != {expected.data:08b}"

    await data_get(data=0xBA)

    await data_get(data=0xAA)
    await data_get(data=0x55)
    await data_get(data=0xC1)
    @@ -616,8 +623,8 @@ async def ctrl_get(*, mode, count=1):
    await ctrl_put(mode=QSPIMode.PutX1, data=0x00)
    await ctrl_put(mode=QSPIMode.PutX1, data=0x00)
    await ctrl_put(mode=QSPIMode.PutX1, data=0x10)
    await ctrl_put(mode=QSPIMode.NopX4)
    await ctrl_put(mode=QSPIMode.NopX4)
    for _ in range(4):
    await ctrl_put(mode=QSPIMode.Dummy)
    assert (await ctrl_get(mode=QSPIMode.GetX4, count=8)) == b"nyaaaaan"

    sim = Simulator(dut)
  3. whitequark revised this gist Jun 21, 2024. No changes.
  4. whitequark renamed this gist Jun 21, 2024. 1 changed file with 158 additions and 5 deletions.
    163 changes: 158 additions & 5 deletions spi_iostream.py → qspi_iostream.py
    Original file line number Diff line number Diff line change
    @@ -152,15 +152,16 @@ def elaborate(self, platform):

    cycle = Signal(range(8))
    m.d.comb += self.frames.ready.eq(~self.octets.valid | self.octets.ready)
    with m.If(self.frames.valid & self.frames.ready):
    with m.If(self.frames.valid):
    with m.Switch(self.frames.p.meta):
    with m.Case(QSPIMode.GetX1):
    m.d.comb += self.octets.valid.eq(cycle == 7)
    with m.Case(QSPIMode.GetX2):
    m.d.comb += self.octets.valid.eq(cycle == 3)
    with m.Case(QSPIMode.GetX4):
    m.d.comb += self.octets.valid.eq(cycle == 1)
    m.d.sync += cycle.eq(Mux(self.octets.valid, 0, cycle + 1))
    with m.If(self.frames.ready):
    m.d.sync += cycle.eq(Mux(self.octets.valid, 0, cycle + 1))

    data_reg = Signal(8)
    with m.Switch(self.frames.p.meta):
    @@ -180,7 +181,7 @@ def elaborate(self, platform):
    return m


    class QSPIBusController(wiring.Component):
    class QSPIController(wiring.Component):
    o_octets: In(stream.Signature(data.StructLayout({
    "mode": QSPIMode,
    "data": 8
    @@ -195,7 +196,36 @@ class QSPIBusController(wiring.Component):
    def elaborate(self, platform):
    m = Module()


    m.submodules.enframer = enframer = QSPIEnframer()
    connect(m, controller=flipped(self.o_octets), enframer=enframer.octets)

    m.submodules.deframer = deframer = QSPIDeframer()
    connect(m, controller=flipped(self.i_octets), deframer=deframer.octets)

    latency = 0 if platform is None else 2 # amaranth-lang/amaranth#1417
    m.submodules.iostream = iostream = IOStream(5, meta_layout=QSPIMode, latency=latency)
    for n in range(4):
    connect(m, controller=flipped(self.io_buffers[n]), iostream=iostream.buffers[n])
    m.d.comb += self.sck_buffer.o.eq(iostream.buffers[4].o)

    phase = Signal()
    with m.If(enframer.frames.valid):
    m.d.sync += phase.eq(~phase)
    m.d.comb += [
    iostream.o_stream.p.o.eq(Cat(enframer.frames.p.o, phase)),
    iostream.o_stream.p.oe.eq(Cat(enframer.frames.p.oe, 1)),
    iostream.o_stream.p.i_en.eq(enframer.frames.p.i_en & phase),
    iostream.o_stream.p.meta.eq(enframer.frames.p.meta),
    iostream.o_stream.valid.eq(enframer.frames.valid),
    enframer.frames.ready.eq(iostream.o_stream.ready & phase),
    ]

    m.d.comb += [
    deframer.frames.p.i.eq(iostream.i_stream.p.i[:4]),
    deframer.frames.p.meta.eq(iostream.i_stream.p.meta),
    deframer.frames.valid.eq(iostream.i_stream.valid),
    iostream.i_stream.ready.eq(deframer.frames.ready),
    ]

    return m

    @@ -467,7 +497,6 @@ async def data_get(*, data):
    await data_get(data=0x55)
    await data_get(data=0xC1)


    sim = Simulator(dut)
    sim.add_clock(1e-6)
    sim.add_testbench(testbench_in)
    @@ -476,8 +505,132 @@ async def data_get(*, data):
    sim.run()


    async def dev_get(ctx, dut, *, x):
    sck = dut.sck_buffer.o
    io0, io1, io2, io3 = dut.io_buffers
    word = 0
    for _ in range(0, 8, x):
    if ctx.get(sck):
    await ctx.tick().until(~sck)
    io0_oe, io0_o, io1_oe, io1_o, io2_oe, io2_o, io3_oe, io3_o = \
    await ctx.tick().sample(io0.oe, io0.o, io1.oe, io1.o, io2.oe, io2.o, io3.oe, io3.o).until(sck)
    if x == 1:
    assert (io0_oe, io1_oe, io2_oe, io3_oe) == (1, 0, 0, 0)
    word = (word << 1) | (io0_o << 0)
    if x == 2:
    assert (io0_oe, io1_oe, io2_oe, io3_oe) == (1, 1, 0, 0)
    word = (word << 2) | (io1_o << 1) | (io0_o << 0)
    if x == 4:
    assert (io0_oe, io1_oe, io2_oe, io3_oe) == (1, 1, 1, 1)
    word = (word << 4) | (io3_o << 3) | (io2_o << 2) | (io1_o << 1) | (io0_o << 0)
    return word

    async def dev_nop(ctx, dut, *, x, cycles):
    sck = dut.sck_buffer.o
    io0, io1, io2, io3 = dut.io_buffers
    for _ in range(cycles):
    if ctx.get(sck):
    await ctx.tick().until(~sck)
    io0_oe, io1_oe, io2_oe, io3_oe = \
    await ctx.tick().sample(io0.oe, io1.oe, io2.oe, io3.oe).until(sck)
    assert (io0_oe, io1_oe, io2_oe, io3_oe) == (x == 1, 0, 0, 0)

    async def dev_put(ctx, dut, word, *, x):
    sck = dut.sck_buffer.o
    io0, io1, io2, io3 = dut.io_buffers
    for _ in range(0, 8, x):
    if ctx.get(sck):
    await ctx.tick().until(~sck)
    if x == 1:
    ctx.set(Cat(io1.i), (word >> 7))
    word = (word << 1) & 0xff
    if x == 2:
    ctx.set(Cat(io0.i, io1.i), (word >> 6))
    word = (word << 2) & 0xff
    if x == 4:
    ctx.set(Cat(io0.i, io1.i, io2.i, io3.i), (word >> 4))
    word = (word << 4) & 0xff
    io0_oe, io1_oe, io2_oe, io3_oe = \
    await ctx.tick().sample(io0.oe, io1.oe, io2.oe, io3.oe).until(sck)
    assert (io0_oe, io1_oe, io2_oe, io3_oe) == (x == 1, 0, 0, 0)


    def simulate_flash(dut, memory=b"nya nya nya nya nyaaaaan"):
    async def testbench(ctx):
    cmd = await dev_get(ctx, dut, x=1)
    if cmd in (0x0B, 0x3B, 0x6B):
    addr2 = await dev_get(ctx, dut, x=1)
    addr1 = await dev_get(ctx, dut, x=1)
    addr0 = await dev_get(ctx, dut, x=1)
    if cmd == 0x0B:
    await dev_nop(ctx, dut, x=1, cycles=8)
    if cmd == 0x3B:
    await dev_nop(ctx, dut, x=2, cycles=4)
    if cmd == 0x6B:
    await dev_nop(ctx, dut, x=4, cycles=4)
    addr = (addr2 << 16) | (addr1 << 8) | (addr0 << 0)
    while True:
    if addr >= len(memory):
    addr = 0
    if cmd == 0x0B:
    await dev_put(ctx, dut, memory[addr], x=1)
    if cmd == 0x3B:
    await dev_put(ctx, dut, memory[addr], x=2)
    if cmd == 0x6B:
    await dev_put(ctx, dut, memory[addr], x=4)
    addr += 1

    return testbench


    def test_qspi_controller():
    dut = QSPIController()

    async def testbench(ctx):
    async def ctrl_put(*, mode, data=0):
    await stream_put(ctx, dut.o_octets, {"data": data, "mode": mode.value})

    async def ctrl_get(*, mode, count=1):
    ctx.set(dut.o_octets.p.mode, mode)
    ctx.set(dut.o_octets.valid, 1)
    ctx.set(dut.i_octets.ready, 1)
    words = bytearray()
    o_count = i_count = 0
    while True:
    _, _, o_octets_ready, i_octets_valid, i_octets_p_data = \
    await ctx.tick().sample(dut.o_octets.ready,
    dut.i_octets.valid, dut.i_octets.p.data)
    if o_octets_ready:
    o_count += 1
    if o_count == count:
    ctx.set(dut.o_octets.valid, 0)
    if i_octets_valid:
    words.append(i_octets_p_data)
    if len(words) == count:
    ctx.set(dut.i_octets.ready, 0)
    assert not ctx.get(dut.o_octets.valid)
    break
    return words

    await ctrl_put(mode=QSPIMode.PutX1, data=0x6B)
    await ctrl_put(mode=QSPIMode.PutX1, data=0x00)
    await ctrl_put(mode=QSPIMode.PutX1, data=0x00)
    await ctrl_put(mode=QSPIMode.PutX1, data=0x10)
    await ctrl_put(mode=QSPIMode.NopX4)
    await ctrl_put(mode=QSPIMode.NopX4)
    assert (await ctrl_get(mode=QSPIMode.GetX4, count=8)) == b"nyaaaaan"

    sim = Simulator(dut)
    sim.add_clock(1e-6)
    sim.add_testbench(testbench)
    sim.add_testbench(simulate_flash(dut), background=True)
    with sim.write_vcd("qspi_controller.vcd"):
    sim.run()


    if __name__ == "__main__":
    test_iostream_basic()
    test_iostream_skid()
    test_qspi_enframer()
    test_qspi_deframer()
    test_qspi_controller()
  5. whitequark created this gist Jun 21, 2024.
    483 changes: 483 additions & 0 deletions spi_iostream.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,483 @@
    from amaranth import *
    from amaranth.lib import enum, data, wiring, stream, io
    from amaranth.lib.wiring import In, Out, connect, flipped
    from amaranth.sim import Simulator


    class IOStream(wiring.Component):
    def __init__(self, width, *, meta_layout=0, latency=0):
    self._latency = latency

    super().__init__({
    "o_stream": In(stream.Signature(data.StructLayout({
    "o": width,
    "oe": width,
    "i_en": 1,
    "meta": meta_layout,
    }))),
    "i_stream": Out(stream.Signature(data.StructLayout({
    "i": width,
    "meta": meta_layout,
    }))),
    "buffers": Out(io.FFBuffer.Signature("io", 1)).array(width),
    })

    def elaborate(self, platform):
    latency = self._latency # FIXME: should be platform dependent

    m = Module()

    buffer_i = Cat(buffer.i for buffer in self.buffers)
    buffer_o = Cat(buffer.o for buffer in self.buffers)
    buffer_oe = Cat(buffer.oe for buffer in self.buffers)

    o_reg = Signal.like(self.o_stream.p.o)
    oe_reg = Signal.like(self.o_stream.p.oe)
    with m.If(self.o_stream.valid & self.o_stream.ready):
    m.d.sync += o_reg.eq(self.o_stream.p.o)
    m.d.sync += oe_reg.eq(self.o_stream.p.oe)
    m.d.comb += buffer_o.eq(self.o_stream.p.o)
    m.d.comb += buffer_oe.eq(self.o_stream.p.oe)
    with m.Else():
    m.d.comb += buffer_o.eq(o_reg)
    m.d.comb += buffer_oe.eq(oe_reg)

    def delay(value, name):
    for stage in range(latency):
    next_value = Signal.like(value, name=f"{name}_{stage}")
    m.d.sync += next_value.eq(value)
    value = next_value
    return value

    meta = delay(self.o_stream.p.meta, name="meta")
    i_en = delay(self.o_stream.valid & self.o_stream.ready &
    self.o_stream.p.i_en, name="i_en")

    skid = Array(Signal(self.i_stream.payload.shape(), name=f"skid_{stage}")
    for stage in range(1 + latency))
    m.d.comb += skid[0].i.eq(buffer_i)
    m.d.comb += skid[0].meta.eq(meta)

    skid_at = Signal(range(1 + latency))
    with m.If(i_en & ~self.i_stream.ready):
    # m.d.sync += Assert(skid_at != latency)
    m.d.sync += skid_at.eq(skid_at + 1)
    for n_shift in range(latency):
    m.d.sync += skid[n_shift + 1].eq(skid[n_shift])
    with m.Elif((skid_at != 0) & self.i_stream.ready):
    m.d.sync += skid_at.eq(skid_at - 1)

    m.d.comb += self.i_stream.payload.i.eq(skid[skid_at].i)
    m.d.comb += self.i_stream.payload.meta.eq(skid[skid_at].meta)
    m.d.comb += self.i_stream.valid.eq(i_en | (skid_at != 0))
    m.d.comb += self.o_stream.ready.eq(self.i_stream.ready & (skid_at == 0))

    return m


    class QSPIMode(enum.Enum):
    PutX1 = 0
    GetX1 = 1

    PutX2 = 2
    GetX2 = 3
    NopX2 = 4

    PutX4 = 5
    GetX4 = 6
    NopX4 = 7


    class QSPIEnframer(wiring.Component):
    octets: In(stream.Signature(data.StructLayout({
    "mode": QSPIMode,
    "data": 8
    })))
    frames: Out(stream.Signature(data.StructLayout({
    "o": 4,
    "oe": 4,
    "i_en": 1,
    "meta": QSPIMode,
    })))

    def elaborate(self, platform):
    m = Module()

    cycle = Signal(range(8))
    m.d.comb += self.frames.valid.eq(self.octets.valid)
    with m.If(self.octets.valid & self.frames.ready):
    with m.Switch(self.octets.p.mode):
    with m.Case(QSPIMode.PutX1, QSPIMode.GetX1):
    m.d.comb += self.octets.ready.eq(cycle == 7)
    with m.Case(QSPIMode.PutX2, QSPIMode.GetX2, QSPIMode.NopX2):
    m.d.comb += self.octets.ready.eq(cycle == 3)
    with m.Case(QSPIMode.PutX4, QSPIMode.GetX4, QSPIMode.NopX4):
    m.d.comb += self.octets.ready.eq(cycle == 1)
    m.d.sync += cycle.eq(Mux(self.octets.ready, 0, cycle + 1))

    rev_data = self.octets.p.data[::-1] # flipped to have MSB at 0; flipped back below
    with m.Switch(self.octets.p.mode):
    with m.Case(QSPIMode.PutX1):
    m.d.comb += self.frames.p.o.eq(rev_data.word_select(cycle, 1)[::-1])
    m.d.comb += self.frames.p.oe.eq(Cat(1, 0, 0, 0))
    with m.Case(QSPIMode.GetX1):
    m.d.comb += self.frames.p.oe.eq(Cat(1, 0, 0, 0))
    m.d.comb += self.frames.p.i_en.eq(1)
    with m.Case(QSPIMode.PutX2):
    m.d.comb += self.frames.p.o.eq(rev_data.word_select(cycle, 2)[::-1])
    m.d.comb += self.frames.p.oe.eq(Cat(1, 1, 0, 0))
    with m.Case(QSPIMode.GetX2):
    m.d.comb += self.frames.p.i_en.eq(1)
    with m.Case(QSPIMode.PutX4):
    m.d.comb += self.frames.p.o.eq(rev_data.word_select(cycle, 4)[::-1])
    m.d.comb += self.frames.p.oe.eq(Cat(1, 1, 1, 1))
    with m.Case(QSPIMode.GetX4):
    m.d.comb += self.frames.p.i_en.eq(1)
    m.d.comb += self.frames.p.meta.eq(self.octets.p.mode)

    return m


    class QSPIDeframer(wiring.Component): # meow :3
    frames: In(stream.Signature(data.StructLayout({
    "i": 4,
    "meta": QSPIMode
    })))
    octets: Out(stream.Signature(data.StructLayout({
    "data": 8
    })))

    def elaborate(self, platform):
    m = Module()

    cycle = Signal(range(8))
    m.d.comb += self.frames.ready.eq(~self.octets.valid | self.octets.ready)
    with m.If(self.frames.valid & self.frames.ready):
    with m.Switch(self.frames.p.meta):
    with m.Case(QSPIMode.GetX1):
    m.d.comb += self.octets.valid.eq(cycle == 7)
    with m.Case(QSPIMode.GetX2):
    m.d.comb += self.octets.valid.eq(cycle == 3)
    with m.Case(QSPIMode.GetX4):
    m.d.comb += self.octets.valid.eq(cycle == 1)
    m.d.sync += cycle.eq(Mux(self.octets.valid, 0, cycle + 1))

    data_reg = Signal(8)
    with m.Switch(self.frames.p.meta):
    with m.Case(QSPIMode.GetX1): # unlike the enframer or the x2/x4 deframer, samples IO1
    m.d.comb += self.octets.p.data.eq(Cat(self.frames.p.i[1], data_reg))
    with m.If(self.frames.valid & self.frames.ready):
    m.d.sync += data_reg.eq(Cat(self.frames.p.i[1], data_reg))
    with m.Case(QSPIMode.GetX2):
    m.d.comb += self.octets.p.data.eq(Cat(self.frames.p.i[:2], data_reg))
    with m.If(self.frames.valid & self.frames.ready):
    m.d.sync += data_reg.eq(Cat(self.frames.p.i[:2], data_reg))
    with m.Case(QSPIMode.GetX4):
    m.d.comb += self.octets.p.data.eq(Cat(self.frames.p.i[:4], data_reg))
    with m.If(self.frames.valid & self.frames.ready):
    m.d.sync += data_reg.eq(Cat(self.frames.p.i[:4], data_reg))

    return m


    class QSPIBusController(wiring.Component):
    o_octets: In(stream.Signature(data.StructLayout({
    "mode": QSPIMode,
    "data": 8
    })))
    i_octets: Out(stream.Signature(data.StructLayout({
    "data": 8
    })))

    sck_buffer: Out(io.FFBuffer.Signature("o", 1))
    io_buffers: Out(io.FFBuffer.Signature("io", 1)).array(4)

    def elaborate(self, platform):
    m = Module()



    return m


    async def stream_get(ctx, stream):
    ctx.set(stream.ready, 1)
    payload, = await ctx.tick().sample(stream.payload).until(stream.valid)
    ctx.set(stream.ready, 0)
    return payload


    async def stream_put(ctx, stream, payload):
    ctx.set(stream.payload, payload)
    ctx.set(stream.valid, 1)
    await ctx.tick().until(stream.ready)
    ctx.set(stream.valid, 0)


    def test_iostream_basic():
    dut = IOStream(1, meta_layout=4, latency=2)

    m = Module()
    m.submodules.dut = dut
    io_reg = Signal()
    m.d.sync += io_reg.eq(Mux(dut.buffers[0].oe, dut.buffers[0].o, 0))
    m.d.sync += dut.buffers[0].i.eq(io_reg)

    async def testbench(ctx):
    await ctx.tick()

    ctx.set(dut.o_stream.p.o, 1)
    ctx.set(dut.o_stream.p.oe, 0)
    ctx.set(dut.o_stream.p.i_en, 1)
    ctx.set(dut.o_stream.p.meta, 1)
    ctx.set(dut.o_stream.valid, 1)
    ctx.set(dut.i_stream.ready, 1)
    assert ctx.get(dut.buffers[0].o) == 1
    assert ctx.get(dut.buffers[0].oe) == 0
    await ctx.tick()
    assert ctx.get(dut.i_stream.valid) == 0

    ctx.set(dut.o_stream.p.oe, 1)
    ctx.set(dut.o_stream.p.meta, 2)
    assert ctx.get(dut.buffers[0].o) == 1
    assert ctx.get(dut.buffers[0].oe) == 1
    await ctx.tick()
    assert ctx.get(dut.i_stream.valid) == 1
    assert ctx.get(dut.i_stream.p.i) == 0
    assert ctx.get(dut.i_stream.p.meta) == 1

    ctx.set(dut.o_stream.p.o, 0)
    ctx.set(dut.o_stream.p.i_en, 0)
    assert ctx.get(dut.buffers[0].o) == 0
    assert ctx.get(dut.buffers[0].oe) == 1
    await ctx.tick()
    assert ctx.get(dut.i_stream.valid) == 1
    assert ctx.get(dut.i_stream.p.i) == 1
    assert ctx.get(dut.i_stream.p.meta) == 2

    ctx.set(dut.o_stream.valid, 0)
    await ctx.tick()
    assert ctx.get(dut.i_stream.valid) == 0

    await ctx.tick()
    assert ctx.get(dut.i_stream.valid) == 0

    sim = Simulator(m)
    sim.add_clock(1e-6)
    sim.add_testbench(testbench)
    with sim.write_vcd("iostream_basic.vcd"):
    sim.run()


    def test_iostream_skid():
    dut = IOStream(4, meta_layout=4, latency=2)

    async def testbench(ctx):
    await ctx.tick()

    dut_buffers_i = Cat(buffer.i for buffer in dut.buffers)

    ctx.set(dut.o_stream.valid, 1)
    ctx.set(dut.o_stream.p.i_en, 1)

    _, _, o_stream_ready, i_stream_valid = \
    await ctx.tick().sample(dut.o_stream.ready, dut.i_stream.valid)
    assert o_stream_ready == 0
    assert i_stream_valid == 0

    _, _, o_stream_ready, i_stream_valid = \
    await ctx.tick().sample(dut.o_stream.ready, dut.i_stream.valid)
    assert o_stream_ready == 0
    assert i_stream_valid == 0

    ctx.set(dut.o_stream.p.meta, 0b0101)
    ctx.set(dut.i_stream.ready, 1)
    assert ctx.get(dut.o_stream.ready) == 1
    _, _, o_stream_ready, i_stream_valid = \
    await ctx.tick().sample(dut.o_stream.ready, dut.i_stream.valid)
    assert o_stream_ready == 1
    assert i_stream_valid == 0

    ctx.set(dut.o_stream.p.meta, 0b1100)
    _, _, o_stream_ready, i_stream_valid = \
    await ctx.tick().sample(dut.o_stream.ready, dut.i_stream.valid)
    assert o_stream_ready == 1
    assert i_stream_valid == 0
    ctx.set(dut.i_stream.ready, 0)
    assert ctx.get(dut.o_stream.ready) == 0

    ctx.set(dut_buffers_i, 0b0101)
    _, _, o_stream_ready, i_stream_valid, i_stream_p = \
    await ctx.tick().sample(dut.o_stream.ready, dut.i_stream.valid, dut.i_stream.p)
    assert o_stream_ready == 0
    assert i_stream_valid == 1
    assert i_stream_p.i == 0b0101
    assert i_stream_p.meta == 0b0101

    ctx.set(dut_buffers_i, 0b1100)
    _, _, o_stream_ready, i_stream_valid, i_stream_p = \
    await ctx.tick().sample(dut.o_stream.ready, dut.i_stream.valid, dut.i_stream.p)
    assert o_stream_ready == 0
    assert i_stream_valid == 1
    assert i_stream_p.i == 0b0101
    assert i_stream_p.meta == 0b0101

    ctx.set(dut.i_stream.ready, 1)
    _, _, o_stream_ready, i_stream_valid, i_stream_p = \
    await ctx.tick().sample(dut.o_stream.ready, dut.i_stream.valid, dut.i_stream.p)
    assert o_stream_ready == 0
    assert i_stream_valid == 1
    assert i_stream_p.i == 0b0101
    assert i_stream_p.meta == 0b0101

    _, _, o_stream_ready, i_stream_valid, i_stream_p = \
    await ctx.tick().sample(dut.o_stream.ready, dut.i_stream.valid, dut.i_stream.p)
    assert o_stream_ready == 0
    assert i_stream_valid == 1
    assert i_stream_p.i == 0b1100
    assert i_stream_p.meta == 0b1100

    ctx.set(dut.o_stream.p.meta, 0b1001)
    _, _, o_stream_ready, i_stream_valid, i_stream_p = \
    await ctx.tick().sample(dut.o_stream.ready, dut.i_stream.valid, dut.i_stream.p)
    assert o_stream_ready == 1
    assert i_stream_valid == 0

    _, _, o_stream_ready, i_stream_valid, i_stream_p = \
    await ctx.tick().sample(dut.o_stream.ready, dut.i_stream.valid, dut.i_stream.p)
    assert o_stream_ready == 1
    assert i_stream_valid == 0

    ctx.set(dut_buffers_i, 0b1001)
    _, _, o_stream_ready, i_stream_valid, i_stream_p = \
    await ctx.tick().sample(dut.o_stream.ready, dut.i_stream.valid, dut.i_stream.p)
    assert o_stream_ready == 1
    assert i_stream_valid == 1
    assert i_stream_p.i == 0b1001
    assert i_stream_p.meta == 0b1001

    sim = Simulator(dut)
    sim.add_clock(1e-6)
    sim.add_testbench(testbench)
    with sim.write_vcd("iostream_skid.vcd", fs_per_delta=1):
    sim.run()


    def test_qspi_enframer():
    dut = QSPIEnframer()

    async def testbench_in(ctx):
    async def data_put(*, data, mode):
    # amaranth-lang/amaranth#1413
    await stream_put(ctx, dut.octets, {"data": data, "mode": mode.value})

    await data_put(data=0xAA, mode=QSPIMode.PutX1)
    await data_put(data=0x55, mode=QSPIMode.PutX1)
    await data_put(data=0xC1, mode=QSPIMode.PutX1)

    await data_put(data=0xAA, mode=QSPIMode.PutX2)
    await data_put(data=0x55, mode=QSPIMode.PutX2)
    await data_put(data=0xC1, mode=QSPIMode.PutX2)

    await data_put(data=0xAA, mode=QSPIMode.PutX4)
    await data_put(data=0x55, mode=QSPIMode.PutX4)
    await data_put(data=0xC1, mode=QSPIMode.PutX4)

    await data_put(data=0, mode=QSPIMode.NopX2)
    await data_put(data=0, mode=QSPIMode.NopX4)

    await data_put(data=0, mode=QSPIMode.GetX1)
    await data_put(data=0, mode=QSPIMode.GetX2)
    await data_put(data=0, mode=QSPIMode.GetX4)

    async def testbench_out(ctx):
    async def bits_get(*, ox, oe, i_en, meta):
    for cycle, o in enumerate(ox):
    # amaranth-lang/amaranth#1413,#1414
    expected = Const({"o": o, "oe": oe, "i_en": i_en, "meta": meta.value},
    dut.frames.p.shape())
    assert (actual := await stream_get(ctx, dut.frames)) == expected, \
    f"(cycle {cycle}) {actual} != {expected}; o: {actual.o:04b} != {expected.o:04b}"

    await bits_get(ox=[1,0,1,0,1,0,1,0], oe=1, i_en=0, meta=QSPIMode.PutX1)
    await bits_get(ox=[0,1,0,1,0,1,0,1], oe=1, i_en=0, meta=QSPIMode.PutX1)
    await bits_get(ox=[1,1,0,0,0,0,0,1], oe=1, i_en=0, meta=QSPIMode.PutX1)

    await bits_get(ox=[0b10,0b10,0b10,0b10], oe=0b11, i_en=0, meta=QSPIMode.PutX2)
    await bits_get(ox=[0b01,0b01,0b01,0b01], oe=0b11, i_en=0, meta=QSPIMode.PutX2)
    await bits_get(ox=[0b11,0b00,0b00,0b01], oe=0b11, i_en=0, meta=QSPIMode.PutX2)

    await bits_get(ox=[0b1010,0b1010], oe=0b1111, i_en=0, meta=QSPIMode.PutX4)
    await bits_get(ox=[0b0101,0b0101], oe=0b1111, i_en=0, meta=QSPIMode.PutX4)
    await bits_get(ox=[0b1100,0b0001], oe=0b1111, i_en=0, meta=QSPIMode.PutX4)

    await bits_get(ox=[0,0,0,0], oe=0, i_en=0, meta=QSPIMode.NopX2)
    await bits_get(ox=[0,0], oe=0, i_en=0, meta=QSPIMode.NopX4)

    await bits_get(ox=[0,0,0,0,0,0,0,0], oe=1, i_en=1, meta=QSPIMode.GetX1)
    await bits_get(ox=[0,0,0,0], oe=0, i_en=1, meta=QSPIMode.GetX2)
    await bits_get(ox=[0,0], oe=0, i_en=1, meta=QSPIMode.GetX4)

    sim = Simulator(dut)
    sim.add_clock(1e-6)
    sim.add_testbench(testbench_in)
    sim.add_testbench(testbench_out)
    with sim.write_vcd("qspi_enframer.vcd"):
    sim.run()


    def test_qspi_deframer():
    dut = QSPIDeframer()

    async def testbench_in(ctx):
    async def bits_put(*, ix, meta):
    for cycle, i in enumerate(ix):
    # amaranth-lang/amaranth#1413
    await stream_put(ctx, dut.frames, {"i": i, "meta": meta.value})

    await bits_put(ix=[i<<1 for i in [1,0,1,0,1,0,1,0]], meta=QSPIMode.GetX1)
    await bits_put(ix=[i<<1 for i in [0,1,0,1,0,1,0,1]], meta=QSPIMode.GetX1)
    await bits_put(ix=[i<<1 for i in [1,1,0,0,0,0,0,1]], meta=QSPIMode.GetX1)

    await bits_put(ix=[0b10,0b10,0b10,0b10], meta=QSPIMode.GetX2)
    await bits_put(ix=[0b01,0b01,0b01,0b01], meta=QSPIMode.GetX2)
    await bits_put(ix=[0b11,0b00,0b00,0b01], meta=QSPIMode.GetX2)

    await bits_put(ix=[0b1010,0b1010], meta=QSPIMode.GetX4)
    await bits_put(ix=[0b0101,0b0101], meta=QSPIMode.GetX4)
    await bits_put(ix=[0b1100,0b0001], meta=QSPIMode.GetX4)

    async def testbench_out(ctx):
    async def data_get(*, data):
    # amaranth-lang/amaranth#1413,#1414
    expected = Const({"data": data},
    dut.octets.p.shape())
    assert (actual := await stream_get(ctx, dut.octets)) == expected, \
    f"{actual} != {expected}; data: {actual.data:08b} != {expected.data:08b}"

    await data_get(data=0xAA)
    await data_get(data=0x55)
    await data_get(data=0xC1)

    await data_get(data=0xAA)
    await data_get(data=0x55)
    await data_get(data=0xC1)

    await data_get(data=0xAA)
    await data_get(data=0x55)
    await data_get(data=0xC1)


    sim = Simulator(dut)
    sim.add_clock(1e-6)
    sim.add_testbench(testbench_in)
    sim.add_testbench(testbench_out)
    with sim.write_vcd("qspi_deframer.vcd"):
    sim.run()


    if __name__ == "__main__":
    test_iostream_basic()
    test_iostream_skid()
    test_qspi_enframer()
    test_qspi_deframer()