From 9f59db3a9c1e4437dcd94e3aba02f1ee89da7ad1 Mon Sep 17 00:00:00 2001 From: xiaoweihao Date: Mon, 6 Jan 2025 11:05:43 +0800 Subject: [PATCH 1/2] test push --- test.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 test.txt diff --git a/test.txt b/test.txt new file mode 100644 index 00000000..a9b9cd5f --- /dev/null +++ b/test.txt @@ -0,0 +1 @@ +push From afcc79e2a3dce66e26c5ad0a40228fb2f9ba169a Mon Sep 17 00:00:00 2001 From: xiaoweihao Date: Mon, 6 Jan 2025 11:11:49 +0800 Subject: [PATCH 2/2] add Mongodb protocol ParseStream() --- agent/protocol/decoder.go | 56 ++++++ agent/protocol/mongodb/mongodb.go | 309 ++++++++++++++++++++++++++++++ agent/protocol/mongodb/types.go | 63 ++++++ bpf/protocol_inference.h | 55 ++++++ 4 files changed, 483 insertions(+) create mode 100644 agent/protocol/mongodb/mongodb.go create mode 100644 agent/protocol/mongodb/types.go diff --git a/agent/protocol/decoder.go b/agent/protocol/decoder.go index 5dc07801..181b41a1 100644 --- a/agent/protocol/decoder.go +++ b/agent/protocol/decoder.go @@ -48,6 +48,10 @@ func (d *BinaryDecoder) ReadBytes() int { var ResourceNotAvailble = NewResourceNotAvailbleError("Insufficient number of bytes.") var NotFound = NewNotFoundError("Could not find sentinel character") +func (d *BinaryDecoder) RemovePrefix(size int32) { + d.str = d.str[size:] +} + /* Extract until encounter the input string. @@ -83,3 +87,55 @@ func (d *BinaryDecoder) ExtractByte() (byte, error) { d.readBytes++ return x, nil } + +func ExtractBEInt[TIntType int32 | uint32 | uint8](d *BinaryDecoder) (TIntType, error) { + typeSize := int(reflect.TypeOf(TIntType(0)).Size()) + if len(d.str) < typeSize { + return 0, ResourceNotAvailble + } + var x TIntType = 0 + for i := 0; i < typeSize; i++ { + x = TIntType(d.str[i]) | (x << 8) + } + d.str = d.str[typeSize:] + d.readBytes += typeSize + return x, nil +} + +func ExtractLEInt[TIntType int32 | uint32 | uint8](d *BinaryDecoder) (TIntType, error) { + typeSize := int(reflect.TypeOf(TIntType(0)).Size()) + if len(d.str) < 4 { + return 0, ResourceNotAvailble + } + var x TIntType = 0 + for i := 0; i < typeSize; i++ { + x = TIntType(d.str[typeSize-1-i]) | (x << 8) + } + d.str = d.str[typeSize:] + d.readBytes += typeSize + return x, nil +} + +func BEndianBytesToInt[TIntType int32 | uint32 | uint8](d *BinaryDecoder) TIntType { + typeSize := int(reflect.TypeOf(TIntType(0)).Size()) + if len(d.str) < typeSize { + return 0 + } + var x TIntType = 0 + for i := 0; i < typeSize; i++ { + x = TIntType(d.str[i]) | (x << 8) + } + return x +} + +func LEndianBytesToInt[TIntType int32 | uint32 | uint8](d *BinaryDecoder) TIntType { + typeSize := int(reflect.TypeOf(TIntType(0)).Size()) + if len(d.str) < 4 { + return 0 + } + var x TIntType = 0 + for i := 0; i < typeSize; i++ { + x = TIntType(d.str[typeSize-1-i]) | (x << 8) + } + return x +} diff --git a/agent/protocol/mongodb/mongodb.go b/agent/protocol/mongodb/mongodb.go new file mode 100644 index 00000000..2aa7a7a0 --- /dev/null +++ b/agent/protocol/mongodb/mongodb.go @@ -0,0 +1,309 @@ +package mongodb + +import ( + "encoding/json" + "fmt" + "kyanos/agent/buffer" + . "kyanos/agent/protocol" + "kyanos/bpf" + + "go.mongodb.org/mongo-driver/bson" +) + +var _ ProtocolStreamParser = &MongoDBStreamParser{} +var _ ParsedMessage = &MongoDBFrame{} + +type MongoDBFrame struct { + FrameBase + // Message Header Fields + // Length of the mongodb header and the wire protocol data. + length int32 + requestId int32 + responseTo int32 + opCode int32 + + // OP_MSG Fields + // Relevant flag bits + checksumPresent bool + moreToCome bool + exhaustAllowed bool + sections []Section + opMsgType string + frame_body string + checksum uint32 + isHandshake bool + consumed bool + + cmd int32 + isReq bool +} + +// FormatToSummaryString implements protocol.ParsedMessage. +func (m *MongoDBFrame) FormatToSummaryString() string { + return fmt.Sprintf("MongoDB base=[%s]", m.FrameBase.String()) +} + +// FormatToString implements protocol.ParsedMessage. +func (m *MongoDBFrame) FormatToString() string { + return fmt.Sprintf("MongoDB base=[%s]", m.FrameBase.String()) +} + +// IsReq implements protocol.ParsedMessage. +func (m *MongoDBFrame) IsReq() bool { + return false +} + +type MongoDBStreamParser struct { +} + +func (m *MongoDBStreamParser) ParseStream(streamBuffer *buffer.StreamBuffer, messageType MessageType) ParseResult { + result := ParseResult{} + if messageType != Request && messageType != Response { + result.ParseState = Invalid + return result + } + + head := streamBuffer.Head().Buffer() + decoder := NewBinaryDecoder(head) + + if uint8(len(head)) < kHeaderLength { + result.ParseState = NeedsMoreData + return result + } + + // Get the length of the packet. This length contains the size of the field containing the + // message's length itself. + length, err := ExtractLEInt[int32](decoder) + if err != nil { + result.ParseState = Invalid + return result + } + if int32(len(head)) < length-int32(kMessageLengthSize) { + result.ParseState = NeedsMoreData + return result + } + + // Get the Request ID. + requestId, err := ExtractLEInt[int32](decoder) + if err != nil { + result.ParseState = Invalid + return result + } + // Get the Response To. + respondTo, err := ExtractLEInt[int32](decoder) + if err != nil { + result.ParseState = Invalid + return result + } + // Get the message's op code (type). + opCode, err := ExtractLEInt[int32](decoder) + if err != nil { + result.ParseState = Invalid + return result + } + if !(opCode == kOPMsg || opCode == kOPReply || + opCode == kOPUpdate || opCode == kOPInsert || + opCode == kReserved || opCode == kOPQuery || + opCode == kOPGetMore || opCode == kOPDelete || + opCode == kOPKillCursors || opCode == kOPCompressed) { + result.ParseState = Invalid + return result + } + + // Parser will ignore Op Codes that have been deprecated/removed from version 5.0 onwards as well + // as kOPCompressed and kReserved which are not supported by the parser yet. + if opCode != kOPMsg { + decoder.RemovePrefix(length) // int32 + result.ParseState = Ignore + return result + } + + mongoDBFrame := &MongoDBFrame{ + length: length, + requestId: requestId, + responseTo: respondTo, + opCode: opCode, + } + + result.ParseState = ProcessPayload(decoder, mongoDBFrame) + result.ReadBytes = decoder.ReadBytes() + result.ParsedMessages = []ParsedMessage{mongoDBFrame} + return result +} + +func (m *MongoDBStreamParser) FindBoundary(streamBuffer *buffer.StreamBuffer, messageType MessageType, startPos int) int { + // 待实现 + return 0 +} + +func (m *MongoDBStreamParser) Match(reqStream *[]ParsedMessage, respStream *[]ParsedMessage) []Record { + // 待实现 + records := make([]Record, 0) + return records +} + +func init() { + ParsersMap[bpf.AgentTrafficProtocolTKProtocolMongo] = func() ProtocolStreamParser { + return &MongoDBStreamParser{} + } +} + +func ProcessPayload(decoder *BinaryDecoder, mongoDBFrame *MongoDBFrame) ParseState { + switch mongoDBFrame.opCode { + case kOPMsg: + return ProcessOpMsg(decoder, mongoDBFrame) + case kOPCompressed: + return Ignore + case kReserved: + return Ignore + default: + return Invalid + } +} + +func ProcessOpMsg(decoder *BinaryDecoder, mongoDBFrame *MongoDBFrame) ParseState { + flagBits, err := ExtractLEInt[uint32](decoder) + if err != nil { + return Invalid + } + + // Find relevant flag bit information and ensure remaining bits are not set. + // Bits 0-15 are required and bits 16-31 are optional. + mongoDBFrame.checksumPresent = (flagBits & kChecksumBitmask) == kChecksumBitmask + mongoDBFrame.moreToCome = (flagBits & kMoreToComeBitmask) == kMoreToComeBitmask + mongoDBFrame.exhaustAllowed = (flagBits & kExhaustAllowedBitmask) == kExhaustAllowedBitmask + if flagBits&kRequiredUnsetBitmask != 0 { + return Invalid + } + + // Determine the number of checksum bytes in the buffer. + var checksumBytes int32 + if mongoDBFrame.checksumPresent { + checksumBytes = 4 + } else { + checksumBytes = 0 + } + + // Get the section(s) data from the buffer. + allSectionsLength := mongoDBFrame.length - int32(kHeaderAndFlagSize) - int32(checksumBytes) + for allSectionsLength > 0 { + var section Section + section.kind, err = ExtractLEInt[uint8](decoder) + if err != nil { + return Invalid + } + // Length of the current section still remaining in the buffer. + var remainingSectionLength int32 = 0 + + if section.kind == kSectionKindZero { + // Check the length but don't extract it since the later logic requires the buffer to retain it. + section.length = LEndianBytesToInt[int32](decoder) + if section.length < int32(kSectionLengthSize) { + return Invalid + } + remainingSectionLength = section.length + } else if section.kind == kSectionKindOne { + section.length, err = ExtractLEInt[int32](decoder) //pixie uint32? + if err != nil { + return Invalid + } + if section.length < int32(kSectionLengthSize) { + return Invalid + } + // Get the sequence identifier (command argument). + seqIdentifier, err := decoder.ExtractStringUntil("\\0") //pixie '\0'? + if err != nil { + return Invalid + } + // Make sure the sequence identifier is a valid OP_MSG kind 1 command argument. + if seqIdentifier != "documents" && seqIdentifier != "updates" && seqIdentifier != "deletes" { + return Invalid + } + remainingSectionLength = section.length - int32(kSectionLengthSize) - int32(len(seqIdentifier)) - int32(kSectionKindSize) + } else { + return Invalid + } + + // Extract the document(s) from the section and convert it from type BSON to a JSON string. + for remainingSectionLength > 0 { + // We can't extract the length bytes since bson_new_from_data() expects those bytes in + // the data as well as the expected length in another parameter. + documentLength := LEndianBytesToInt[int32](decoder) + if documentLength > kMaxBSONObjSize { + return Invalid + } + sectionBody, err := decoder.ExtractString(int(documentLength)) + if err != nil { + return Invalid + } + + // Check if section_body contains an empty document. + if len(sectionBody) == int(kSectionLengthSize) { + section.documents = append(section.documents, "") + remainingSectionLength -= documentLength + continue + } + + // Convert the BSON document to a JSON string. + var bsonDoc bson.M + if err := bson.Unmarshal([]byte(sectionBody), &bsonDoc); err != nil { + return Invalid + } + jsonDoc, err := bson.MarshalExtJSON(bsonDoc, true, false) + if err != nil { + return Invalid + } + + var doc map[string]interface{} + if err := json.Unmarshal(jsonDoc, &doc); err != nil { + return Invalid + } + + // Find the type of command argument from the kind 0 section. + if section.kind == kSectionKindZero { + var opMsgType string + for key := range doc { + opMsgType = key + break + } + switch opMsgType { + case kInsert, kDelete, kUpdate, kFind, kCursor: + mongoDBFrame.opMsgType = opMsgType + case kHello, kIsMaster, kIsMasterAlternate: + // The frame is a handshaking message. + mongoDBFrame.opMsgType = opMsgType + mongoDBFrame.isHandshake = true + default: + // The frame is a response message, find the "ok" key and its value. + if okValue, ok := doc["ok"]; ok { + switch v := okValue.(type) { + case map[string]interface{}: + for key, value := range v { + mongoDBFrame.opMsgType = fmt.Sprintf("ok: {%s: %v}", key, value) + break + } + case float64: + mongoDBFrame.opMsgType = fmt.Sprintf("ok: %d", int(v)) + } + } else { + return Invalid + } + + } + } + + section.documents = append(section.documents, string(jsonDoc)) + remainingSectionLength -= documentLength + } + mongoDBFrame.sections = append(mongoDBFrame.sections, section) + allSectionsLength -= (section.length + int32(kSectionKindSize)) + } + // Get the checksum data, if necessary. + if mongoDBFrame.checksumPresent { + mongoDBFrame.checksum, err = ExtractLEInt[uint32](decoder) + if err != nil { + return Invalid + } + } + return Success +} diff --git a/agent/protocol/mongodb/types.go b/agent/protocol/mongodb/types.go new file mode 100644 index 00000000..116c579a --- /dev/null +++ b/agent/protocol/mongodb/types.go @@ -0,0 +1,63 @@ +package mongodb + +type opType int32 + +const ( + kHeaderLength uint8 = 16 + kMessageLengthSize uint8 = 4 + kSectionLengthSize uint8 = 4 + kHeaderAndFlagSize uint8 = 20 + + kChecksumBitmask uint32 = 1 + kMoreToComeBitmask uint32 = 1 << 1 + kExhaustAllowedBitmask uint32 = 1 << 16 + kRequiredUnsetBitmask uint32 = 0xFFFC +) + +const ( + kOPReply int32 = 1 + kOPUpdate int32 = 2001 + kOPInsert int32 = 2002 + kReserved int32 = 2003 + kOPQuery int32 = 2004 + kOPGetMore int32 = 2005 + kOPDelete int32 = 2006 + kOPKillCursors int32 = 2007 + kOPCompressed int32 = 2012 + kOPMsg int32 = 2013 +) + +type Section struct { + kind uint8 + length int32 + documents []string +} + +const ( + kSectionKindSize uint8 = 1 +) + +const ( + kSectionKindZero uint8 = iota + kSectionKindOne +) + +// Types of OP_MSG requests/responses +const ( + kInsert = "insert" + kDelete = "delete" + kUpdate = "update" + kFind = "find" + kCursor = "cursor" + kOk = "ok" +) + +// Types of top level keys for handshaking messages +const ( + kHello = "hello" + kIsMaster = "isMaster" + kIsMasterAlternate = "ismaster" +) + +// Max BSON object size in bytes +const kMaxBSONObjSize = 16000000 diff --git a/bpf/protocol_inference.h b/bpf/protocol_inference.h index 0912c333..9dc844ba 100644 --- a/bpf/protocol_inference.h +++ b/bpf/protocol_inference.h @@ -127,12 +127,67 @@ static __always_inline enum message_type_t is_http_protocol(const char *old_buf, return kUnknown; } + +// MongoDB protocol +static __inline enum message_type_t is_mongo_protocol(const char* buf, size_t count) { + // Reference: + // https://docs.mongodb.com/manual/reference/mongodb-wire-protocol/#std-label-wp-request-opcodes. + // Note: Response side inference for Mongo is not robust, and is not attempted to avoid + // confusion with other protocols, especially MySQL. + static const int32_t kOPUpdate = 2001; + static const int32_t kOPInsert = 2002; + static const int32_t kReserved = 2003; + static const int32_t kOPQuery = 2004; + static const int32_t kOPGetMore = 2005; + static const int32_t kOPDelete = 2006; + static const int32_t kOPKillCursors = 2007; + static const int32_t kOPCompressed = 2012; + static const int32_t kOPMsg = 2013; + + static const int32_t kMongoHeaderLength = 16; + + if (count < kMongoHeaderLength) { + return kUnknown; + } + + int32_t* buf4 = (int32_t*)buf; + int32_t message_length = buf4[0]; + + if (message_length < kMongoHeaderLength) { + return kUnknown; + } + + int32_t request_id = buf4[1]; + + if (request_id < 0) { + return kUnknown; + } + + int32_t response_to = buf4[2]; + int32_t opcode = buf4[3]; + + if (opcode == kOPUpdate || opcode == kOPInsert || opcode == kReserved || opcode == kOPQuery || + opcode == kOPGetMore || opcode == kOPDelete || opcode == kOPKillCursors || + opcode == kOPCompressed || opcode == kOPMsg) { + if (response_to == 0) { + return kRequest; + } + } + + return kUnknown; +} + + + static __always_inline struct protocol_message_t infer_protocol(const char *buf, size_t count, struct conn_info_t *conn_info) { struct protocol_message_t protocol_message; protocol_message.protocol = kProtocolUnknown; protocol_message.type = kUnknown; if ((protocol_message.type = is_http_protocol(buf, count)) != kUnknown) { protocol_message.protocol = kProtocolHTTP; + // MongoDB protocol + } else if ((protocol_message.type = is_mongo_protocol(buf, count)) != kUnknown) { + protocol_message.protocol = kProtocolMongo; } else if ((protocol_message.type = is_mysql_protocol(buf, count, conn_info)) != kUnknown) { protocol_message.protocol = kProtocolMySQL; } else if (is_redis_protocol(buf, count)) {