diff --git a/internal/context/datapath.go b/internal/context/datapath.go index 292f9209..3b540813 100644 --- a/internal/context/datapath.go +++ b/internal/context/datapath.go @@ -444,13 +444,27 @@ func (dataPath *DataPath) ActivateTunnelAndPDR(smContext *SMContext, precedence logger.PduSessLog.Errorln("new QER failed") return } else { + var bitRateKbpsULMBR uint64 + var bitRateKbpsDLMBR uint64 + var bitRateConvertErr error + bitRateKbpsULMBR, bitRateConvertErr = util.BitRateTokbps(sessionRule.AuthSessAmbr.Uplink) + if bitRateConvertErr != nil { + logger.PduSessLog.Errorln("Cannot get the unit of ULMBR, please check the settings in web console") + return + } + bitRateKbpsDLMBR, bitRateConvertErr = util.BitRateTokbps(sessionRule.AuthSessAmbr.Downlink) + if bitRateConvertErr != nil { + logger.PduSessLog.Errorln("Cannot get the unit of DLMBR, please check the settings in web console") + return + } + newQER.QFI.QFI = sessionRule.DefQosQFI newQER.GateStatus = &pfcpType.GateStatus{ ULGate: pfcpType.GateOpen, DLGate: pfcpType.GateOpen, } newQER.MBR = &pfcpType.MBR{ - ULMBR: util.BitRateTokbps(sessionRule.AuthSessAmbr.Uplink), - DLMBR: util.BitRateTokbps(sessionRule.AuthSessAmbr.Downlink), + ULMBR: bitRateKbpsULMBR, + DLMBR: bitRateKbpsDLMBR, } ambrQER = newQER } @@ -849,19 +863,62 @@ func (p *DataPath) AddQoS(smContext *SMContext, qfi uint8, qos *models.QosData) DLGate: pfcpType.GateOpen, } if isGBRFlow(qos) { + var bitRateKbpsQoSGBRUL uint64 + var bitRateKbpsQoSGBRDL uint64 + var bitRateKbpsQoSMBRUL uint64 + var bitRateKbpsQoSMBRDL uint64 + var bitRateConvertErr error + bitRateKbpsQoSGBRUL, bitRateConvertErr = util.BitRateTokbps(qos.GbrUl) + if bitRateConvertErr != nil { + logger.PduSessLog.Panicln("Cannot get the unit of GBRUL, please check the settings in web console") + return + } + + bitRateKbpsQoSGBRDL, bitRateConvertErr = util.BitRateTokbps(qos.GbrDl) + if bitRateConvertErr != nil { + logger.PduSessLog.Panicln("Cannot get the unit of GBRDL, please check the settings in web console") + return + } + + bitRateKbpsQoSMBRUL, bitRateConvertErr = util.BitRateTokbps(qos.MaxbrUl) + if bitRateConvertErr != nil { + logger.PduSessLog.Panicln("Cannot get the unit of MBRUL, please check the settings in web console") + return + } + + bitRateKbpsQoSMBRDL, bitRateConvertErr = util.BitRateTokbps(qos.MaxbrDl) + if bitRateConvertErr != nil { + logger.PduSessLog.Panicln("Cannot get the unit of MBRDL, please check the settings in web console") + return + } + newQER.GBR = &pfcpType.GBR{ - ULGBR: util.BitRateTokbps(qos.GbrUl), - DLGBR: util.BitRateTokbps(qos.GbrDl), + ULGBR: bitRateKbpsQoSGBRUL, + DLGBR: bitRateKbpsQoSGBRDL, } newQER.MBR = &pfcpType.MBR{ - ULMBR: util.BitRateTokbps(qos.MaxbrUl), - DLMBR: util.BitRateTokbps(qos.MaxbrDl), + ULMBR: bitRateKbpsQoSMBRUL, + DLMBR: bitRateKbpsQoSMBRDL, } } else { + var bitRateKbpsSessionAmbrMBRUL uint64 + var bitRateKbpsSessionAmbrMBRDL uint64 + var bitRateConvertErr error + bitRateKbpsSessionAmbrMBRUL, bitRateConvertErr = util.BitRateTokbps(qos.MaxbrUl) + if bitRateConvertErr != nil { + logger.PduSessLog.Error("Cannot get the unit of MBRUL, please check the settings in web console") + return + } + bitRateKbpsSessionAmbrMBRDL, bitRateConvertErr = util.BitRateTokbps(qos.MaxbrDl) + + if bitRateConvertErr != nil { + logger.PduSessLog.Error("Cannot get the unit of MBRDL, please check the settings in web console") + return + } // Non-GBR flow should follows session-AMBR newQER.MBR = &pfcpType.MBR{ - ULMBR: util.BitRateTokbps(smContext.DnnConfiguration.SessionAmbr.Uplink), - DLMBR: util.BitRateTokbps(smContext.DnnConfiguration.SessionAmbr.Downlink), + ULMBR: bitRateKbpsSessionAmbrMBRUL, + DLMBR: bitRateKbpsSessionAmbrMBRDL, } } qer = newQER diff --git a/internal/sbi/processor/ulcl_procedure.go b/internal/sbi/processor/ulcl_procedure.go index 573cc54d..eecfc383 100644 --- a/internal/sbi/processor/ulcl_procedure.go +++ b/internal/sbi/processor/ulcl_procedure.go @@ -147,13 +147,13 @@ func EstablishULCL(smContext *context.SMContext) { DownLinkPDR := curDPNode.DownLinkTunnel.PDR UPLinkPDR.State = context.RULE_INITIAL - // new IPFilterRule with action:"permit" and diection:"out" + // new IPFilterRule with action:"permit" and direction:"out" FlowDespcription := flowdesc.NewIPFilterRule() - FlowDespcription.Dst = dest.DestinationIP + FlowDespcription.Src = dest.DestinationIP if dstPort, err := flowdesc.ParsePorts(dest.DestinationPort); err != nil { - FlowDespcription.DstPorts = dstPort + FlowDespcription.SrcPorts = dstPort } - FlowDespcription.Src = smContext.PDUAddress.To4().String() + FlowDespcription.Dst = smContext.PDUAddress.To4().String() FlowDespcriptionStr, err := flowdesc.Encode(FlowDespcription) if err != nil { @@ -305,13 +305,13 @@ func UpdateRANAndIUPFUpLink(smContext *context.SMContext) { if _, exist := bpMGR.UpdatedBranchingPoint[curDPNode.UPF]; exist { // add SDF Filter - // new IPFilterRule with action:"permit" and diection:"out" + // new IPFilterRule with action:"permit" and direction:"out" FlowDespcription := flowdesc.NewIPFilterRule() - FlowDespcription.Dst = dest.DestinationIP + FlowDespcription.Src = dest.DestinationIP if dstPort, err := flowdesc.ParsePorts(dest.DestinationPort); err != nil { - FlowDespcription.DstPorts = dstPort + FlowDespcription.SrcPorts = dstPort } - FlowDespcription.Src = smContext.PDUAddress.To4().String() + FlowDespcription.Dst = smContext.PDUAddress.To4().String() FlowDespcriptionStr, err := flowdesc.Encode(FlowDespcription) if err != nil { @@ -328,6 +328,7 @@ func UpdateRANAndIUPFUpLink(smContext *context.SMContext) { FlowDescription: []byte(FlowDespcriptionStr), } } + UPLinkPDR.Precedence = 30 pfcpState := &PFCPState{ upf: curDPNode.UPF, diff --git a/internal/util/qos_convert.go b/internal/util/qos_convert.go index 602814a1..c9adcd60 100644 --- a/internal/util/qos_convert.go +++ b/internal/util/qos_convert.go @@ -1,24 +1,29 @@ package util import ( + "errors" "strconv" "strings" "github.com/free5gc/ngap/ngapType" ) -func BitRateTokbps(bitrate string) uint64 { +func BitRateTokbps(bitrate string) (uint64, error) { s := strings.Split(bitrate, " ") var kbps uint64 var digit int if n, err := strconv.Atoi(s[0]); err != nil { - return 0 + return 0, nil } else { digit = n } + if len(s) == 1 { + return 0, errors.New("cannot get the unit of ULMBR/DLMBR/ULGBR/DLGBR, please check the settings in web console") + } + switch s[1] { case "bps": kbps = uint64(digit / 1000) @@ -31,7 +36,7 @@ func BitRateTokbps(bitrate string) uint64 { case "Tbps": kbps = uint64(digit * 1000000000) } - return kbps + return kbps, nil } func BitRateTombps(bitrate string) uint16 { diff --git a/internal/util/qos_convert_test.go b/internal/util/qos_convert_test.go new file mode 100644 index 00000000..b8aa4af0 --- /dev/null +++ b/internal/util/qos_convert_test.go @@ -0,0 +1,104 @@ +package util_test + +import ( + "testing" + + "github.com/free5gc/smf/internal/util" +) + +func TestBitRateToKbpsWithValidBpsBitRateShouldReturnValidKbpsBitRate(t *testing.T) { + var bitrate string = "1000 bps" + var correctBitRateKbps uint64 = 1 + + bitrateKbps, err := util.BitRateTokbps(bitrate) + + t.Log("Check: err should be nil since act should work correctly.") + if err != nil { + t.Errorf("Error: err should be nil but it returns %s", err) + } + t.Log("Check: convert should act correctly.") + if bitrateKbps != correctBitRateKbps { + t.Errorf("Error: bitrate convert failed. Expect: %d. Actually: %d", correctBitRateKbps, bitrateKbps) + } + t.Log("Passed.") +} + +func TestBitRateToKbpsWithValidKbpsBitRateShouldReturnValidKbpsBitRate(t *testing.T) { + var bitrate string = "1000 Kbps" + var correctBitRateKbps uint64 = 1000 + + bitrateKbps, err := util.BitRateTokbps(bitrate) + + t.Log("Check: err should be nil since act should work correctly.") + if err != nil { + t.Errorf("Error: err should be nil but it returns %s", err) + } + t.Log("Check: convert should act correctly.") + if bitrateKbps != correctBitRateKbps { + t.Errorf("Error: bitrate convert failed. Expect: %d. Actually: %d", correctBitRateKbps, bitrateKbps) + } + t.Log("Passed.") +} + +func TestBitRateToKbpsWithValidMbpsBitRateShouldReturnValidKbpsBitRate(t *testing.T) { + var bitrate string = "1000 Mbps" + var correctBitRateKbps uint64 = 1000000 + + bitrateKbps, err := util.BitRateTokbps(bitrate) + + t.Log("Check: err should be nil since act should work correctly.") + if err != nil { + t.Errorf("Error: err should be nil but it returns %s", err) + } + t.Log("Check: convert should act correctly.") + if bitrateKbps != correctBitRateKbps { + t.Errorf("Error: bitrate convert failed. Expect: %d. Actually: %d", correctBitRateKbps, bitrateKbps) + } + t.Log("Passed.") +} + +func TestBitRateToKbpsWithValidGbpsBitRateShouldReturnValidKbpsBitRate(t *testing.T) { + var bitrate string = "1000 Gbps" + var correctBitRateKbps uint64 = 1000000000 + + bitrateKbps, err := util.BitRateTokbps(bitrate) + + t.Log("Check: err should be nil since act should work correctly.") + if err != nil { + t.Errorf("Error: err should be nil but it returns %s", err) + } + t.Log("Check: convert should act correctly.") + if bitrateKbps != correctBitRateKbps { + t.Errorf("Error: bitrate convert failed. Expect: %d. Actually: %d", correctBitRateKbps, bitrateKbps) + } + t.Log("Passed.") +} + +func TestBitRateToKbpsWithValidTbpsBitRateShouldReturnValidKbpsBitRate(t *testing.T) { + var bitrate string = "1000 Tbps" + var correctBitRateKbps uint64 = 1000000000000 + + bitrateKbps, err := util.BitRateTokbps(bitrate) + + t.Log("Check: err should be nil since act should work correctly.") + if err != nil { + t.Errorf("Error: err should be nil but it returns %s", err) + } + t.Log("Check: convert should act correctly.") + if bitrateKbps != correctBitRateKbps { + t.Errorf("Error: bitrate convert failed. Expect: %d. Actually: %d", correctBitRateKbps, bitrateKbps) + } + t.Log("Passed.") +} + +func TestBitRateToKbpsWithInvalidBitRateShouldReturnError(t *testing.T) { + var bitrate string = "1000" // The unit is absent. It should raise error for `BitRateToKbps`. + + _, err := util.BitRateTokbps(bitrate) + + t.Log("Check: err should not be nil.") + if err == nil { + t.Error("Error: err should not be nil.") + } + t.Log("Passed.") +} diff --git a/pkg/factory/config.go b/pkg/factory/config.go index 6c7e9fdd..b0ee3267 100644 --- a/pkg/factory/config.go +++ b/pkg/factory/config.go @@ -132,7 +132,7 @@ func (c *Configuration) validate() (bool, error) { } for _, snssaiInfo := range c.SNssaiInfo { - if result, err := snssaiInfo.validate(); err != nil { + if result, err := snssaiInfo.Validate(); err != nil { return result, err } } @@ -166,17 +166,19 @@ type SnssaiInfoItem struct { DnnInfos []*SnssaiDnnInfoItem `yaml:"dnnInfos" valid:"required"` } -func (s *SnssaiInfoItem) validate() (bool, error) { +func (s *SnssaiInfoItem) Validate() (bool, error) { if snssai := s.SNssai; snssai != nil { if result := (snssai.Sst >= 0 && snssai.Sst <= 255); !result { err := errors.New("Invalid sNssai.Sst: " + strconv.Itoa(int(snssai.Sst)) + ", should be in range 0~255.") return false, err } - if result := govalidator.StringMatches(snssai.Sd, "^[0-9A-Fa-f]{6}$"); !result { - err := errors.New("Invalid sNssai.Sd: " + snssai.Sd + - ", should be 3 bytes hex string and in range 000000~FFFFFF.") - return false, err + if snssai.Sd != "" { + if result := govalidator.StringMatches(snssai.Sd, "^[0-9A-Fa-f]{6}$"); !result { + err := errors.New("Invalid sNssai.Sd: " + snssai.Sd + + ", should be 3 bytes hex string and in range 000000~FFFFFF.") + return false, err + } } } @@ -489,7 +491,7 @@ func (u *UPNode) validate() (bool, error) { }) for _, snssaiInfo := range u.SNssaiInfos { - if result, err := snssaiInfo.validate(); err != nil { + if result, err := snssaiInfo.Validate(); err != nil { return result, err } } @@ -547,17 +549,19 @@ type SnssaiUpfInfoItem struct { DnnUpfInfoList []*DnnUpfInfoItem `json:"dnnUpfInfoList" yaml:"dnnUpfInfoList" valid:"required"` } -func (s *SnssaiUpfInfoItem) validate() (bool, error) { +func (s *SnssaiUpfInfoItem) Validate() (bool, error) { if s.SNssai != nil { if result := (s.SNssai.Sst >= 0 && s.SNssai.Sst <= 255); !result { err := errors.New("Invalid sNssai.Sst: " + strconv.Itoa(int(s.SNssai.Sst)) + ", should be in range 0~255.") return false, err } - if result := govalidator.StringMatches(s.SNssai.Sd, "^[0-9A-Fa-f]{6}$"); !result { - err := errors.New("Invalid sNssai.Sd: " + s.SNssai.Sd + - ", should be 3 bytes hex string and in range 000000~FFFFFF.") - return false, err + if s.SNssai.Sd != "" { + if result := govalidator.StringMatches(s.SNssai.Sd, "^[0-9A-Fa-f]{6}$"); !result { + err := errors.New("Invalid sNssai.Sd: " + s.SNssai.Sd + + ", should be 3 bytes hex string and in range 000000~FFFFFF.") + return false, err + } } } diff --git a/pkg/factory/config_test.go b/pkg/factory/config_test.go new file mode 100644 index 00000000..67b56b95 --- /dev/null +++ b/pkg/factory/config_test.go @@ -0,0 +1,106 @@ +package factory_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/free5gc/openapi/models" + "github.com/free5gc/smf/pkg/factory" +) + +func TestSnssaiInfoItem(t *testing.T) { + testcase := []struct { + Name string + Snssai *models.Snssai + DnnInfos []*factory.SnssaiDnnInfoItem + }{ + { + Name: "Default", + Snssai: &models.Snssai{ + Sst: int32(1), + Sd: "010203", + }, + DnnInfos: []*factory.SnssaiDnnInfoItem{ + { + Dnn: "internet", + DNS: &factory.DNS{ + IPv4Addr: "8.8.8.8", + }, + }, + }, + }, + { + Name: "Empty SD", + Snssai: &models.Snssai{ + Sst: int32(1), + }, + DnnInfos: []*factory.SnssaiDnnInfoItem{ + { + Dnn: "internet2", + DNS: &factory.DNS{ + IPv4Addr: "1.1.1.1", + }, + }, + }, + }, + } + + for _, tc := range testcase { + t.Run(tc.Name, func(t *testing.T) { + snssaiInfoItem := factory.SnssaiInfoItem{ + SNssai: tc.Snssai, + DnnInfos: tc.DnnInfos, + } + + ok, err := snssaiInfoItem.Validate() + require.True(t, ok) + require.Nil(t, err) + }) + } +} + +func TestSnssaiUpfInfoItem(t *testing.T) { + testcase := []struct { + Name string + Snssai *models.Snssai + DnnInfos []*factory.DnnUpfInfoItem + }{ + { + Name: "Default", + Snssai: &models.Snssai{ + Sst: int32(1), + Sd: "010203", + }, + DnnInfos: []*factory.DnnUpfInfoItem{ + { + Dnn: "internet", + }, + }, + }, + { + Name: "Empty SD", + Snssai: &models.Snssai{ + Sst: int32(1), + }, + DnnInfos: []*factory.DnnUpfInfoItem{ + { + Dnn: "internet2", + }, + }, + }, + } + + for _, tc := range testcase { + t.Run(tc.Name, func(t *testing.T) { + snssaiInfoItem := factory.SnssaiUpfInfoItem{ + SNssai: tc.Snssai, + DnnUpfInfoList: tc.DnnInfos, + } + + ok, err := snssaiInfoItem.Validate() + require.True(t, ok) + require.Nil(t, err) + }) + } +}