Skip to content

Commit

Permalink
cmsis_dap: Refactor to use amaranth.lib.wiring.
Browse files Browse the repository at this point in the history
  • Loading branch information
zyp committed Jan 13, 2024
1 parent dbdd8a5 commit da6002c
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 64 deletions.
89 changes: 49 additions & 40 deletions orbtrace/debug/cmsis_dap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from amaranth import *
from .dbgIF import DBGIF

from amaranth.lib.wiring import Component, In, Out

from zyp_amaranth_libs.stream import StreamSignature

from .dbgIF_wrapper import dbgIFSignature

from ..git_version import get_version

# Principle of operation
Expand Down Expand Up @@ -147,17 +153,22 @@ def elaborate(self, platform):
# This is the CMSIS-DAP handler itself
# ====================================

class CMSIS_DAP(Elaboratable):
def __init__(self, streamIn, streamOut, dbgif, v2Indication):
# External interface (generally LEDs)
self.running = Signal() # Flag for if target is running
self.connected = Signal() # Flag for if target is connected
self.can = Signal() # Canary
class CMSIS_DAP(Component):
# External interface (generally LEDs)
running: Out(1) # Flag for if target is running
connected: Out(1) # Flag for if target is connected
can: Out(1) # Canary

# Nature of the USB connection
self.isV2 = v2Indication
self.streamIn = streamIn
self.streamOut = streamOut
# Nature of the USB connection
isV2: In(1)
streamIn: In(StreamSignature(8, first = True, last = True))
streamOut: Out(StreamSignature(8, first = True, last = True))

# Debug interface
dbgif: Out(dbgIFSignature)

def __init__(self):
super().__init__()

# Receive block construction
self.rxBlock = Signal( 7*8 ) # Longest message we pickup is 6 bytes + command
Expand Down Expand Up @@ -219,8 +230,6 @@ def __init__(self, streamIn, streamOut, dbgif, v2Indication):
self.waitRetry = Signal(16,reset=4096) # Number of transfer retries after WAIT response
self.matchRetry = Signal(16,reset=16) # Number of retries on reads with Value Match in DAP_Transfer

self.dbgif = dbgif

# ----------------------------------------------------------------------------------
def RESP_Invalid(self, m):
# Simply transmit an 'invalid' packet back
Expand Down Expand Up @@ -424,7 +433,7 @@ def RESP_SWJ_Sequence_Process(self, m):
with m.Case(0):
with m.If(self.streamOut.valid & self.streamOut.ready):
m.d.sync += [
self.tfrData.eq(self.streamOut.payload),
self.tfrData.eq(self.streamOut.data),
self.swj_txb.eq(1)
]
with m.Else():
Expand Down Expand Up @@ -496,7 +505,7 @@ def RESP_JTAG_Configure_Process(self, m):
# Collect octets representing the irlength for each member of the chain
with m.If(self.streamOut.valid & self.streamOut.ready):
m.d.sync += [
self.dbgif.dwrite.bit_select( self.jtag_ircount,5 ).eq(self.streamOut.payload.bit_select(0,5)),
self.dbgif.dwrite.bit_select( self.jtag_ircount,5 ).eq(self.streamOut.data.bit_select(0,5)),
self.jtag_ircount.eq(self.jtag_ircount+5)
]
with m.If(self.streamOut.last):
Expand Down Expand Up @@ -626,18 +635,18 @@ def RESP_Transfer_Process(self, m):
# Rule for JTAG is any read, for SWD it's any AP read
# If it doesn't then we need to collect these data before progressing
with m.If(self.readDelay &
((self.isJTAG & (~self.streamOut.payload.bit_select(1,1))) |
(~self.isJTAG & (self.streamOut.payload.bit_select(0,2)!=3)))):
((self.isJTAG & (~self.streamOut.data.bit_select(1,1))) |
(~self.isJTAG & (self.streamOut.data.bit_select(0,2)!=3)))):
m.d.sync += [
self.tfr_txb.eq(6),
self.tfrReq.eq(0x0e),
self.readDelay.eq(0),
self.readAgain.eq(1),
self.PendPayload.eq(self.streamOut.payload),
self.PendPayload.eq(self.streamOut.data),
]
with m.Else():
m.d.sync += [
self.tfrReq.eq(self.streamOut.payload),
self.tfrReq.eq(self.streamOut.data),
self.tfr_txb.eq(1)
]

