diff --git a/pkg/ebpf/c/tracee.bpf.c b/pkg/ebpf/c/tracee.bpf.c index f8470b3b5692..69427acc1a58 100644 --- a/pkg/ebpf/c/tracee.bpf.c +++ b/pkg/ebpf/c/tracee.bpf.c @@ -2490,12 +2490,28 @@ int BPF_KPROBE(trace_security_socket_connect) if (!should_submit(SECURITY_SOCKET_CONNECT, p.event)) return 0; - uint addr_len = (uint) PT_REGS_PARM3(ctx) & 0x7FFFFFFF; // verifier overflow checks + u64 addr_len = PT_REGS_PARM3(ctx); + + struct socket *sock = (struct socket *) PT_REGS_PARM1(ctx); + if (!sock) + return 0; struct sockaddr *address = (struct sockaddr *) PT_REGS_PARM2(ctx); if (!address) return 0; + // Check if the socket type is supported. + u32 type = BPF_CORE_READ(sock, type); + switch (type) { + // TODO: case SOCK_DCCP: + case SOCK_DGRAM: + case SOCK_SEQPACKET: + case SOCK_STREAM: + break; + default: + return 0; + } + // Check if the socket family is supported. sa_family_t sa_fam = get_sockaddr_family(address); switch (sa_fam) { @@ -2521,9 +2537,6 @@ int BPF_KPROBE(trace_security_socket_connect) to = (void *) sys->args.args[1]; #endif - // Save the socket fd argument to the event. - stsb(args_buf, to, sizeof(u32), 0); - // Save the socket fd, depending on the syscall. switch (sys->id) { case SYSCALL_CONNECT: @@ -2533,6 +2546,12 @@ int BPF_KPROBE(trace_security_socket_connect) return 0; } + // Save the socket fd argument to the event. + stsb(args_buf, to, sizeof(u32), 0); + + // Save the socket type argument to the event. + stsb(args_buf, &type, sizeof(u32), 1); + bool need_workaround = false; // Save the sockaddr struct, depending on the family. @@ -2556,14 +2575,14 @@ int BPF_KPROBE(trace_security_socket_connect) if (need_workaround) { // Workaround for sockaddr_un struct length (issue: #1129). struct sockaddr_un sockaddr = {0}; - bpf_probe_read(&sockaddr, (uint) addr_len, (void *) address); - stsb(args_buf, (void *) &sockaddr, sizeof(struct sockaddr_un), 1); + bpf_probe_read(&sockaddr, (u32) addr_len, (void *) address); + stsb(args_buf, (void *) &sockaddr, sizeof(struct sockaddr_un), 2); } #endif // Save the sockaddr struct argument to the event. if (!need_workaround) - stsb(args_buf, (void *) address, sockaddr_len, 1); + stsb(args_buf, (void *) address, sockaddr_len, 2); // Submit the event. return events_perf_submit(&p, SECURITY_SOCKET_CONNECT, 0); diff --git a/pkg/ebpf/c/vmlinux.h b/pkg/ebpf/c/vmlinux.h index 73b9ba6f8cbb..84d8eeccfa82 100644 --- a/pkg/ebpf/c/vmlinux.h +++ b/pkg/ebpf/c/vmlinux.h @@ -478,8 +478,9 @@ struct alloc_context { }; struct socket { - struct sock *sk; + short type; struct file *file; + struct sock *sk; }; typedef struct { diff --git a/pkg/events/core.go b/pkg/events/core.go index 96212af56df7..9a0de8882ea4 100644 --- a/pkg/events/core.go +++ b/pkg/events/core.go @@ -9811,6 +9811,7 @@ var CoreEvents = map[ID]Definition{ sets: []string{"default", "lsm_hooks", "net", "net_sock"}, params: []trace.ArgMeta{ {Type: "int", Name: "sockfd"}, + {Type: "int", Name: "type"}, {Type: "struct sockaddr*", Name: "remote_addr"}, }, }, diff --git a/pkg/events/derive/net_tcp.go b/pkg/events/derive/net_tcp.go index fb80007a95f6..017bee4408b8 100644 --- a/pkg/events/derive/net_tcp.go +++ b/pkg/events/derive/net_tcp.go @@ -57,6 +57,16 @@ func pickIpAndPort(event trace.Event, fieldName string) (string, int, error) { var err error // e.g: sockaddr: map[sa_family:AF_INET sin_addr:10.10.11.2 sin_port:1234] + // Check if socket is a TCP socket. + sType, err := parse.ArgVal[string](event.Args, "type") + if err != nil { + return "", 0, errfmt.WrapError(err) + } + if sType != "SOCK_STREAM" { + return "", 0, nil + } + + // Get sockaddr field. sockaddr, err := parse.ArgVal[map[string]string](event.Args, fieldName) if err != nil { return "", 0, errfmt.WrapError(err) diff --git a/pkg/events/parse_args.go b/pkg/events/parse_args.go index d7bc6d3be0b2..cf60252a9259 100644 --- a/pkg/events/parse_args.go +++ b/pkg/events/parse_args.go @@ -132,7 +132,7 @@ func ParseArgs(event *trace.Event) error { parseOrEmptyString(typeArg, socketTypeArgument, err) } } - case SecuritySocketCreate: + case SecuritySocketCreate, SecuritySocketConnect: if domArg := GetArg(event, "family"); domArg != nil { if dom, isInt32 := domArg.Value.(int32); isInt32 { socketDomainArgument, err := helpers.ParseSocketDomainArgument(uint64(dom))