Skip to content

Commit

Permalink
chore: code refactor to add assertion (aquasecurity#4014)
Browse files Browse the repository at this point in the history
* chore: enable unchecked-type-assertion

* chore: code refactor to add assertion

Refactored code to fix the unchecked type cast issue by
incorporating error handling after the type assertion.
  • Loading branch information
rscampos authored May 22, 2024
1 parent ba4db8b commit 656eb97
Show file tree
Hide file tree
Showing 13 changed files with 95 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .revive.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ enableAllRules = true
[rule.var-declaration]
Disabled = false
[rule.unchecked-type-assertion]
Disabled = true
Disabled = false
[rule.unconditional-recursion]
Disabled = false
[rule.unexported-naming]
Expand Down
6 changes: 5 additions & 1 deletion pkg/ebpf/events_pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,11 @@ func (t *Tracee) decodeEvents(ctx context.Context, sourceChan chan []byte) (<-ch
}

// get an event pointer from the pool
evt := t.eventsPool.Get().(*trace.Event)
evt, ok := t.eventsPool.Get().(*trace.Event)
if !ok {
t.handleError(errfmt.Errorf("failed to get event from pool"))
continue
}

// populate all the fields of the event used in this stage, and reset the rest

Expand Down
11 changes: 8 additions & 3 deletions pkg/ebpf/events_pipeline_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ func BenchmarkGetEventFromPool(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < decodeEvts; j++ {
ctx := <-decodeChan
evt := evtPool.Get().(*trace.Event)

evt, ok := evtPool.Get().(*trace.Event)
if !ok {
b.Error("Failed to get event from pool")
}
evt.Timestamp = int(ctx.Ts)
evt.ThreadStartTime = int(ctx.StartTime)
evt.ProcessorID = int(ctx.ProcessorId)
Expand Down Expand Up @@ -125,7 +127,10 @@ func BenchmarkGetEventFromPool(b *testing.B) {

// get an event from the pool, fill it with data and
// pass it to the other stages
evtCopy := evtPool.Get().(*trace.Event)
evtCopy, ok := evtPool.Get().(*trace.Event)
if !ok {
b.Error("Failed to get event from pool")
}
*evtCopy = *evt // shallow copy
sinkChan <- evt
if j < deriveEvts {
Expand Down
11 changes: 7 additions & 4 deletions pkg/events/derive/symbols_collision.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,10 @@ func (procLoadedObjsCache *loadedObjsPerProcessCache) GetLoadedObjsPerProcess(

loadedObjsIface, ok := procLoadedObjsCache.cache.Get(pid) // loaded objs per process (ObjInfo)
if ok {
loadedObjs = loadedObjsIface.([]sharedobjs.ObjInfo)
return loadedObjs, true // true if process existed in the cache
if objs, ok := loadedObjsIface.([]sharedobjs.ObjInfo); ok {
loadedObjs = objs
return loadedObjs, true
}
}

return nil, false
Expand Down Expand Up @@ -336,8 +338,9 @@ func (socCache collisionChecksCache) getObjCollisionsAndCollisionKey(
collisionsIface, ok = socCache.cache.Get(key)
}
if ok {
collisions := collisionsIface.([]string)
return key, collisions, true
if collisions, ok := collisionsIface.([]string); ok {
return key, collisions, true
}
}

return collisionsKey{}, nil, false // no collisions found
Expand Down
10 changes: 8 additions & 2 deletions pkg/events/derive/symbols_collision_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,17 @@ func TestSymbolsCollisionArgsGenerator_deriveArgs(t *testing.T) {
assert.Equal(t, testCase.loadingSO.Path, args[0])
path := args[1]
require.IsType(t, "", path)
path = path.(string)
path, ok := path.(string)
if !ok {
t.Error("Failed to cast path's value")
}
if path == lso.so.Path {
col := args[2]
require.IsType(t, []string{}, col)
col = col.([]string)
col, ok = col.([]string)
if !ok {
t.Error("Failed to cast path arg's value")
}
assert.ElementsMatch(t, col, lso.expectedCollisions)
found = true
break
Expand Down
5 changes: 4 additions & 1 deletion pkg/events/ftrace.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ func checkFtraceHooks(eventsCounter counter.Counter, out chan *trace.Event, base
}

args[countIndex].Value = newCount
symbol := args[symbolIndex].Value.(string)
symbol, ok := args[symbolIndex].Value.(string)
if !ok {
return errors.New("failed to cast symbol's value")
}

// Verify that we didn't report this symbol already, and it wasn't changed.
// If we reported the symbol in the past, and now the count has reduced - report only if the callback was changed
Expand Down
31 changes: 17 additions & 14 deletions pkg/server/grpc/tracee.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,20 +345,23 @@ func getThreat(description string, metadata map[string]interface{}) *pb.Threat {
}

func getSeverity(metadata map[string]interface{}) pb.Severity {
switch metadata["Severity"].(int) {
case 0:
return pb.Severity_INFO
case 1:
return pb.Severity_LOW
case 2:
return pb.Severity_MEDIUM
case 3:
return pb.Severity_HIGH
case 4:
return pb.Severity_CRITICAL
}

return -1
severityValue, ok := metadata["Severity"].(int)
if ok {
switch severityValue {
case 0:
return pb.Severity_INFO
case 1:
return pb.Severity_LOW
case 2:
return pb.Severity_MEDIUM
case 3:
return pb.Severity_HIGH
case 4:
return pb.Severity_CRITICAL
}
}

return pb.Severity_INFO
}

func getStackAddress(stackAddresses []uint64) []*pb.StackAddress {
Expand Down
5 changes: 4 additions & 1 deletion pkg/signatures/benchmark/signature/golang/anti_debugging.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ func (sig *antiDebugging) OnEvent(event protocol.Event) error {
if err != nil {
return err
}
requestString := request.Value.(string)
requestString, ok := request.Value.(string)
if !ok {
return fmt.Errorf("failed to cast request's value")
}
if requestString != "PTRACE_TRACEME" {
return nil
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/signatures/benchmark/signature/golang/code_injection.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ func (sig *codeInjection) OnEvent(event protocol.Event) error {
if err != nil {
return err
}
requestString := request.Value.(string)

requestString, ok := request.Value.(string)
if !ok {
return fmt.Errorf("failed to cast request's value")
}

if requestString == "PTRACE_POKETEXT" || requestString == "PTRACE_POKEDATA" {
sig.cb(&detect.Finding{
// Signature: sig,
Expand Down
5 changes: 3 additions & 2 deletions pkg/utils/sharedobjs/host_symbols_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ type dynamicSymbolsLRUCache struct {
func (soCache *dynamicSymbolsLRUCache) Get(objID ObjID) (*DynamicSymbols, bool) {
objInfoIface, ok := soCache.lru.Get(objID)
if ok {
objInfo := objInfoIface.(*DynamicSymbols)
return objInfo, true
if objInfo, ok := objInfoIface.(*DynamicSymbols); ok {
return objInfo, true
}
}

return nil, false
Expand Down
15 changes: 12 additions & 3 deletions tests/integration/capture_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ func readWriteCaptureTest(t *testing.T, captureDir string, workingDir string) er
return err
}

statInfo := fi.Sys().(*syscall.Stat_t)
statInfo, ok := fi.Sys().(*syscall.Stat_t)
if !ok {
t.Logf("type assertion failed: expected *syscall.Stat_t")
}
inode := statInfo.Ino

// Write "Hello World" into the file
Expand Down Expand Up @@ -172,7 +175,10 @@ func readWritevCaptureTest(t *testing.T, captureDir string, workingDir string) e
return err
}

statInfo := fi.Sys().(*syscall.Stat_t)
statInfo, ok := fi.Sys().(*syscall.Stat_t)
if !ok {
t.Logf("type assertion failed: expected *syscall.Stat_t")
}
inode := statInfo.Ino

// Strings to write
Expand Down Expand Up @@ -251,7 +257,10 @@ func readWritePipe(t *testing.T, captureDir string, workingDir string) error {

finfo, err := pipe.Stat()
require.NoError(t, err)
statInfo := finfo.Sys().(*syscall.Stat_t)
statInfo, ok := finfo.Sys().(*syscall.Stat_t)
if !ok {
t.Logf("type assertion failed: expected *syscall.Stat_t")
}
inode := statInfo.Ino

// Write "Hello World!" to the named pipe
Expand Down
20 changes: 16 additions & 4 deletions tests/integration/event_filters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2172,7 +2172,10 @@ func ExpectAtLeastOneForEach(t *testing.T, cmdEvents []cmdEvents, actual *eventB
}
switch v := expArg.Value.(type) {
case string:
actVal := actArg.Value.(string)
actVal, ok := actArg.Value.(string)
if !ok {
return fmt.Errorf("failed to cast arg's value")
}
if strings.Contains(v, "*") {
v = strings.ReplaceAll(v, "*", "")
if !strings.Contains(actVal, v) {
Expand Down Expand Up @@ -2305,7 +2308,10 @@ func ExpectAnyOfEvts(t *testing.T, cmdEvents []cmdEvents, actual *eventBuffer, u
}
switch v := expArg.Value.(type) {
case string:
actVal := actArg.Value.(string)
actVal, ok := actArg.Value.(string)
if !ok {
return fmt.Errorf("failed to cast arg's value")
}
if strings.Contains(v, "*") {
v = strings.ReplaceAll(v, "*", "")
if !strings.Contains(actVal, v) {
Expand Down Expand Up @@ -2424,7 +2430,10 @@ func ExpectAllEvtsEqualToOne(t *testing.T, cmdEvents []cmdEvents, actual *eventB
}
switch v := expArg.Value.(type) {
case string:
actVal := actArg.Value.(string)
actVal, ok := actArg.Value.(string)
if !ok {
return fmt.Errorf("failed to cast arg's value")
}
if strings.Contains(v, "*") {
v = strings.ReplaceAll(v, "*", "")
if !strings.Contains(actVal, v) {
Expand Down Expand Up @@ -2523,7 +2532,10 @@ func ExpectAllInOrderSequentially(t *testing.T, cmdEvents []cmdEvents, actual *e
}
switch v := expArg.Value.(type) {
case string:
actVal := actArg.Value.(string)
actVal, ok := actArg.Value.(string)
if !ok {
return fmt.Errorf("failed to cast arg's value")
}
if strings.Contains(v, "*") {
v = strings.ReplaceAll(v, "*", "")
if !strings.Contains(actVal, v) {
Expand Down
5 changes: 4 additions & 1 deletion types/trace/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,10 @@ func (arg *Argument) UnmarshalJSON(b []byte) error {
switch arg.Type {
case "const char*const*", "const char**":
if arg.Value != nil {
argValue := arg.Value.([]interface{})
argValue, ok := arg.Value.([]interface{})
if !ok {
return fmt.Errorf("const char*const*: type error")
}
arg.Value = jsonConvertToStringSlice(argValue)
} else {
arg.Value = []string{}
Expand Down

0 comments on commit 656eb97

Please sign in to comment.