Skip to content

Commit

Permalink
add rep in common testdial
Browse files Browse the repository at this point in the history
  • Loading branch information
apiresatos committed Dec 25, 2019
1 parent e6a7068 commit 8156cbe
Showing 1 changed file with 47 additions and 83 deletions.
130 changes: 47 additions & 83 deletions nftables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func cleanupSystemNFTConn(t *testing.T, newNS netns.NsHandle) {
}
}

func CheckNLReq(t *testing.T, wantMsg [][]byte, wantFlags []netlink.HeaderFlags) nltest.Func {
func CheckNLReq(t *testing.T, wantMsg [][]byte, wantFlags []netlink.HeaderFlags, replies [][]netlink.Message) nltest.Func {
return func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
Expand All @@ -132,13 +132,21 @@ func CheckNLReq(t *testing.T, wantMsg [][]byte, wantFlags []netlink.HeaderFlags)
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
}

if len(wantFlags) > 0 && msg.Header.Flags != wantFlags[idx] {
t.Errorf("message %d: %s", idx, linediff(msg.Header.Flags.String(), wantFlags[idx].String()))
if len(wantFlags) > 0 {
if msg.Header.Flags != wantFlags[idx] {
t.Errorf("message %d: %s", idx, linediff(msg.Header.Flags.String(), wantFlags[idx].String()))
}
}

wantMsg = wantMsg[1:]
}
return req, nil

if len(replies) > 0 {
rep := replies[0]
replies = replies[1:]
return rep, nil
} else {
return req, nil
}
}
}

Expand Down Expand Up @@ -170,7 +178,7 @@ func TestConfigureNAT(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -364,7 +372,7 @@ func TestConfigureNATSourceAddress(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -427,29 +435,7 @@ func TestGetRule(t *testing.T) {
}

c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
t.Fatal(err)
}
if len(b) < 16 {
continue
}
b = b[16:]
if len(want) == 0 {
t.Errorf("no want entry for message %d: %x", idx, b)
continue
}
if got, want := b, want[0]; !bytes.Equal(got, want) {
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
}
want = want[1:]
}
rep := reply[0]
reply = reply[1:]
return rep, nil
},
TestDial: CheckNLReq(t, want, nil, reply),
}