Expand Down Expand Up @@ -665,14 +674,14 @@ def RESP_Transfer_Process(self, m):
with m.If(self.streamOut.valid & self.streamOut.ready):
m.d.sync+=[
# Beware, state used to select byte in word construction
self.tfrData.word_select((self.tfr_txb.bit_select(0,3)-2).as_unsigned(),8).eq(self.streamOut.payload),
self.tfrData.word_select((self.tfr_txb.bit_select(0,3)-2).as_unsigned(),8).eq(self.streamOut.data),
self.tfr_txb.eq(self.tfr_txb+1)
]

with m.If(self.tfrReq.bit_select(5,1) & (self.tfr_txb==6)):
# This is a match register write
m.d.sync += [
self.mask.eq(Cat(self.streamOut.payload,self.tfrData.bit_select(0,24))),
self.mask.eq(Cat(self.streamOut.data,self.tfrData.bit_select(0,24))),
self.tfr_txb.eq(0)
]
with m.Else():
Expand Down Expand Up @@ -747,7 +756,7 @@ def RESP_Transfer_Process(self, m):
with m.Case(10,11,12):
with m.If(self.streamIn.ready):
m.d.sync += [
self.streamIn.payload.eq(self.txBlock.word_select((self.tfr_txb-10).as_unsigned(),8)),
self.streamIn.data.eq(self.txBlock.word_select((self.tfr_txb-10).as_unsigned(),8)),
self.streamIn.valid.eq(1),
self.tfr_txb.eq(self.tfr_txb+1),
self.streamIn.last.eq(self.isV2 & (self.tfr_txb==12) & (self.tfrram.adr==0))
Expand Down Expand Up @@ -821,7 +830,7 @@ def RESP_TransferBlock_Process(self, m):
with m.Case(0,1,2,3):
with m.If(self.streamOut.ready & self.streamOut.valid):
m.d.sync+=[
self.dbgif.dwrite.word_select(self.tfB_txb,8).eq(self.streamOut.payload),
self.dbgif.dwrite.word_select(self.tfB_txb,8).eq(self.streamOut.data),
self.tfB_txb.eq(self.tfB_txb+1),
]
with m.Else():
Expand Down Expand Up @@ -887,7 +896,7 @@ def RESP_TransferBlock_Process(self, m):
with m.If(self.streamIn.ready):
m.d.sync += [
# Beware, we use the bottom two bits of the state to select the byte to return
self.streamIn.payload.eq(self.txBlock.word_select(self.tfB_txb.bit_select(0,2),8)),
self.streamIn.data.eq(self.txBlock.word_select(self.tfB_txb.bit_select(0,2),8)),
self.streamIn.valid.eq(1),
self.tfB_txb.eq(self.tfB_txb+1),
# End of transfer if there are no data to return
Expand Down Expand Up @@ -931,7 +940,7 @@ def RESP_Transfer_Complete(self, m):
with m.Case(2):
m.d.sync += [
self.transferCCount.eq(self.transferCCount-1),
self.streamIn.payload.eq(self.tfrram.dat_r.word_select(0,8)),
self.streamIn.data.eq(self.tfrram.dat_r.word_select(0,8)),
self.txb.eq(3)
]

Expand All @@ -941,7 +950,7 @@ def RESP_Transfer_Complete(self, m):
with m.If(self.streamIn.ready & self.streamIn.valid):
m.d.sync += [
self.txb.eq(self.txb+1),
self.streamIn.payload.eq(self.tfrram.dat_r.word_select((self.txb-2).as_unsigned(),8)),
self.streamIn.data.eq(self.tfrram.dat_r.word_select((self.txb-2).as_unsigned(),8)),
# 5 because of pipeline
self.streamIn.last.eq(self.isV2 & (~self.transferCCount.bool()) & (self.txb==5)),
self.streamIn.valid.eq(self.txb!=6)
Expand Down Expand Up @@ -997,7 +1006,7 @@ def RESP_Sequence_PROCESS(self,m):
with m.If(self.streamIn.ready):
m.d.sync += [
# Send frontmatter for reponse
self.streamIn.payload.eq(Mux(self.isJTAG,DAP_JTAG_Sequence,DAP_SWD_Sequence)),
self.streamIn.data.eq(Mux(self.isJTAG,DAP_JTAG_Sequence,DAP_SWD_Sequence)),
self.streamIn.last.eq(0),
self.streamIn.valid.eq(1),

Expand All @@ -1013,19 +1022,19 @@ def RESP_Sequence_PROCESS(self,m):
with m.If(self.streamOut.ready & self.streamOut.valid):
m.d.sync += [
self.seqCount.eq(self.seqCount-1),
self.seqckCycles.eq(self.streamOut.payload.bit_select(0,6)),
self.seqckCycles.eq(self.streamOut.data.bit_select(0,6)),

# If we're reading then we don't want to write to SWD, we do if it's JTAG (TMS)
self.dbgif.pinsin[4].eq(Mux(self.isJTAG,1,~self.streamOut.payload[7])),
self.dbgif.pinsin[4].eq(Mux(self.isJTAG,1,~self.streamOut.data[7])),
self.seqIn.eq(0),

# Decide on correct state to move to and setup output for read or write condition
self.seqIsRead.eq(self.streamOut.payload[7]),
self.seq_txb.eq(Mux(self.isJTAG,2,Mux(self.streamOut.payload[7],3,2)))
self.seqIsRead.eq(self.streamOut.data[7]),
self.seq_txb.eq(Mux(self.isJTAG,2,Mux(self.streamOut.data[7],3,2)))
]
# If this is a JTAG sequence then set TMS
with m.If(self.isJTAG):
m.d.sync += self.dbgif.pinsin.bit_select(1,1).eq(self.streamOut.payload[6])
m.d.sync += self.dbgif.pinsin.bit_select(1,1).eq(self.streamOut.data[6])
with m.Else():
# If we're showing ~valid then this packet is foreshortened
with m.If(~self.streamOut.valid):
Expand All @@ -1037,7 +1046,7 @@ def RESP_Sequence_PROCESS(self,m):
with m.Case(2):
with m.If(self.streamOut.ready & self.streamOut.valid):
m.d.sync += [
self.seqOut.eq(self.streamOut.payload),
self.seqOut.eq(self.streamOut.data),
self.bitCount.eq(0),
self.seq_txb.eq(3)
]
Expand Down Expand Up @@ -1104,7 +1113,7 @@ def RESP_Sequence_PROCESS(self,m):

with m.If(self.seqIsRead):
m.d.sync += [
self.streamIn.payload.eq(self.seqPendingTX),
self.streamIn.data.eq(self.seqPendingTX),
self.streamIn.valid.eq(1),
self.seqPendingTX.eq(self.seqIn),
self.txedLen.eq(self.txedLen+1),
Expand All @@ -1117,7 +1126,7 @@ def RESP_Sequence_PROCESS(self,m):
with m.Case(7):
with m.If(self.streamIn.ready):
m.d.sync += [
self.streamIn.payload.eq(self.seqPendingTX),
self.streamIn.data.eq(self.seqPendingTX),
self.streamIn.last.eq(self.isV2),
self.streamIn.valid.eq(1),
self.seq_txb.eq(8)
Expand Down Expand Up @@ -1168,12 +1177,12 @@ def elaborate(self,platform):
self.txedLen.eq(0),

# Default return is packet name followed by 0 (no error)
self.txBlock.word_select(0,16).eq(Cat(self.streamOut.payload,C(0,8))),
self.txBlock.word_select(0,16).eq(Cat(self.streamOut.data,C(0,8))),
self.txLen.eq(2),

# Grab incoming from usb
self.rxedLen.eq(1),
self.rxBlock.word_select(0,8).eq(self.streamOut.payload),
self.rxBlock.word_select(0,8).eq(self.streamOut.data),
]

# Only process if this is the start of a packet (i.e. it's not overrrun or similar)
Expand Down Expand Up @@ -1229,7 +1238,7 @@ def elaborate(self,platform):
# Grab next byte in this packet
with m.If(self.streamOut.valid & self.streamOut.ready):
m.d.sync += [
self.rxBlock.word_select(self.rxedLen,8).eq(self.streamOut.payload),
self.rxBlock.word_select(self.rxedLen,8).eq(self.streamOut.data),
self.rxedLen.eq(self.rxedLen+1)
]

Expand Down Expand Up @@ -1377,7 +1386,7 @@ def elaborate(self,platform):
with m.State('RESPOND'):
m.d.sync += [
self.streamIn.valid.eq(self.txedLen<self.txLen),
self.streamIn.payload.eq(self.txBlock.word_select(self.txedLen,8)),
self.streamIn.data.eq(self.txBlock.word_select(self.txedLen,8)),

# This is the end of the packet if we've filled the length and it's v2
self.streamIn.last.eq(self.isV2 & (self.txedLen==self.txLen-1))
Expand All @@ -1399,7 +1408,7 @@ def elaborate(self,platform):
with m.State('V1PACKETFILL'):
m.d.sync += [
self.streamIn.valid.eq(self.txedLen<DAP_V1_MAX_PACKET_SIZE),
self.streamIn.payload.eq(0),
self.streamIn.data.eq(0),
]

with m.If(self.streamIn.ready & self.streamIn.valid):
Expand Down
40 changes: 16 additions & 24 deletions orbtrace/debug/cmsis_dap_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from migen import *

import amaranth
from amaranth.hdl.rec import DIR_FANIN, DIR_FANOUT, DIR_NONE
from amaranth.lib import wiring

from luna.gateware.stream import StreamInterface

from . import cmsis_dap, dbgIF, dbgIF_wrapper
from . import cmsis_dap, dbgIF_wrapper

from litex.soc.interconnect.stream import Endpoint

from litex.build.io import SDRInput, DDRInput, SDROutput, SDRTristate

class CMSIS_DAP(Module):
def __init__(self, dbgif, glue):
self.source = Endpoint([('data', 8)])
Expand All @@ -20,30 +15,27 @@ def __init__(self, dbgif, glue):
self.connected = Signal()
self.running = Signal()

stream_in = StreamInterface()
stream_out = StreamInterface()

is_v2 = amaranth.Signal()

dbgif_wrapper = dbgIF_wrapper.DBGIF(dbgif, glue)
glue.m.submodules += dbgif_wrapper

dap = cmsis_dap.CMSIS_DAP(stream_in, stream_out, dbgif_wrapper, is_v2)
dap = cmsis_dap.CMSIS_DAP()
glue.m.submodules += dap

glue.connect(self.source.data, stream_in.payload)
glue.connect(self.source.first, stream_in.first)
glue.connect(self.source.last, stream_in.last)
glue.connect(self.source.valid, stream_in.valid)
glue.connect(self.source.ready, stream_in.ready)
wiring.connect(glue.m, dbgif_wrapper, dap.dbgif)

glue.connect(self.source.data, dap.streamIn.data)
glue.connect(self.source.first, dap.streamIn.first)
glue.connect(self.source.last, dap.streamIn.last)
glue.connect(self.source.valid, dap.streamIn.valid)
glue.connect(self.source.ready, dap.streamIn.ready)

glue.connect(self.sink.data, stream_out.payload)
glue.connect(self.sink.first, stream_out.first)
glue.connect(self.sink.last, stream_out.last)
glue.connect(self.sink.valid, stream_out.valid)
glue.connect(self.sink.ready, stream_out.ready)
glue.connect(self.sink.data, dap.streamOut.data)
glue.connect(self.sink.first, dap.streamOut.first)
glue.connect(self.sink.last, dap.streamOut.last)
glue.connect(self.sink.valid, dap.streamOut.valid)
glue.connect(self.sink.ready, dap.streamOut.ready)

glue.connect(self.is_v2, is_v2)
glue.connect(self.is_v2, dap.isV2)

glue.connect(self.can, dap.can)

Expand Down
21 changes: 21 additions & 0 deletions orbtrace/debug/dbgIF_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
from amaranth import *

from amaranth.lib.wiring import In, Out, Signature

dbgIFSignature = Signature({
'addr32': Out(2),
'rnw': Out(1),
'apndp': Out(1),
'dwrite': Out(32),
'dread': In(32),
'perr': In(1),
'go': Out(1),
'done': In(1),
'ack': In(3),
'pinsin': Out(16),
'pinsout': In(8),
'command': Out(5),
'dev': Out(3),
'is_jtag': Out(1),
})

class DBGIF(Elaboratable):
signature = dbgIFSignature.flip()

def __init__(self, dbgif, glue):
self.addr32 = glue.from_migen(dbgif.addr32)
self.rnw = glue.from_migen(dbgif.rnw)
Expand Down

0 comments on commit da6002c

Please sign in to comment.