Skip to content

Commit

Permalink
mark filter fix (#54)
Browse files Browse the repository at this point in the history
Co-authored-by: Florian Lehner <[email protected]>
  • Loading branch information
dvomartin and florianl authored Nov 24, 2023
1 parent 176ac90 commit 45c7b44
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 9 deletions.
143 changes: 134 additions & 9 deletions bpf.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package conntrack
import (
"encoding/binary"
"errors"
"sort"

"fmt"
"github.com/florianl/go-conntrack/internal/unix"
"golang.org/x/net/bpf"
"sort"
)

// Various errors which may occur when processing filters
Expand All @@ -22,6 +22,8 @@ var (
const (
bpfMAXINSTR = 4096

failedJump = uint8(255)

bpfVerdictAccept = 0xffffffff
bpfVerdictReject = 0x00000000
)
Expand Down Expand Up @@ -185,7 +187,6 @@ func compareValues(filters []ConnAttr) []bpf.RawInstruction {
func filterAttribute(filters []ConnAttr) []bpf.RawInstruction {
var raw []bpf.RawInstruction
nested := len(filterCheck[filters[0].Type].nest)
failed := uint8(255)

// sizeof(nlmsghdr) + sizeof(nfgenmsg) = 20
tmp := bpf.RawInstruction{Op: unix.BPF_LD | unix.BPF_IMM, K: 0x14}
Expand All @@ -200,7 +201,7 @@ func filterAttribute(filters []ConnAttr) []bpf.RawInstruction {
raw = append(raw, tmp)

// jump, if nest not found
tmp = bpf.RawInstruction{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, K: 0, Jt: failed}
tmp = bpf.RawInstruction{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, K: 0, Jt: failedJump}
raw = append(raw, tmp)

tmp = bpf.RawInstruction{Op: unix.BPF_ALU | unix.BPF_ADD | unix.BPF_K, K: 4}
Expand All @@ -214,7 +215,7 @@ func filterAttribute(filters []ConnAttr) []bpf.RawInstruction {
tmp = bpf.RawInstruction{Op: unix.BPF_LD | unix.BPF_B | unix.BPF_ABS, K: 0xfffff00c}
raw = append(raw, tmp)

tmp = bpf.RawInstruction{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, K: 0, Jt: failed}
tmp = bpf.RawInstruction{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, K: 0, Jt: failedJump}
raw = append(raw, tmp)

tmp = bpf.RawInstruction{Op: unix.BPF_MISC | unix.BPF_TAX}
Expand All @@ -234,9 +235,9 @@ func filterAttribute(filters []ConnAttr) []bpf.RawInstruction {
// Failed jumps are set to 255. Now we correct them to the actual failed jump instruction
j := uint8(1)
for i := len(raw) - 1; i > 0; i-- {
if (raw[i].Jt == 255) && (raw[i].Op == unix.BPF_JMP|unix.BPF_JEQ|unix.BPF_K) {
if (raw[i].Jt == failedJump) && (raw[i].Op == unix.BPF_JMP|unix.BPF_JEQ|unix.BPF_K) {
raw[i].Jt = j - jump
} else if (raw[i].Jf == 255) && (raw[i].Op == unix.BPF_JMP|unix.BPF_JEQ|unix.BPF_K) {
} else if (raw[i].Jf == failedJump) && (raw[i].Op == unix.BPF_JMP|unix.BPF_JEQ|unix.BPF_K) {
raw[i].Jf = j - 1
}
j++
Expand Down Expand Up @@ -302,7 +303,12 @@ func constructFilter(subsys Table, filters []ConnAttr) ([]bpf.RawInstruction, er
// We can not simple range over the map, because the order of selected items can vary
for key := 0; key <= int(attrMax); key++ {
if x, ok := filterMap[ConnAttrType(key)]; ok {
tmp = filterAttribute(x)
switch key {
case int(AttrMark):
tmp = filterMarkAttribute(x)
default:
tmp = filterAttribute(x)
}
raw = append(raw, tmp...)
}
}
Expand All @@ -317,15 +323,134 @@ func constructFilter(subsys Table, filters []ConnAttr) ([]bpf.RawInstruction, er
return raw, nil
}

func (nfct *Nfct) attachFilter(subsys Table, filters []ConnAttr) error {
func filterMarkAttribute(filters []ConnAttr) []bpf.RawInstruction {
var raw []bpf.RawInstruction

// sizeof(nlmsghdr) + sizeof(nfgenmsg) = 20
tmp := bpf.RawInstruction{Op: unix.BPF_LD | unix.BPF_IMM, K: 0x14}
raw = append(raw, tmp)

// find final attribute
tmp = bpf.RawInstruction{Op: unix.BPF_LDX | unix.BPF_IMM, K: uint32(filterCheck[filters[0].Type].ct)}
raw = append(raw, tmp)
tmp = bpf.RawInstruction{Op: unix.BPF_LD | unix.BPF_B | unix.BPF_ABS, K: 0xfffff00c}
raw = append(raw, tmp)

tmp = bpf.RawInstruction{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, K: 0, Jt: 2}
raw = append(raw, tmp)

tmp = bpf.RawInstruction{Op: unix.BPF_MISC | unix.BPF_TAX}
raw = append(raw, tmp)

tmp = bpf.RawInstruction{Op: unix.BPF_LD | unix.BPF_W | unix.BPF_IND, K: uint32(len(filters[0].Data))}
raw = append(raw, tmp)

tmp = bpf.RawInstruction{Op: unix.BPF_MISC | unix.BPF_TAX}
raw = append(raw, tmp)

for _, filter := range filters {
var dataLen = len(filter.Data)
for i := 0; i < (int(dataLen) / 4); i++ {
mask := encodeValue(filter.Mask[i*4 : (i+1)*4])
tmp = bpf.RawInstruction{Op: unix.BPF_ALU | unix.BPF_AND | unix.BPF_K, K: mask}
raw = append(raw, tmp)
val := encodeValue(filter.Data[i*4 : (i+1)*4])
val &= mask
tmp = bpf.RawInstruction{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, K: val, Jt: failedJump}
raw = append(raw, tmp)
tmp = bpf.RawInstruction{Op: unix.BPF_MISC | unix.BPF_TXA}
raw = append(raw, tmp)
}
}

var j uint8 = 1

// Failed jumps are set to 255. Now we correct them to the actual failed jump instruction
for i := len(raw) - 1; i > 0; i-- {
if (raw[i].Jt == failedJump) && (raw[i].Op == unix.BPF_JMP|unix.BPF_JEQ|unix.BPF_K) {
raw[i].Jt = j
}
j++
}

// negate filter
if filters[0].Negate {
raw = append(raw, bpf.RawInstruction{Op: unix.BPF_JMP | unix.BPF_JA, K: 1})
}

// reject
raw = append(raw, bpf.RawInstruction{Op: unix.BPF_RET | unix.BPF_K, K: bpfVerdictReject})

return raw
}

func (nfct *Nfct) attachFilter(subsys Table, filters []ConnAttr) error {
bpfFilters, err := constructFilter(subsys, filters)
if err != nil {
return err
}
if nfct.debug {
fmtInstructions := fmtRawInstructions(bpfFilters)
nfct.logger.Println("---BPF filter start---")
nfct.logger.Print(fmtInstructions)
nfct.logger.Println("---BPF filter end---")
}

return nfct.Con.SetBPF(bpfFilters)
}

func (nfct *Nfct) removeFilter() error {
return nfct.Con.RemoveBPF()
}

func fmtRawInstruction(index int, raw bpf.RawInstruction) string {
code := code2str(raw.Op & 0xFFFF)
return fmt.Sprintf("(%.4x) code=%30s\tjt=%.2x jf=%.2x k=%.8x\n",
index,
code,
raw.Jt&0xFF,
raw.Jf&0xFF,
raw.K&0xFFFFFFFF)
}

func fmtRawInstructions(raw []bpf.RawInstruction) string {
var output string

for i, instr := range raw {
output += fmtRawInstruction(i, instr)
}

return output
}

func code2str(op uint16) string {
switch op {
case unix.BPF_LD | unix.BPF_IMM:
return "BPF_LD|BPF_IMM"
case unix.BPF_LDX | unix.BPF_IMM:
return "BPF_LDX|BPF_IMM"
case unix.BPF_LD | unix.BPF_B | unix.BPF_ABS:
return "BPF_LD|BPF_B|BPF_ABS"
case unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K:
return "BPF_JMP|BPF_JEQ|BPF_K"
case unix.BPF_ALU | unix.BPF_AND | unix.BPF_K:
return "BPF_ALU|BPF_AND|BPF_K"
case unix.BPF_JMP | unix.BPF_JA:
return "BPF_JMP|BPF_JA"
case unix.BPF_RET | unix.BPF_K:
return "BPF_RET|BPF_K"
case unix.BPF_ALU | unix.BPF_ADD | unix.BPF_K:
return "BPF_ALU|BPF_ADD|BPF_K"
case unix.BPF_MISC | unix.BPF_TAX:
return "BPF_MISC|BPF_TAX"
case unix.BPF_MISC | unix.BPF_TXA:
return "BPF_MISC|BPF_TXA"
case unix.BPF_LD | unix.BPF_B | unix.BPF_IND:
return "BPF_LD|BPF_B|BPF_IND"
case unix.BPF_LD | unix.BPF_H | unix.BPF_IND:
return "BPF_LD|BPF_H|BPF_IND"
case unix.BPF_LD | unix.BPF_W | unix.BPF_IND:
return "BPF_LD|BPF_W|BPF_IND"
}
return "UNKNOWN_INSTRUCTION"
}
130 changes: 130 additions & 0 deletions bpf_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package conntrack

import (
"encoding/binary"
"errors"
"github.com/florianl/go-conntrack/internal/unix"
"testing"

"golang.org/x/net/bpf"
Expand Down Expand Up @@ -170,3 +173,130 @@ func TestConstructFilter(t *testing.T) {
})
}
}

func TestAttrMarkFilter(t *testing.T) {
mark1ByteValue := make([]byte, 4)
binary.BigEndian.PutUint32(mark1ByteValue, 1)
mark10ByteValue := make([]byte, 4)
binary.BigEndian.PutUint32(mark10ByteValue, 10)
mark11ByteValue := make([]byte, 4)
binary.BigEndian.PutUint32(mark11ByteValue, 11)
mark50ByteValue := make([]byte, 4)
binary.BigEndian.PutUint32(mark50ByteValue, 50)
mark1000ByteValue := make([]byte, 4)
binary.BigEndian.PutUint32(mark1000ByteValue, 1000)

tests := []struct {
name string
table Table
filters []ConnAttr
rawInstr []bpf.RawInstruction
err error
}{
{name: "mark positive filter: [1]", table: Conntrack, filters: []ConnAttr{
{Type: AttrMark, Data: mark1ByteValue, Mask: []byte{255, 255, 255, 255}, Negate: false},
}, rawInstr: []bpf.RawInstruction{
//--- check subsys ---
{Op: unix.BPF_LDX | unix.BPF_IMM, Jt: 0x00, Jf: 0x00, K: 0x00000004},
{Op: unix.BPF_LD | unix.BPF_B | unix.BPF_IND, Jt: 0x00, Jf: 0x00, K: 0x00000001},
{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, Jt: 0x01, Jf: 0x00, K: 0x00000001},
{Op: unix.BPF_RET | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0xffffffff},
//--- check mark ---
{Op: unix.BPF_LD | unix.BPF_IMM, Jt: 0x00, Jf: 0x00, K: 0x00000014},
{Op: unix.BPF_LDX | unix.BPF_IMM, Jt: 0x00, Jf: 0x00, K: 0x00000008},
{Op: unix.BPF_LD | unix.BPF_B | unix.BPF_ABS, Jt: 0x00, Jf: 0x00, K: 0xfffff00c},
{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, Jt: 0x02, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_MISC | unix.BPF_TAX, Jt: 0x00, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_LD | unix.BPF_W | unix.BPF_IND, Jt: 0x00, Jf: 0x00, K: 0x00000004},
{Op: unix.BPF_MISC | unix.BPF_TAX, Jt: 0x00, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_ALU | unix.BPF_AND | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0xffffffff},
{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, Jt: 0x02, Jf: 0x00, K: 0x00000001},
{Op: unix.BPF_MISC | unix.BPF_TXA, Jt: 0x00, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_RET | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0x00000000},
//---- final verdict ----
{Op: unix.BPF_RET | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0xffffffff},
}},
{name: "mark positive filter: [10,50,1000]", table: Conntrack, filters: []ConnAttr{
{Type: AttrMark, Data: mark10ByteValue, Mask: []byte{255, 255, 255, 255}, Negate: false},
{Type: AttrMark, Data: mark50ByteValue, Mask: []byte{255, 255, 255, 255}, Negate: false},
{Type: AttrMark, Data: mark1000ByteValue, Mask: []byte{255, 255, 255, 255}, Negate: false},
}, rawInstr: []bpf.RawInstruction{
//--- check subsys ---
{Op: unix.BPF_LDX | unix.BPF_IMM, Jt: 0x00, Jf: 0x00, K: 0x00000004},
{Op: unix.BPF_LD | unix.BPF_B | unix.BPF_IND, Jt: 0x00, Jf: 0x00, K: 0x00000001},
{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, Jt: 0x01, Jf: 0x00, K: 0x00000001},
{Op: unix.BPF_RET | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0xffffffff},
//--- check mark ---
{Op: unix.BPF_LD | unix.BPF_IMM, Jt: 0x00, Jf: 0x00, K: 0x00000014},
{Op: unix.BPF_LDX | unix.BPF_IMM, Jt: 0x00, Jf: 0x00, K: 0x00000008},
{Op: unix.BPF_LD | unix.BPF_B | unix.BPF_ABS, Jt: 0x00, Jf: 0x00, K: 0xfffff00c},
{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, Jt: 0x02, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_MISC | unix.BPF_TAX, Jt: 0x00, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_LD | unix.BPF_W | unix.BPF_IND, Jt: 0x00, Jf: 0x00, K: 0x00000004},
{Op: unix.BPF_MISC | unix.BPF_TAX, Jt: 0x00, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_ALU | unix.BPF_AND | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0xffffffff},
{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, Jt: 0x08, Jf: 0x00, K: 0x0000000a},
{Op: unix.BPF_MISC | unix.BPF_TXA, Jt: 0x00, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_ALU | unix.BPF_AND | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0xffffffff},
{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, Jt: 0x05, Jf: 0x00, K: 0x00000032},
{Op: unix.BPF_MISC | unix.BPF_TXA, Jt: 0x00, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_ALU | unix.BPF_AND | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0xffffffff},
{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, Jt: 0x02, Jf: 0x00, K: 0x000003e8},
{Op: unix.BPF_MISC | unix.BPF_TXA, Jt: 0x00, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_RET | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0x00000000},
//---- final verdict ----
{Op: unix.BPF_RET | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0xffffffff},
}},
{name: "mark negative filter: [10,11]", table: Conntrack, filters: []ConnAttr{
{Type: AttrMark, Data: mark10ByteValue, Mask: []byte{255, 255, 255, 255}, Negate: true},
{Type: AttrMark, Data: mark11ByteValue, Mask: []byte{255, 255, 255, 255}, Negate: true},
}, rawInstr: []bpf.RawInstruction{
//--- check subsys ---
{Op: unix.BPF_LDX | unix.BPF_IMM, Jt: 0x00, Jf: 0x00, K: 0x00000004},
{Op: unix.BPF_LD | unix.BPF_B | unix.BPF_IND, Jt: 0x00, Jf: 0x00, K: 0x00000001},
{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, Jt: 0x01, Jf: 0x00, K: 0x00000001},
{Op: unix.BPF_RET | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0xffffffff},
//--- check mark ---
{Op: unix.BPF_LD | unix.BPF_IMM, Jt: 0x00, Jf: 0x00, K: 0x00000014},
{Op: unix.BPF_LDX | unix.BPF_IMM, Jt: 0x00, Jf: 0x00, K: 0x00000008},
{Op: unix.BPF_LD | unix.BPF_B | unix.BPF_ABS, Jt: 0x00, Jf: 0x00, K: 0xfffff00c},
{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, Jt: 0x02, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_MISC | unix.BPF_TAX, Jt: 0x00, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_LD | unix.BPF_W | unix.BPF_IND, Jt: 0x00, Jf: 0x00, K: 0x00000004},
{Op: unix.BPF_MISC | unix.BPF_TAX, Jt: 0x00, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_ALU | unix.BPF_AND | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0xffffffff},
{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, Jt: 0x05, Jf: 0x00, K: 0x0000000a},
{Op: unix.BPF_MISC | unix.BPF_TXA, Jt: 0x00, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_ALU | unix.BPF_AND | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0xffffffff},
{Op: unix.BPF_JMP | unix.BPF_JEQ | unix.BPF_K, Jt: 0x02, Jf: 0x00, K: 0x0000000b},
{Op: unix.BPF_MISC | unix.BPF_TXA, Jt: 0x00, Jf: 0x00, K: 0x00000000},
{Op: unix.BPF_JMP | unix.BPF_JA, Jt: 0x00, Jf: 0x00, K: 0x00000001},
{Op: unix.BPF_RET | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0x00000000},
//---- final verdict ----
{Op: unix.BPF_RET | unix.BPF_K, Jt: 0x00, Jf: 0x00, K: 0xffffffff},
}},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
rawInstr, err := constructFilter(tc.table, tc.filters)
if !errors.Is(err, tc.err) {
t.Fatal(err)
}
if len(rawInstr) != len(tc.rawInstr) {
t.Fatalf("different length:\n- want:\n%s\n- got:\n%s", fmtRawInstructions(tc.rawInstr), fmtRawInstructions(rawInstr))
}
var isErr bool
for i, v := range rawInstr {
if v != tc.rawInstr[i] {
t.Errorf("unexpected instruction:\n- want:\n%s\n- got:\n%s", fmtRawInstruction(i, tc.rawInstr[i]), fmtRawInstruction(i, rawInstr[i]))
isErr = true
}
}

if isErr {
t.Fatalf("unexpected reply:\n- want:\n%s\n- got:\n%s", fmtRawInstructions(tc.rawInstr), fmtRawInstructions(rawInstr))
}
})
}
}
5 changes: 5 additions & 0 deletions conntrack.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,11 @@ func (nfct *Nfct) RegisterFiltered(ctx context.Context, t Table, group NetlinkGr
return nfct.register(ctx, t, group, filter, fn)
}

// EnableDebug print bpf filter for RegisterFiltered function
func (nfct *Nfct) EnableDebug() {
nfct.debug = true
}

func (nfct *Nfct) register(ctx context.Context, t Table, groups NetlinkGroup, filter []ConnAttr, fn func(c Con) int) error {
nfct.ctx, nfct.ctxCancel = context.WithCancel(ctx)
nfct.shutdown = make(chan struct{})
Expand Down
2 changes: 2 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ type Nfct struct {

errChan chan error

debug bool

setWriteTimeout func() error

ctx context.Context
Expand Down

0 comments on commit 45c7b44

Please sign in to comment.