rules, err := c.GetRule(
Expand Down Expand Up @@ -506,7 +492,7 @@ func TestAddCounter(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.AddObj(&nftables.CounterObj{
Expand Down Expand Up @@ -543,7 +529,7 @@ func TestDelRule(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.DelRule(&nftables.Rule{
Expand All @@ -568,7 +554,7 @@ func TestLog(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.AddRule(&nftables.Rule{
Expand Down Expand Up @@ -598,7 +584,7 @@ func TestTProxy(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.AddRule(&nftables.Rule{
Expand Down Expand Up @@ -641,7 +627,7 @@ func TestCt(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.AddRule(&nftables.Rule{
Expand Down Expand Up @@ -674,7 +660,7 @@ func TestCtSet(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.AddRule(&nftables.Rule{
Expand Down Expand Up @@ -713,7 +699,7 @@ func TestAddRuleWithPosition(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.AddRule(&nftables.Rule{
Expand Down Expand Up @@ -801,7 +787,7 @@ func TestAddChain(t *testing.T) {

for _, tt := range tests {
c := &nftables.Conn{
TestDial: CheckNLReq(t, tt.want, nil),
TestDial: CheckNLReq(t, tt.want, nil, nil),
}

filter := c.AddTable(&nftables.Table{
Expand Down Expand Up @@ -858,7 +844,7 @@ func TestDelChain(t *testing.T) {

for _, tt := range tests {
c := &nftables.Conn{
TestDial: CheckNLReq(t, tt.want, nil),
TestDial: CheckNLReq(t, tt.want, nil, nil),
}

tt.chain.Table = &nftables.Table{
Expand Down Expand Up @@ -888,29 +874,7 @@ func TestGetObjReset(t *testing.T) {
}

c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
for idx, msg := range req {
b, err := msg.MarshalBinary()
if err != nil {
t.Fatal(err)
}
if len(b) < 16 {
continue
}
b = b[16:]
if len(want) == 0 {
t.Errorf("no want entry for message %d: %x", idx, b)
continue
}
if got, want := b, want[0]; !bytes.Equal(got, want) {
t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
}
want = want[1:]
}
rep := reply[0]
reply = reply[1:]
return rep, nil
},
TestDial: CheckNLReq(t, want, nil, reply),
}

filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}
Expand Down Expand Up @@ -968,7 +932,7 @@ func TestConfigureClamping(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -1081,7 +1045,7 @@ func TestDropVerdict(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -1161,7 +1125,7 @@ func TestCreateUseAnonymousSet(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -1898,7 +1862,7 @@ func TestConfigureNATRedirect(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -1983,7 +1947,7 @@ func TestConfigureJumpVerdict(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -2069,7 +2033,7 @@ func TestConfigureReturnVerdict(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -2134,7 +2098,7 @@ func TestConfigureRangePort(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -2212,7 +2176,7 @@ func TestConfigureRangeIPv4(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -2282,7 +2246,7 @@ func TestConfigureRangeIPv6(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -2375,7 +2339,7 @@ func TestSet4(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

tbl := &nftables.Table{
Expand Down Expand Up @@ -2541,7 +2505,7 @@ func TestMasq(t *testing.T) {

for _, tt := range tests {
c := &nftables.Conn{
TestDial: CheckNLReq(t, tt.want, nil),
TestDial: CheckNLReq(t, tt.want, nil, nil),
}

filter := c.AddTable(&nftables.Table{
Expand Down Expand Up @@ -2652,7 +2616,7 @@ func TestReject(t *testing.T) {

for _, tt := range tests {
c := &nftables.Conn{
TestDial: CheckNLReq(t, tt.want, nil),
TestDial: CheckNLReq(t, tt.want, nil, nil),
}

filter := c.AddTable(&nftables.Table{
Expand Down Expand Up @@ -2760,7 +2724,7 @@ func TestFib(t *testing.T) {

for _, tt := range tests {
c := &nftables.Conn{
TestDial: CheckNLReq(t, tt.want, nil),
TestDial: CheckNLReq(t, tt.want, nil, nil),
}

filter := c.AddTable(&nftables.Table{
Expand Down Expand Up @@ -2844,7 +2808,7 @@ func TestNumgen(t *testing.T) {

for _, tt := range tests {
c := &nftables.Conn{
TestDial: CheckNLReq(t, tt.want, nil),
TestDial: CheckNLReq(t, tt.want, nil, nil),
}

filter := c.AddTable(&nftables.Table{
Expand Down Expand Up @@ -2909,7 +2873,7 @@ func TestMap(t *testing.T) {

for _, tt := range tests {
c := &nftables.Conn{
TestDial: CheckNLReq(t, tt.want, nil),
TestDial: CheckNLReq(t, tt.want, nil, nil),
}

filter := c.AddTable(&nftables.Table{
Expand Down Expand Up @@ -3007,7 +2971,7 @@ func TestVmap(t *testing.T) {

for _, tt := range tests {
c := &nftables.Conn{
TestDial: CheckNLReq(t, tt.want, nil),
TestDial: CheckNLReq(t, tt.want, nil, nil),
}

filter := c.AddTable(&nftables.Table{
Expand Down Expand Up @@ -3047,7 +3011,7 @@ func TestJHash(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -3128,7 +3092,7 @@ func TestDup(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -3208,7 +3172,7 @@ func TestDupWoDev(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -3270,7 +3234,7 @@ func TestNotrack(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down Expand Up @@ -3324,7 +3288,7 @@ func TestStatelessNAT(t *testing.T) {
}

c := &nftables.Conn{
TestDial: CheckNLReq(t, want, nil),
TestDial: CheckNLReq(t, want, nil, nil),
}

c.FlushRuleset()
Expand Down

0 comments on commit 8156cbe

Please sign in to comment.