Skip to content

Commit

Permalink
Merge pull request #229 from ueckoken/db-connection-pool
Browse files Browse the repository at this point in the history
  • Loading branch information
Azuki-bar authored Nov 21, 2023
2 parents 61c8192 + a9d752e commit 7f80f7e
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 308 deletions.
14 changes: 10 additions & 4 deletions backend/onetime/seed-data/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
package main

import (
"context"
"os"

"github.com/joho/godotenv"
statev1 "github.com/ueckoken/plarail2023/backend/spec/state/v1"
"github.com/ueckoken/plarail2023/backend/state-manager/pkg/db"
dbhandler "github.com/ueckoken/plarail2023/backend/state-manager/pkg/db"
"go.mongodb.org/mongo-driver/mongo/options"
"gopkg.in/yaml.v3"
"os"
)

type Station string
Expand All @@ -26,8 +29,11 @@ func main() {
if err := godotenv.Load(".env"); err != nil {
panic(err)
}
db.Open()
defer db.C()
db, err := dbhandler.Open(context.TODO(), options.Client().ApplyURI(os.Getenv("MONGODB_URI")))
if err != nil {
return
}
defer db.Close()
data := &Seed{}
b, _ := os.ReadFile("./data/nt-tokyo.yaml")
if err := yaml.Unmarshal(b, data); err != nil {
Expand Down
39 changes: 34 additions & 5 deletions backend/state-manager/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ import (
"os/signal"
"time"

mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/httplog/v2"
"github.com/joho/godotenv"
"github.com/ueckoken/plarail2023/backend/spec/state/v1/statev1connect"
connectHandler "github.com/ueckoken/plarail2023/backend/state-manager/pkg/connect"
"github.com/ueckoken/plarail2023/backend/state-manager/pkg/db"
"github.com/ueckoken/plarail2023/backend/state-manager/pkg/mqtt_handler"
"go.mongodb.org/mongo-driver/mongo/options"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -84,18 +87,44 @@ func main() {
}
baseCtx, cancel := context.WithCancel(context.Background())
defer cancel()
signalCtx, stop := signal.NotifyContext(baseCtx, os.Interrupt)
ctx, stop := signal.NotifyContext(baseCtx, os.Interrupt)
defer stop()
go func() {
<-signalCtx.Done()
slog.Default().Info("signal received")
<-ctx.Done()
slog.Default().Info("signal received or canceled")
}()

eg, ctx := errgroup.WithContext(signalCtx)
eg, ctx := errgroup.WithContext(ctx)

DBOpts := options.Client().ApplyURI(os.Getenv("MONGODB_URI"))
DBHandler, err := db.Open(ctx, DBOpts)
if err != nil {
slog.Default().Error("database connection failed", slog.Any("err", err))
cancel()
return
}
mqttClientOpts := mqtt.NewClientOptions()
mqttClientOpts.AddBroker(os.Getenv("MQTT_BROKER_ADDR"))
mqttClientOpts.Username = os.Getenv("MQTT_USERNAME")
mqttClientOpts.Password = os.Getenv("MQTT_PASSWORD")
mqttClientOpts.ClientID = os.Getenv("MQTT_CLIENT_ID")

mqttHandler, err := mqtt_handler.NewHandler(mqttClientOpts, DBHandler)
if err != nil {
slog.Default().Error("mqtt create client or handler failed,", slog.Any("err", err))
cancel()
return
}
eg.Go(func() error {
slog.Default().Info("start mqtt handler")
return mqttHandler.Start(ctx)
})

r := chi.NewRouter()
// r.Use(middleware.Recoverer)
r.Use(middleware.Heartbeat("/debug/ping"))
r.Mount("/debug", middleware.Profiler())
r.Handle(statev1connect.NewStateManagerServiceHandler(&connectHandler.StateManagerServer{DBHandler: DBHandler, MqttHandler: mqttHandler}))
r.Use(httplog.RequestLogger(
httplog.NewLogger(
"http_server",
Expand Down Expand Up @@ -142,7 +171,7 @@ func main() {
//go operation.Handler()
eg.Go(func() error {
slog.Default().Info("start mqtt handler")
err := mqtt_handler.StartHandler(ctx)
err := mqttHandler.Start(ctx)
return fmt.Errorf("mqtt handler error: %w", err)
})

Expand Down
28 changes: 12 additions & 16 deletions backend/state-manager/pkg/connect/connect_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ import (
"connectrpc.com/connect"

statev1 "github.com/ueckoken/plarail2023/backend/spec/state/v1"
db "github.com/ueckoken/plarail2023/backend/state-manager/pkg/db"
"github.com/ueckoken/plarail2023/backend/state-manager/pkg/db"
"github.com/ueckoken/plarail2023/backend/state-manager/pkg/mqtt_handler"
)

type StateManagerServer struct{}

type StateManagerServer struct {
DBHandler *db.DBHandler
MqttHandler *mqtt_handler.Handler
}

/*
Block
Expand All @@ -23,9 +27,7 @@ func (s *StateManagerServer) GetBlockStates(
ctx context.Context,
req *connect.Request[statev1.GetBlockStatesRequest],
) (*connect.Response[statev1.GetBlockStatesResponse], error) {
defer db.C()
db.Open()
blockStates, err := db.GetBlocks()
blockStates, err := s.DBHandler.GetBlocks()
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
Expand Down Expand Up @@ -55,9 +57,7 @@ func (s *StateManagerServer) UpdateBlockState(
ctx context.Context,
req *connect.Request[statev1.UpdateBlockStateRequest],
) (*connect.Response[statev1.UpdateBlockStateResponse], error) {
defer db.C()
db.Open()
err := db.UpdateBlock(req.Msg.State)
err := s.DBHandler.UpdateBlock(req.Msg.State)
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
Expand All @@ -77,9 +77,7 @@ func (s *StateManagerServer) UpdatePointState(
ctx context.Context,
req *connect.Request[statev1.UpdatePointStateRequest],
) (*connect.Response[statev1.UpdatePointStateResponse], error) {
defer db.C()
db.Open()
err := db.UpdatePoint(req.Msg.State)
err := s.DBHandler.UpdatePoint(req.Msg.State)
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
Expand All @@ -88,7 +86,7 @@ func (s *StateManagerServer) UpdatePointState(
slog.Default().Error("db error", err)
return nil, err
}
mqtt_handler.NotifyStateUpdate("point", req.Msg.State.Id, req.Msg.State.State.String())
s.MqttHandler.NotifyStateUpdate("point", req.Msg.State.Id, req.Msg.State.State.String())

return connect.NewResponse(&statev1.UpdatePointStateResponse{}), nil
}
Expand All @@ -112,9 +110,7 @@ func (s *StateManagerServer) UpdateStopState(
ctx context.Context,
req *connect.Request[statev1.UpdateStopStateRequest],
) (*connect.Response[statev1.UpdateStopStateResponse], error) {
db.Open()
defer db.C()
err := db.UpdateStop(req.Msg.State)
err := s.DBHandler.UpdateStop(req.Msg.State)
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
Expand All @@ -123,7 +119,7 @@ func (s *StateManagerServer) UpdateStopState(
slog.Default().Error("db connection error", err)
return nil, err
}
mqtt_handler.NotifyStateUpdate("stop", req.Msg.State.Id, req.Msg.State.State.String())
s.MqttHandler.NotifyStateUpdate("stop", req.Msg.State.Id, req.Msg.State.State.String())
return connect.NewResponse(&statev1.UpdateStopStateResponse{}), nil
}

Expand Down
79 changes: 41 additions & 38 deletions backend/state-manager/pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ package db

import (
"context"
"fmt"
"log"
"log/slog"
"os"

statev1 "github.com/ueckoken/plarail2023/backend/spec/state/v1"
"go.mongodb.org/mongo-driver/bson"
Expand All @@ -17,30 +17,33 @@ import (
"go.mongodb.org/mongo-driver/mongo/options"
)

// mongodb connection
var client *mongo.Client
type DBHandler struct {
stateManagerDB *mongo.Database
}

func Open() {
func Open(ctx context.Context, opts *options.ClientOptions) (*DBHandler, error) {
var err error
slog.Default().Debug("Connecting to MongoDB...")
uri := os.Getenv("MONGODB_URI")
if uri == "" {
log.Fatal("No MONGODB_URI set")
}
// TODO: Open関数がctxを受けるようにして、そのctxの子contextをDBのコネクションに使う
client, err = mongo.Connect(context.TODO(), options.Client().ApplyURI(uri))
client, err := mongo.Connect(ctx, opts)
if err != nil {
// TODO: return err, do not panic!
slog.Default().Error("database connection failed", slog.Any("err", err))
panic(err)
return nil, err
}

if err := client.Ping(ctx, nil); err != nil {
slog.Error("DB ping failed", slog.Any("err", err))
return nil, fmt.Errorf("DB Ping failed `%w`", err)
}
slog.Default().Debug("connected to DB")
return &DBHandler{
stateManagerDB: client.Database("state-manager"),
}, nil
}

func C() {
func (db *DBHandler) Close() {
slog.Default().Debug("Closing connection to DB...")
// TODO: contextを受けて、その子contextをDBクライアントに渡す
if err := client.Disconnect(context.TODO()); err != nil {
if err := db.stateManagerDB.Client().Disconnect(context.TODO()); err != nil {
slog.Default().Error("DB Connection Closing failed")
log.Println(err)
}
Expand All @@ -51,8 +54,8 @@ func C() {
Point
*/

func UpdatePoint(PointAndState *statev1.PointAndState) error {
collection := client.Database("state-manager").Collection("points")
func (db *DBHandler) UpdatePoint(PointAndState *statev1.PointAndState) error {
collection := db.stateManagerDB.Collection("points")
_, err := collection.UpdateOne(
context.Background(),
bson.M{"id": PointAndState.Id},
Expand All @@ -64,17 +67,17 @@ func UpdatePoint(PointAndState *statev1.PointAndState) error {
return nil
}

func AddPoint(PointAndState *statev1.PointAndState) error {
collection := client.Database("state-manager").Collection("points")
func (db *DBHandler) AddPoint(PointAndState *statev1.PointAndState) error {
collection := db.stateManagerDB.Collection("points")
_, err := collection.InsertOne(context.Background(), PointAndState)
if err != nil {
return err
}
return nil
}

func GetPoint(pointId string) (*statev1.PointAndState, error) {
collection := client.Database("state-manager").Collection("points")
func (db *DBHandler) GetPoint(pointId string) (*statev1.PointAndState, error) {
collection := db.stateManagerDB.Collection("points")
var result *statev1.PointAndState
err := collection.FindOne(context.Background(), bson.M{"id": pointId}).Decode(&result)
if err != nil {
Expand All @@ -83,8 +86,8 @@ func GetPoint(pointId string) (*statev1.PointAndState, error) {
return result, nil
}

func GetPoints() []*statev1.PointAndState {
collection := client.Database("state-manager").Collection("points")
func (db *DBHandler) GetPoints() []*statev1.PointAndState {
collection := db.stateManagerDB.Collection("points")
cursor, err := collection.Find(context.Background(), bson.M{})
if err != nil {
slog.Default().Warn("Get Points failed", slog.Any("err", err))
Expand All @@ -101,8 +104,8 @@ func GetPoints() []*statev1.PointAndState {
Stop
*/

func UpdateStop(stop *statev1.StopAndState) error {
collection := client.Database("state-manager").Collection("stops")
func (db *DBHandler) UpdateStop(stop *statev1.StopAndState) error {
collection := db.stateManagerDB.Collection("stops")

_, err := collection.UpdateOne(
context.Background(),
Expand All @@ -116,17 +119,17 @@ func UpdateStop(stop *statev1.StopAndState) error {
return nil
}

func AddStop(stop *statev1.StopAndState) error {
collection := client.Database("state-manager").Collection("stops")
func (db *DBHandler) AddStop(stop *statev1.StopAndState) error {
collection := db.stateManagerDB.Collection("stops")
_, err := collection.InsertOne(context.Background(), stop)
if err != nil {
return err
}
return nil
}

func GetStop(stopId string) (*statev1.StopAndState, error) {
collection := client.Database("state-manager").Collection("stops")
func (db *DBHandler) GetStop(stopId string) (*statev1.StopAndState, error) {
collection := db.stateManagerDB.Collection("stops")
var result *statev1.StopAndState
err := collection.FindOne(context.Background(), bson.M{"id": stopId}).Decode(&result)
if err != nil {
Expand All @@ -135,8 +138,8 @@ func GetStop(stopId string) (*statev1.StopAndState, error) {
return result, nil
}

func GetStops() []*statev1.StopAndState {
collection := client.Database("state-manager").Collection("stops")
func (db *DBHandler) GetStops() []*statev1.StopAndState {
collection := db.stateManagerDB.Collection("stops")
cursor, err := collection.Find(context.Background(), bson.M{})
if err != nil {
panic(err)
Expand All @@ -152,17 +155,17 @@ func GetStops() []*statev1.StopAndState {
Block
*/

func AddBlock(block *statev1.BlockState) error {
collection := client.Database("state-manager").Collection("blocks")
func (db *DBHandler) AddBlock(block *statev1.BlockState) error {
collection := db.stateManagerDB.Collection("blocks")
_, err := collection.InsertOne(context.Background(), block)
if err != nil {
return err
}
return nil
}

func UpdateBlock(block *statev1.BlockState) error {
collection := client.Database("state-manager").Collection("blocks")
func (db *DBHandler) UpdateBlock(block *statev1.BlockState) error {
collection := db.stateManagerDB.Collection("blocks")
_, err := collection.UpdateOne(
context.Background(),
bson.M{"blockid": block.BlockId},
Expand All @@ -174,8 +177,8 @@ func UpdateBlock(block *statev1.BlockState) error {
return nil
}

func GetBlock(blockId string) (*statev1.BlockState, error) {
collection := client.Database("state-manager").Collection("blocks")
func (db *DBHandler) GetBlock(blockId string) (*statev1.BlockState, error) {
collection := db.stateManagerDB.Collection("blocks")
var result *statev1.BlockState
err := collection.FindOne(context.Background(), bson.M{"blockid": blockId}).Decode(&result)
if err != nil {
Expand All @@ -184,8 +187,8 @@ func GetBlock(blockId string) (*statev1.BlockState, error) {
return result, nil
}

func GetBlocks() ([]*statev1.BlockState, error) {
collection := client.Database("state-manager").Collection("blocks")
func (db *DBHandler) GetBlocks() ([]*statev1.BlockState, error) {
collection := db.stateManagerDB.Collection("blocks")
cursor, err := collection.Find(context.Background(), bson.M{})
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit 7f80f7e

Please sign in to comment.