diff --git a/backend/onetime/seed-data/main.go b/backend/onetime/seed-data/main.go index 8fa150a7..f93eef40 100644 --- a/backend/onetime/seed-data/main.go +++ b/backend/onetime/seed-data/main.go @@ -3,11 +3,14 @@ package main import ( + "context" + "os" + "github.com/joho/godotenv" statev1 "github.com/ueckoken/plarail2023/backend/spec/state/v1" - "github.com/ueckoken/plarail2023/backend/state-manager/pkg/db" + dbhandler "github.com/ueckoken/plarail2023/backend/state-manager/pkg/db" + "go.mongodb.org/mongo-driver/mongo/options" "gopkg.in/yaml.v3" - "os" ) type Station string @@ -26,8 +29,11 @@ func main() { if err := godotenv.Load(".env"); err != nil { panic(err) } - db.Open() - defer db.C() + db, err := dbhandler.Open(context.TODO(), options.Client().ApplyURI(os.Getenv("MONGODB_URI"))) + if err != nil { + return + } + defer db.Close() data := &Seed{} b, _ := os.ReadFile("./data/nt-tokyo.yaml") if err := yaml.Unmarshal(b, data); err != nil { diff --git a/backend/state-manager/cmd/main.go b/backend/state-manager/cmd/main.go index b77e3700..b26069ed 100644 --- a/backend/state-manager/cmd/main.go +++ b/backend/state-manager/cmd/main.go @@ -10,13 +10,16 @@ import ( "os/signal" "time" + mqtt "github.com/eclipse/paho.mqtt.golang" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/httplog/v2" "github.com/joho/godotenv" "github.com/ueckoken/plarail2023/backend/spec/state/v1/statev1connect" connectHandler "github.com/ueckoken/plarail2023/backend/state-manager/pkg/connect" + "github.com/ueckoken/plarail2023/backend/state-manager/pkg/db" "github.com/ueckoken/plarail2023/backend/state-manager/pkg/mqtt_handler" + "go.mongodb.org/mongo-driver/mongo/options" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" "golang.org/x/sync/errgroup" @@ -84,18 +87,44 @@ func main() { } baseCtx, cancel := context.WithCancel(context.Background()) defer cancel() - signalCtx, stop := signal.NotifyContext(baseCtx, os.Interrupt) + ctx, stop := signal.NotifyContext(baseCtx, os.Interrupt) defer stop() go func() { - <-signalCtx.Done() - slog.Default().Info("signal received") + <-ctx.Done() + slog.Default().Info("signal received or canceled") }() - eg, ctx := errgroup.WithContext(signalCtx) + eg, ctx := errgroup.WithContext(ctx) + + DBOpts := options.Client().ApplyURI(os.Getenv("MONGODB_URI")) + DBHandler, err := db.Open(ctx, DBOpts) + if err != nil { + slog.Default().Error("database connection failed", slog.Any("err", err)) + cancel() + return + } + mqttClientOpts := mqtt.NewClientOptions() + mqttClientOpts.AddBroker(os.Getenv("MQTT_BROKER_ADDR")) + mqttClientOpts.Username = os.Getenv("MQTT_USERNAME") + mqttClientOpts.Password = os.Getenv("MQTT_PASSWORD") + mqttClientOpts.ClientID = os.Getenv("MQTT_CLIENT_ID") + + mqttHandler, err := mqtt_handler.NewHandler(mqttClientOpts, DBHandler) + if err != nil { + slog.Default().Error("mqtt create client or handler failed,", slog.Any("err", err)) + cancel() + return + } + eg.Go(func() error { + slog.Default().Info("start mqtt handler") + return mqttHandler.Start(ctx) + }) r := chi.NewRouter() // r.Use(middleware.Recoverer) r.Use(middleware.Heartbeat("/debug/ping")) + r.Mount("/debug", middleware.Profiler()) + r.Handle(statev1connect.NewStateManagerServiceHandler(&connectHandler.StateManagerServer{DBHandler: DBHandler, MqttHandler: mqttHandler})) r.Use(httplog.RequestLogger( httplog.NewLogger( "http_server", @@ -142,7 +171,7 @@ func main() { //go operation.Handler() eg.Go(func() error { slog.Default().Info("start mqtt handler") - err := mqtt_handler.StartHandler(ctx) + err := mqttHandler.Start(ctx) return fmt.Errorf("mqtt handler error: %w", err) }) diff --git a/backend/state-manager/pkg/connect/connect_handler.go b/backend/state-manager/pkg/connect/connect_handler.go index ad6e1f55..cfc00599 100644 --- a/backend/state-manager/pkg/connect/connect_handler.go +++ b/backend/state-manager/pkg/connect/connect_handler.go @@ -8,11 +8,15 @@ import ( "connectrpc.com/connect" statev1 "github.com/ueckoken/plarail2023/backend/spec/state/v1" - db "github.com/ueckoken/plarail2023/backend/state-manager/pkg/db" + "github.com/ueckoken/plarail2023/backend/state-manager/pkg/db" "github.com/ueckoken/plarail2023/backend/state-manager/pkg/mqtt_handler" ) -type StateManagerServer struct{} + +type StateManagerServer struct { + DBHandler *db.DBHandler + MqttHandler *mqtt_handler.Handler +} /* Block @@ -23,9 +27,7 @@ func (s *StateManagerServer) GetBlockStates( ctx context.Context, req *connect.Request[statev1.GetBlockStatesRequest], ) (*connect.Response[statev1.GetBlockStatesResponse], error) { - defer db.C() - db.Open() - blockStates, err := db.GetBlocks() + blockStates, err := s.DBHandler.GetBlocks() if err != nil { err = connect.NewError( connect.CodeUnknown, @@ -55,9 +57,7 @@ func (s *StateManagerServer) UpdateBlockState( ctx context.Context, req *connect.Request[statev1.UpdateBlockStateRequest], ) (*connect.Response[statev1.UpdateBlockStateResponse], error) { - defer db.C() - db.Open() - err := db.UpdateBlock(req.Msg.State) + err := s.DBHandler.UpdateBlock(req.Msg.State) if err != nil { err = connect.NewError( connect.CodeUnknown, @@ -77,9 +77,7 @@ func (s *StateManagerServer) UpdatePointState( ctx context.Context, req *connect.Request[statev1.UpdatePointStateRequest], ) (*connect.Response[statev1.UpdatePointStateResponse], error) { - defer db.C() - db.Open() - err := db.UpdatePoint(req.Msg.State) + err := s.DBHandler.UpdatePoint(req.Msg.State) if err != nil { err = connect.NewError( connect.CodeUnknown, @@ -88,7 +86,7 @@ func (s *StateManagerServer) UpdatePointState( slog.Default().Error("db error", err) return nil, err } - mqtt_handler.NotifyStateUpdate("point", req.Msg.State.Id, req.Msg.State.State.String()) + s.MqttHandler.NotifyStateUpdate("point", req.Msg.State.Id, req.Msg.State.State.String()) return connect.NewResponse(&statev1.UpdatePointStateResponse{}), nil } @@ -112,9 +110,7 @@ func (s *StateManagerServer) UpdateStopState( ctx context.Context, req *connect.Request[statev1.UpdateStopStateRequest], ) (*connect.Response[statev1.UpdateStopStateResponse], error) { - db.Open() - defer db.C() - err := db.UpdateStop(req.Msg.State) + err := s.DBHandler.UpdateStop(req.Msg.State) if err != nil { err = connect.NewError( connect.CodeUnknown, @@ -123,7 +119,7 @@ func (s *StateManagerServer) UpdateStopState( slog.Default().Error("db connection error", err) return nil, err } - mqtt_handler.NotifyStateUpdate("stop", req.Msg.State.Id, req.Msg.State.State.String()) + s.MqttHandler.NotifyStateUpdate("stop", req.Msg.State.Id, req.Msg.State.State.String()) return connect.NewResponse(&statev1.UpdateStopStateResponse{}), nil } diff --git a/backend/state-manager/pkg/db/db.go b/backend/state-manager/pkg/db/db.go index 03995226..80903aab 100644 --- a/backend/state-manager/pkg/db/db.go +++ b/backend/state-manager/pkg/db/db.go @@ -6,9 +6,9 @@ package db import ( "context" + "fmt" "log" "log/slog" - "os" statev1 "github.com/ueckoken/plarail2023/backend/spec/state/v1" "go.mongodb.org/mongo-driver/bson" @@ -17,30 +17,33 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -// mongodb connection -var client *mongo.Client +type DBHandler struct { + stateManagerDB *mongo.Database +} -func Open() { +func Open(ctx context.Context, opts *options.ClientOptions) (*DBHandler, error) { var err error slog.Default().Debug("Connecting to MongoDB...") - uri := os.Getenv("MONGODB_URI") - if uri == "" { - log.Fatal("No MONGODB_URI set") - } - // TODO: Open関数がctxを受けるようにして、そのctxの子contextをDBのコネクションに使う - client, err = mongo.Connect(context.TODO(), options.Client().ApplyURI(uri)) + client, err := mongo.Connect(ctx, opts) if err != nil { - // TODO: return err, do not panic! slog.Default().Error("database connection failed", slog.Any("err", err)) - panic(err) + return nil, err + } + + if err := client.Ping(ctx, nil); err != nil { + slog.Error("DB ping failed", slog.Any("err", err)) + return nil, fmt.Errorf("DB Ping failed `%w`", err) } slog.Default().Debug("connected to DB") + return &DBHandler{ + stateManagerDB: client.Database("state-manager"), + }, nil } -func C() { +func (db *DBHandler) Close() { slog.Default().Debug("Closing connection to DB...") // TODO: contextを受けて、その子contextをDBクライアントに渡す - if err := client.Disconnect(context.TODO()); err != nil { + if err := db.stateManagerDB.Client().Disconnect(context.TODO()); err != nil { slog.Default().Error("DB Connection Closing failed") log.Println(err) } @@ -51,8 +54,8 @@ func C() { Point */ -func UpdatePoint(PointAndState *statev1.PointAndState) error { - collection := client.Database("state-manager").Collection("points") +func (db *DBHandler) UpdatePoint(PointAndState *statev1.PointAndState) error { + collection := db.stateManagerDB.Collection("points") _, err := collection.UpdateOne( context.Background(), bson.M{"id": PointAndState.Id}, @@ -64,8 +67,8 @@ func UpdatePoint(PointAndState *statev1.PointAndState) error { return nil } -func AddPoint(PointAndState *statev1.PointAndState) error { - collection := client.Database("state-manager").Collection("points") +func (db *DBHandler) AddPoint(PointAndState *statev1.PointAndState) error { + collection := db.stateManagerDB.Collection("points") _, err := collection.InsertOne(context.Background(), PointAndState) if err != nil { return err @@ -73,8 +76,8 @@ func AddPoint(PointAndState *statev1.PointAndState) error { return nil } -func GetPoint(pointId string) (*statev1.PointAndState, error) { - collection := client.Database("state-manager").Collection("points") +func (db *DBHandler) GetPoint(pointId string) (*statev1.PointAndState, error) { + collection := db.stateManagerDB.Collection("points") var result *statev1.PointAndState err := collection.FindOne(context.Background(), bson.M{"id": pointId}).Decode(&result) if err != nil { @@ -83,8 +86,8 @@ func GetPoint(pointId string) (*statev1.PointAndState, error) { return result, nil } -func GetPoints() []*statev1.PointAndState { - collection := client.Database("state-manager").Collection("points") +func (db *DBHandler) GetPoints() []*statev1.PointAndState { + collection := db.stateManagerDB.Collection("points") cursor, err := collection.Find(context.Background(), bson.M{}) if err != nil { slog.Default().Warn("Get Points failed", slog.Any("err", err)) @@ -101,8 +104,8 @@ func GetPoints() []*statev1.PointAndState { Stop */ -func UpdateStop(stop *statev1.StopAndState) error { - collection := client.Database("state-manager").Collection("stops") +func (db *DBHandler) UpdateStop(stop *statev1.StopAndState) error { + collection := db.stateManagerDB.Collection("stops") _, err := collection.UpdateOne( context.Background(), @@ -116,8 +119,8 @@ func UpdateStop(stop *statev1.StopAndState) error { return nil } -func AddStop(stop *statev1.StopAndState) error { - collection := client.Database("state-manager").Collection("stops") +func (db *DBHandler) AddStop(stop *statev1.StopAndState) error { + collection := db.stateManagerDB.Collection("stops") _, err := collection.InsertOne(context.Background(), stop) if err != nil { return err @@ -125,8 +128,8 @@ func AddStop(stop *statev1.StopAndState) error { return nil } -func GetStop(stopId string) (*statev1.StopAndState, error) { - collection := client.Database("state-manager").Collection("stops") +func (db *DBHandler) GetStop(stopId string) (*statev1.StopAndState, error) { + collection := db.stateManagerDB.Collection("stops") var result *statev1.StopAndState err := collection.FindOne(context.Background(), bson.M{"id": stopId}).Decode(&result) if err != nil { @@ -135,8 +138,8 @@ func GetStop(stopId string) (*statev1.StopAndState, error) { return result, nil } -func GetStops() []*statev1.StopAndState { - collection := client.Database("state-manager").Collection("stops") +func (db *DBHandler) GetStops() []*statev1.StopAndState { + collection := db.stateManagerDB.Collection("stops") cursor, err := collection.Find(context.Background(), bson.M{}) if err != nil { panic(err) @@ -152,8 +155,8 @@ func GetStops() []*statev1.StopAndState { Block */ -func AddBlock(block *statev1.BlockState) error { - collection := client.Database("state-manager").Collection("blocks") +func (db *DBHandler) AddBlock(block *statev1.BlockState) error { + collection := db.stateManagerDB.Collection("blocks") _, err := collection.InsertOne(context.Background(), block) if err != nil { return err @@ -161,8 +164,8 @@ func AddBlock(block *statev1.BlockState) error { return nil } -func UpdateBlock(block *statev1.BlockState) error { - collection := client.Database("state-manager").Collection("blocks") +func (db *DBHandler) UpdateBlock(block *statev1.BlockState) error { + collection := db.stateManagerDB.Collection("blocks") _, err := collection.UpdateOne( context.Background(), bson.M{"blockid": block.BlockId}, @@ -174,8 +177,8 @@ func UpdateBlock(block *statev1.BlockState) error { return nil } -func GetBlock(blockId string) (*statev1.BlockState, error) { - collection := client.Database("state-manager").Collection("blocks") +func (db *DBHandler) GetBlock(blockId string) (*statev1.BlockState, error) { + collection := db.stateManagerDB.Collection("blocks") var result *statev1.BlockState err := collection.FindOne(context.Background(), bson.M{"blockid": blockId}).Decode(&result) if err != nil { @@ -184,8 +187,8 @@ func GetBlock(blockId string) (*statev1.BlockState, error) { return result, nil } -func GetBlocks() ([]*statev1.BlockState, error) { - collection := client.Database("state-manager").Collection("blocks") +func (db *DBHandler) GetBlocks() ([]*statev1.BlockState, error) { + collection := db.stateManagerDB.Collection("blocks") cursor, err := collection.Find(context.Background(), bson.M{}) if err != nil { return nil, err diff --git a/backend/state-manager/pkg/db/db_test.go b/backend/state-manager/pkg/db/db_test.go index 59183621..7e71f5f7 100644 --- a/backend/state-manager/pkg/db/db_test.go +++ b/backend/state-manager/pkg/db/db_test.go @@ -1,136 +1,136 @@ package db -import ( - "testing" +// import ( +// "testing" - "github.com/joho/godotenv" - statev1 "github.com/ueckoken/plarail2023/backend/spec/state/v1" -) +// "github.com/joho/godotenv" +// statev1 "github.com/ueckoken/plarail2023/backend/spec/state/v1" +// ) -func Test_ConnectDB(t *testing.T) { - defer C() - err := godotenv.Load("../../cmd/.env") - if err != nil { - panic(err) - } - Open() -} +// func Test_ConnectDB(t *testing.T) { +// defer Close() +// err := godotenv.Load("../../cmd/.env") +// if err != nil { +// panic(err) +// } +// Open() +// } -func Test_SetPoint(t *testing.T) { - defer C() - err := godotenv.Load("../../cmd/.env") - if err != nil { - panic(err) - } - Open() - SetPoint(&statev1.PointAndState{ - Id: "test", - State: statev1.PointStateEnum_POINT_STATE_REVERSE, - }) - point, err := GetPoint("test") - if err != nil { - t.Fatal("error") - } - if point.State != statev1.PointStateEnum_POINT_STATE_REVERSE { - t.Fatal("point state is not reverse") - } -} +// func Test_SetPoint(t *testing.T) { +// defer Close() +// err := godotenv.Load("../../cmd/.env") +// if err != nil { +// panic(err) +// } +// Open() +// SetPoint(&statev1.PointAndState{ +// Id: "test", +// State: statev1.PointStateEnum_POINT_STATE_REVERSE, +// }) +// point, err := GetPoint("test") +// if err != nil { +// t.Fatal("error") +// } +// if point.State != statev1.PointStateEnum_POINT_STATE_REVERSE { +// t.Fatal("point state is not reverse") +// } +// } -func Test_SetStop(t *testing.T) { - defer C() - err := godotenv.Load("../../cmd/.env") - if err != nil { - panic(err) - } - Open() - SetStop(&statev1.StopAndState{ - Id: "test", - State: statev1.StopStateEnum_STOP_STATE_GO, - }) - stop, err := GetStop("test") - if err != nil { - t.Fatal("error") - } - if stop.State != statev1.StopStateEnum_STOP_STATE_GO { - t.Fatal("point state is not stop") - } -} +// func Test_SetStop(t *testing.T) { +// defer Close() +// err := godotenv.Load("../../cmd/.env") +// if err != nil { +// panic(err) +// } +// Open() +// SetStop(&statev1.StopAndState{ +// Id: "test", +// State: statev1.StopStateEnum_STOP_STATE_GO, +// }) +// stop, err := GetStop("test") +// if err != nil { +// t.Fatal("error") +// } +// if stop.State != statev1.StopStateEnum_STOP_STATE_GO { +// t.Fatal("point state is not stop") +// } +// } -func Test_GetPoints(t *testing.T) { - defer C() - err := godotenv.Load("../../cmd/.env") - if err != nil { - panic(err) - } - Open() - SetPoint(&statev1.PointAndState{ - Id: "test", - State: statev1.PointStateEnum_POINT_STATE_REVERSE, - }) - SetPoint(&statev1.PointAndState{ - Id: "test2", - State: statev1.PointStateEnum_POINT_STATE_NORMAL, - }) - points := GetPoints() - if len(points) < 2 { - t.Fatal("points length is not larger than 2") - } -} +// func Test_GetPoints(t *testing.T) { +// defer Close() +// err := godotenv.Load("../../cmd/.env") +// if err != nil { +// panic(err) +// } +// Open() +// SetPoint(&statev1.PointAndState{ +// Id: "test", +// State: statev1.PointStateEnum_POINT_STATE_REVERSE, +// }) +// SetPoint(&statev1.PointAndState{ +// Id: "test2", +// State: statev1.PointStateEnum_POINT_STATE_NORMAL, +// }) +// points := GetPoints() +// if len(points) < 2 { +// t.Fatal("points length is not larger than 2") +// } +// } -func Test_GetStops(t *testing.T) { - defer C() - err := godotenv.Load("../../cmd/.env") - if err != nil { - panic(err) - } - Open() - SetStop(&statev1.StopAndState{ - Id: "test", - State: statev1.StopStateEnum_STOP_STATE_GO, - }) - SetStop(&statev1.StopAndState{ - Id: "test2", - State: statev1.StopStateEnum_STOP_STATE_STOP, - }) - stops := GetStops() - if len(stops) < 2 { - t.Fatal("stops length is not larger than 2") - } -} +// func Test_GetStops(t *testing.T) { +// defer Close() +// err := godotenv.Load("../../cmd/.env") +// if err != nil { +// panic(err) +// } +// Open() +// SetStop(&statev1.StopAndState{ +// Id: "test", +// State: statev1.StopStateEnum_STOP_STATE_GO, +// }) +// SetStop(&statev1.StopAndState{ +// Id: "test2", +// State: statev1.StopStateEnum_STOP_STATE_STOP, +// }) +// stops := GetStops() +// if len(stops) < 2 { +// t.Fatal("stops length is not larger than 2") +// } +// } -func Test_GetBlocks(t *testing.T) { - defer C() - err := godotenv.Load("../../cmd/.env") - if err != nil { - panic(err) - } - Open() - err = SetBlock(&statev1.BlockState{ - BlockId: "test", - State: statev1.BlockStateEnum_BLOCK_STATE_OPEN, - }) - if err != nil { - t.Fatal("error") - } - err = SetBlock(&statev1.BlockState{ - BlockId: "test2", - State: statev1.BlockStateEnum_BLOCK_STATE_CLOSE, - }) - if err != nil { - t.Fatal("error") - } - block, err := GetBlock("test") - if err != nil { - t.Fatal("error") - } - if block.State != statev1.BlockStateEnum_BLOCK_STATE_OPEN { - t.Fatal("block state is not open") - } - blocks, err := GetBlocks() - if err != nil { - t.Fatal("error") - } - if len(blocks) < 2 { - t.Fatal("blocks length is not larger than 2") - } -} +// func Test_GetBlocks(t *testing.T) { +// defer Close() +// err := godotenv.Load("../../cmd/.env") +// if err != nil { +// panic(err) +// } +// Open() +// err = SetBlock(&statev1.BlockState{ +// BlockId: "test", +// State: statev1.BlockStateEnum_BLOCK_STATE_OPEN, +// }) +// if err != nil { +// t.Fatal("error") +// } +// err = SetBlock(&statev1.BlockState{ +// BlockId: "test2", +// State: statev1.BlockStateEnum_BLOCK_STATE_CLOSE, +// }) +// if err != nil { +// t.Fatal("error") +// } +// block, err := GetBlock("test") +// if err != nil { +// t.Fatal("error") +// } +// if block.State != statev1.BlockStateEnum_BLOCK_STATE_OPEN { +// t.Fatal("block state is not open") +// } +// blocks, err := GetBlocks() +// if err != nil { +// t.Fatal("error") +// } +// if len(blocks) < 2 { +// t.Fatal("blocks length is not larger than 2") +// } +// } diff --git a/backend/state-manager/pkg/mqtt_handler/mqtt_handler.go b/backend/state-manager/pkg/mqtt_handler/mqtt_handler.go index 8f23b994..92cf4c43 100644 --- a/backend/state-manager/pkg/mqtt_handler/mqtt_handler.go +++ b/backend/state-manager/pkg/mqtt_handler/mqtt_handler.go @@ -14,66 +14,61 @@ import ( "github.com/ueckoken/plarail2023/backend/state-manager/pkg/db" ) -var cc mqtt.Client - -func MakeClient() mqtt.Client { - var opts = mqtt.NewClientOptions() - opts.AddBroker(os.Getenv("MQTT_BROKER_ADDR")) - opts.Username = os.Getenv("MQTT_USERNAME") - opts.Password = os.Getenv("MQTT_PASSWORD") - opts.ClientID = os.Getenv("MQTT_CLIENT_ID") - - cc = mqtt.NewClient(opts) - - if token := cc.Connect(); token.Wait() && token.Error() != nil { - log.Fatalf("Mqtt error: %s", token.Error()) - } - - return cc +type Handler struct { + client mqtt.Client + dbHandler *db.DBHandler } -func Subscribe(cc mqtt.Client, topic []string, f mqtt.MessageHandler) { - qos := byte(1) - - filters := make(map[string]byte) - for _, t := range topic { - filters[t] = qos - } +func NewHandler(clientOpts *mqtt.ClientOptions, dbHandler *db.DBHandler) (*Handler, error) { + cc := mqtt.NewClient(clientOpts) - subscribeToken := cc.SubscribeMultiple(filters, f) - if subscribeToken.Wait() && subscribeToken.Error() != nil { - log.Fatal(subscribeToken.Error()) + if token := cc.Connect(); token.Wait() && token.Error() != nil { + return nil, fmt.Errorf("mqtt error: %w", token.Error()) } + return &Handler{client: cc, dbHandler: dbHandler}, nil } -func Send(cc mqtt.Client, topic string, payload string) { - token := cc.Publish(topic, 0, false, payload) - token.Wait() -} - -func StartHandler(ctx context.Context) error { +func (h *Handler) Start(ctx context.Context) error { msgCh := make(chan mqtt.Message) var f mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) { msgCh <- msg } - cc := MakeClient() - Subscribe(cc, []string{"point/#", "stop/#", "block/#", "train/#", "setting/#"}, f) + h.Subscribe([]string{"point/#", "stop/#", "block/#", "train/#", "setting/#"}, f) for { select { case msg := <-msgCh: // if topic start with "point/" log.Printf("Received message: %s from topic: %s\n", msg.Payload(), msg.Topic()) - topicHandler(cc, msg) + h.topicHandler(msg) case <-ctx.Done(): slog.Default().Info("Interrupted at mqtt_handler") - cc.Disconnect(1000) + h.client.Disconnect(1000) slog.Default().Info("Disconnected from mqtt broker") - return ctx.Err() + return nil } } } +func (h *Handler) Subscribe(topic []string, f mqtt.MessageHandler) { + qos := byte(1) + + filters := make(map[string]byte) + for _, t := range topic { + filters[t] = qos + } + + subscribeToken := h.client.SubscribeMultiple(filters, f) + if subscribeToken.Wait() && subscribeToken.Error() != nil { + log.Fatal(subscribeToken.Error()) + } +} + +func (h *Handler) Send(topic string, payload string) { + token := h.client.Publish(topic, 0, false, payload) + token.Wait() +} + /* Endpoint {target}/{pointId}/get @@ -81,7 +76,7 @@ func StartHandler(ctx context.Context) error { {target}/{pointId}/update */ -func topicHandler(cc mqtt.Client, msg mqtt.Message) { +func (h *Handler) topicHandler(msg mqtt.Message) { // Handle by Path arr := strings.Split(msg.Topic(), "/") target := arr[0] @@ -94,49 +89,49 @@ func topicHandler(cc mqtt.Client, msg mqtt.Message) { switch method { case "get": - getState(cc, target, id) + h.getState(target, id) case "delta": - getDelta(cc, target, id) + h.getDelta(target, id) case "update": - updateState(cc, target, id, msg.Payload()) + h.updateState(target, id, msg.Payload()) } } -func NotifyStateUpdate(target string, id string, state string) { - token := cc.Publish(target+"/"+id+"/delta", 0, false, state) +func (h *Handler) NotifyStateUpdate(target string, id string, state string) { + token := h.client.Publish(target+"/"+id+"/delta", 0, false, state) token.Wait() } -func getState(cc mqtt.Client, target string, id string) { - defer db.C() - db.Open() - +func (h *Handler) getState(target string, id string) { switch target { case "point": - point, err := db.GetPoint(id) + point, err := h.dbHandler.GetPoint(id) if err != nil { log.Fatal(err) } log.Println(point) - token := cc.Publish("point/"+id+"/get/accepted", 0, false, point.State.String()) + token := h.client.Publish("point/"+id+"/get/accepted", 0, false, point.State.String()) token.Wait() case "stop": - stop, err := db.GetStop(id) + stop, err := h.dbHandler.GetStop(id) if err != nil { log.Fatal(err) } log.Println(stop) - token := cc.Publish("stop/"+id+"/get/accepted", 0, false, stop.State.String()) + token := h.client.Publish("stop/"+id+"/get/accepted", 0, false, stop.State.String()) token.Wait() case "block": - block, err := db.GetBlock(id) + block, err := h.dbHandler.GetBlock(id) if err != nil { log.Fatal(err) } res, err := json.Marshal(block) - token := cc.Publish("block/"+id+"/get/accepted", 0, false, res) + if err != nil { + slog.Default().Info("invaild json marshaled in mqtt_handler.NotifyStateUpdate", slog.Any("err", err)) + } + token := h.client.Publish("block/"+id+"/get/accepted", 0, false, res) token.Wait() case "setting": @@ -146,7 +141,7 @@ func getState(cc mqtt.Client, target string, id string) { if err != nil { log.Println(err.Error()) // Return error message - token := cc.Publish("setting/"+id+"/get/accepted", 0, false, "error") + token := h.client.Publish("setting/"+id+"/get/accepted", 0, false, "error") token.Wait() return } @@ -158,7 +153,7 @@ func getState(cc mqtt.Client, target string, id string) { // remove \n code raw = []byte(strings.Replace(string(raw), "\n", "", -1)) raw = []byte(strings.Replace(string(raw), " ", "", -1)) - token := cc.Publish("setting/"+id+"/get/accepted", 0, false, string(raw)) + token := h.client.Publish("setting/"+id+"/get/accepted", 0, false, string(raw)) token.Wait() case "train": @@ -166,13 +161,11 @@ func getState(cc mqtt.Client, target string, id string) { } } -func getDelta(cc mqtt.Client, target string, id string) { +func (h *Handler) getDelta(target string, id string) { } -func updateState(cc mqtt.Client, target string, id string, payload []byte) { - defer db.C() - db.Open() +func (h *Handler) updateState(target string, id string, payload []byte) { switch target { case "block": @@ -181,7 +174,7 @@ func updateState(cc mqtt.Client, target string, id string, payload []byte) { fmt.Print("newState: ") fmt.Println(newState) if newState == "OPEN" { - err := db.UpdateBlock(&statev1.BlockState{ + err := h.dbHandler.UpdateBlock(&statev1.BlockState{ BlockId: id, State: statev1.BlockStateEnum_BLOCK_STATE_OPEN, }) @@ -190,25 +183,25 @@ func updateState(cc mqtt.Client, target string, id string, payload []byte) { } // NT Tokyo if id == "yamashita_b1" { - err := db.UpdateStop(&statev1.StopAndState{ + err := h.dbHandler.UpdateStop(&statev1.StopAndState{ Id: "yamashita_s1", State: statev1.StopStateEnum_STOP_STATE_GO, }) if err != nil { log.Fatal(err) } - NotifyStateUpdate("stop", "yamashita_s1", statev1.StopStateEnum_STOP_STATE_GO.String()) - err = db.UpdateStop(&statev1.StopAndState{ + h.NotifyStateUpdate("stop", "yamashita_s1", statev1.StopStateEnum_STOP_STATE_GO.String()) + err = h.dbHandler.UpdateStop(&statev1.StopAndState{ Id: "yamashita_s2", State: statev1.StopStateEnum_STOP_STATE_GO, }) if err != nil { log.Fatal(err) } - NotifyStateUpdate("stop", "yamashita_s2", statev1.StopStateEnum_STOP_STATE_GO.String()) + h.NotifyStateUpdate("stop", "yamashita_s2", statev1.StopStateEnum_STOP_STATE_GO.String()) // 今と逆にする - now, err := db.GetPoint("yamashita_p1") + now, err := h.dbHandler.GetPoint("yamashita_p1") if err != nil { log.Fatal(err) } @@ -218,7 +211,7 @@ func updateState(cc mqtt.Client, target string, id string, payload []byte) { } else { newS = statev1.PointStateEnum_POINT_STATE_NORMAL } - err = db.UpdatePoint(&statev1.PointAndState{ + err = h.dbHandler.UpdatePoint(&statev1.PointAndState{ Id: "yamashita_p1", State: newS, }) @@ -227,11 +220,11 @@ func updateState(cc mqtt.Client, target string, id string, payload []byte) { log.Fatal(err) } - NotifyStateUpdate("point", "yamashita_p1", newS.String()) + h.NotifyStateUpdate("point", "yamashita_p1", newS.String()) } } else if newState == "CLOSE" { - err := db.UpdateBlock(&statev1.BlockState{ + err := h.dbHandler.UpdateBlock(&statev1.BlockState{ BlockId: id, State: statev1.BlockStateEnum_BLOCK_STATE_CLOSE, }) @@ -240,22 +233,22 @@ func updateState(cc mqtt.Client, target string, id string, payload []byte) { } // NT Tokyo if id == "yamashita_b1" { - err := db.UpdateStop(&statev1.StopAndState{ + err := h.dbHandler.UpdateStop(&statev1.StopAndState{ Id: "yamashita_s1", State: statev1.StopStateEnum_STOP_STATE_STOP, }) if err != nil { log.Fatal(err) } - NotifyStateUpdate("stop", "yamashita_s1", statev1.StopStateEnum_STOP_STATE_STOP.String()) - err = db.UpdateStop(&statev1.StopAndState{ + h.NotifyStateUpdate("stop", "yamashita_s1", statev1.StopStateEnum_STOP_STATE_STOP.String()) + err = h.dbHandler.UpdateStop(&statev1.StopAndState{ Id: "yamashita_s2", State: statev1.StopStateEnum_STOP_STATE_STOP, }) if err != nil { log.Fatal(err) } - NotifyStateUpdate("stop", "yamashita_s2", statev1.StopStateEnum_STOP_STATE_STOP.String()) + h.NotifyStateUpdate("stop", "yamashita_s2", statev1.StopStateEnum_STOP_STATE_STOP.String()) } } diff --git a/backend/state-manager/pkg/mqtt_handler/mqtt_handler_test.go b/backend/state-manager/pkg/mqtt_handler/mqtt_handler_test.go index c14c548f..c4f953be 100644 --- a/backend/state-manager/pkg/mqtt_handler/mqtt_handler_test.go +++ b/backend/state-manager/pkg/mqtt_handler/mqtt_handler_test.go @@ -1,15 +1,15 @@ package mqtt_handler -import ( - "github.com/joho/godotenv" - "testing" -) +// import ( +// "github.com/joho/godotenv" +// "testing" +// ) -func Test_SendMsg(t *testing.T) { - err := godotenv.Load("../../cmd/.env") - if err != nil { - panic(err) - } - client := MakeClient() - Send(client, "test", "test") -} +// func Test_SendMsg(t *testing.T) { +// err := godotenv.Load("../../cmd/.env") +// if err != nil { +// panic(err) +// } +// client := MakeClient() +// Send(client, "test", "test") +// } diff --git a/backend/state-manager/pkg/operation/operation.go b/backend/state-manager/pkg/operation/operation.go index 52ac1380..701e1e5d 100644 --- a/backend/state-manager/pkg/operation/operation.go +++ b/backend/state-manager/pkg/operation/operation.go @@ -1,40 +1,40 @@ package operation -import ( - "fmt" +// import ( +// "fmt" - statev1 "github.com/ueckoken/plarail2023/backend/spec/state/v1" - "github.com/ueckoken/plarail2023/backend/state-manager/pkg/db" - "github.com/ueckoken/plarail2023/backend/state-manager/pkg/mqtt_handler" +// statev1 "github.com/ueckoken/plarail2023/backend/spec/state/v1" +// "github.com/ueckoken/plarail2023/backend/state-manager/pkg/db" +// "github.com/ueckoken/plarail2023/backend/state-manager/pkg/mqtt_handler" - "log" -) +// "log" +// ) -// シンプルなオペレーション用 +// // シンプルなオペレーション用 -func Check(change *statev1.StopAndState) { - defer db.C() - db.Open() - state, _ := db.GetBlock("yamashita_b1") - fmt.Println(state.State.String()) - if state.State == statev1.BlockStateEnum_BLOCK_STATE_CLOSE { - // 2. 閉塞が閉じていたらストップレールをあげる - err := db.UpdateStop(&statev1.StopAndState{ - Id: "yamashita_s1", - State: statev1.StopStateEnum_STOP_STATE_STOP, - }) - if err != nil { - log.Fatalln(err) - } - } else { - // 3. 閉塞が開いていたらストップレールを下げる - err := db.UpdateStop(&statev1.StopAndState{ - Id: "yamashita_s1", - State: statev1.StopStateEnum_STOP_STATE_GO, - }) - if err != nil { - log.Fatalln(err) - } - mqtt_handler.NotifyStateUpdate("stop", "yamashita_s1", statev1.StopStateEnum_STOP_STATE_STOP.String()) - } -} +// func Check(change *statev1.StopAndState) { +// defer db.Close() +// db.Open() +// state, _ := db.GetBlock("yamashita_b1") +// fmt.Println(state.State.String()) +// if state.State == statev1.BlockStateEnum_BLOCK_STATE_CLOSE { +// // 2. 閉塞が閉じていたらストップレールをあげる +// err := db.UpdateStop(&statev1.StopAndState{ +// Id: "yamashita_s1", +// State: statev1.StopStateEnum_STOP_STATE_STOP, +// }) +// if err != nil { +// log.Fatalln(err) +// } +// } else { +// // 3. 閉塞が開いていたらストップレールを下げる +// err := db.UpdateStop(&statev1.StopAndState{ +// Id: "yamashita_s1", +// State: statev1.StopStateEnum_STOP_STATE_GO, +// }) +// if err != nil { +// log.Fatalln(err) +// } +// mqtt_handler.NotifyStateUpdate("stop", "yamashita_s1", statev1.StopStateEnum_STOP_STATE_STOP.String()) +// } +// }