diff --git a/cmd/internal/grpc/README.tmpl b/cmd/internal/grpc/README.tmpl index 9146df7..f01f75b 100644 --- a/cmd/internal/grpc/README.tmpl +++ b/cmd/internal/grpc/README.tmpl @@ -9,7 +9,8 @@ Go gRPC项目快速开发脚手架 - 命令行使用 [cobra](https://github.com/spf13/cobra) - 工具包使用 [yiigo](https://github.com/shenghui0779/yiigo) - 使用 [grpc-gateway](https://github.com/grpc-ecosystem/grpc-gateway) 同时支持 grpc 和 http 服务 +- http服务支持跨域 - 支持 proto 参数验证 - 支持 swagger.json 生成 -- 包含 TraceId、请求日志 等中间价 +- 包含 TraceId、请求日志 等中间件 - 简单好用的 Result Status 统一输出方式 diff --git a/cmd/internal/grpc/app/README.tmpl b/cmd/internal/grpc/app/README.tmpl index 95b3de2..583b6fc 100644 --- a/cmd/internal/grpc/app/README.tmpl +++ b/cmd/internal/grpc/app/README.tmpl @@ -23,9 +23,9 @@ go install github.com/go-swagger/go-swagger/cmd/swagger@latest ### 配置运行 -1. 配置文件: `config.toml` +1. 配置文件 `config.toml` 2. 执行 `buf generate` 编译proto文件 -3. 表 `t_demo` 对应 `ent/schema/demo.go` +3. 执行 `go mod tidy` 下载依赖 4. 执行 `ent/generate.go` 生成ORM代码 (只要 `ent/schema` 目录下有变动都需要执行) 5. 执行 `go run main.go` 运行 6. 执行 `go run main.go -h` 查看命令 diff --git a/cmd/internal/grpc/app/pkg_app_cmd_root.tmpl b/cmd/internal/grpc/app/pkg_app_cmd_root.tmpl index 97d9082..cb9c98a 100644 --- a/cmd/internal/grpc/app/pkg_app_cmd_root.tmpl +++ b/cmd/internal/grpc/app/pkg_app_cmd_root.tmpl @@ -3,6 +3,8 @@ package cmd import ( "context" "os" + "os/signal" + "syscall" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -50,9 +52,30 @@ func preInit(ctx context.Context) { ent.Init(ctx) } +// CleanUp 清理资源 +func CleanUp() { + // 关闭数据库连接 + ent.Close() +} + func serving(ctx context.Context) { + go watchExit() // serve grpc go server.ServeGrpc(ctx) // serve http server.ServeHttp(ctx) } + +func watchExit() { + // 创建一个通道来监听信号 + ch := make(chan os.Signal, 1) + // 监听特定的系统信号 + signal.Notify(ch, syscall.SIGINT, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM) + // 处理信号 + sig := <-ch + log.Warn(context.TODO(), "Received Signal", zap.String("signal", sig.String())) + // 执行清理操作 + CleanUp() + // 退出程序 + os.Exit(0) +} diff --git a/cmd/internal/grpc/app/pkg_app_main.tmpl b/cmd/internal/grpc/app/pkg_app_main.tmpl index 6c5615a..1734afb 100644 --- a/cmd/internal/grpc/app/pkg_app_main.tmpl +++ b/cmd/internal/grpc/app/pkg_app_main.tmpl @@ -1,17 +1,8 @@ package main -import ( - "{{.Module}}/pkg/{{.AppPkg}}/cmd" - "{{.Module}}/pkg/{{.AppPkg}}/ent" -) +import "{{.Module}}/pkg/{{.AppPkg}}/cmd" func main() { - defer clean() + defer cmd.CleanUp() cmd.Init() } - -// clean 清理资源 -func clean() { - // 关闭数据库连接 - ent.Close() -} diff --git a/cmd/internal/grpc/app/pkg_app_server_http.tmpl b/cmd/internal/grpc/app/pkg_app_server_http.tmpl index d1d0099..7fc972b 100644 --- a/cmd/internal/grpc/app/pkg_app_server_http.tmpl +++ b/cmd/internal/grpc/app/pkg_app_server_http.tmpl @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strings" "time" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" @@ -27,7 +28,21 @@ func ServeHttp(ctx context.Context) { } defer conn.Close() // Create http mux with health check - mux := runtime.NewServeMux(runtime.WithHealthzEndpoint(grpc_health_v1.NewHealthClient(conn))) + mux := runtime.NewServeMux( + runtime.WithHealthzEndpoint(grpc_health_v1.NewHealthClient(conn)), + runtime.WithIncomingHeaderMatcher(func(s string) (string, bool) { + if v, ok := runtime.DefaultHeaderMatcher(s); ok { + return v, true + } + return strings.ToLower(s), true + }), + runtime.WithOutgoingHeaderMatcher(func(s string) (string, bool) { + if s == log.TraceId { + return s, true + } + return runtime.MetadataHeaderPrefix + s, true + }), + ) // Register http handler if err = registerHttp(ctx, mux, conn); err != nil { log.Fatal(ctx, "Error register http", zap.Error(err)) @@ -40,7 +55,7 @@ func ServeHttp(ctx context.Context) { }, AllowedMethods: []string{http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodPut, http.MethodDelete, http.MethodOptions}, AllowedHeaders: []string{"Authorization", "Content-Type", "withCredentials"}, - ExposedHeaders: []string{}, // 服务器暴露一些自定义的头信息,允许客户端访问 + ExposedHeaders: []string{log.TraceId}, // 服务器暴露一些自定义的头信息,允许客户端访问 AllowCredentials: true, }).Handler(mux) // Serve HTTP server diff --git a/cmd/internal/grpc/pkg_lib_identity_identity.tmpl b/cmd/internal/grpc/pkg_lib_identity_identity.tmpl index 2bf13c9..a9271e3 100644 --- a/cmd/internal/grpc/pkg_lib_identity_identity.tmpl +++ b/cmd/internal/grpc/pkg_lib_identity_identity.tmpl @@ -14,10 +14,6 @@ import ( "{{.Module}}/pkg/lib/log" ) -type CtxKeyAuth int - -const IdentityKey CtxKeyAuth = 0 - // Identity 授权身份 type Identity interface { // ID 授权ID @@ -77,17 +73,24 @@ func New(id int64, token string) Identity { } } +type identityKey struct{} + +// NewContext 将Identity注入context +func NewContext(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, identityKey{}, FromAuthToken(ctx, token)) +} + // FromContext 获取授权信息 func FromContext(ctx context.Context) Identity { if ctx == nil { return NewEmpty() } - _identity, ok := ctx.Value(IdentityKey).(Identity) + id, ok := ctx.Value(identityKey{}).(Identity) if !ok { return NewEmpty() } - return _identity + return id } // FromAuthToken 解析授权Token @@ -105,10 +108,10 @@ func FromAuthToken(ctx context.Context, token string) Identity { return NewEmpty() } - _identity := NewEmpty() - if err = json.Unmarshal(plainText, _identity); err != nil { + id := NewEmpty() + if err = json.Unmarshal(plainText, id); err != nil { log.Error(ctx, "Error json.Unmarshal AuthToken", zap.Error(err)) return NewEmpty() } - return _identity + return id } diff --git a/cmd/internal/grpc/pkg_lib_log_log.tmpl b/cmd/internal/grpc/pkg_lib_log_log.tmpl index b263ef5..087077b 100644 --- a/cmd/internal/grpc/pkg_lib_log_log.tmpl +++ b/cmd/internal/grpc/pkg_lib_log_log.tmpl @@ -4,24 +4,48 @@ import ( "context" "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) func Info(ctx context.Context, msg string, fields ...zap.Field) { - logger.Info(msg, append(fields, zap.String("trace_id", GetTraceId(ctx)), zap.String("method", GetFullMethod(ctx)))...) + Log(ctx, zapcore.InfoLevel, msg, fields...) } func Warn(ctx context.Context, msg string, fields ...zap.Field) { - logger.Warn(msg, append(fields, zap.String("trace_id", GetTraceId(ctx)), zap.String("method", GetFullMethod(ctx)))...) + Log(ctx, zapcore.WarnLevel, msg, fields...) } func Error(ctx context.Context, msg string, fields ...zap.Field) { - logger.Error(msg, append(fields, zap.String("trace_id", GetTraceId(ctx)), zap.String("method", GetFullMethod(ctx)))...) + Log(ctx, zapcore.ErrorLevel, msg, fields...) } func Panic(ctx context.Context, msg string, fields ...zap.Field) { - logger.Panic(msg, append(fields, zap.String("trace_id", GetTraceId(ctx)), zap.String("method", GetFullMethod(ctx)))...) + Log(ctx, zapcore.PanicLevel, msg, fields...) } func Fatal(ctx context.Context, msg string, fields ...zap.Field) { - logger.Fatal(msg, append(fields, zap.String("trace_id", GetTraceId(ctx)), zap.String("method", GetFullMethod(ctx)))...) + Log(ctx, zapcore.FatalLevel, msg, fields...) +} + +func Log(ctx context.Context, level zapcore.Level, msg string, fields ...zap.Field) { + traceId, fullMethod := GetTraceInfo(ctx) + fields = append(fields, + zap.String("hostname", hostname), + zap.String("trace_id", traceId), + zap.String("method", fullMethod), + ) + switch level { + case zapcore.InfoLevel: + logger.Info(msg, fields...) + case zapcore.WarnLevel: + logger.Warn(msg, fields...) + case zapcore.ErrorLevel: + logger.Error(msg, fields...) + case zapcore.PanicLevel: + logger.Panic(msg, fields...) + case zapcore.FatalLevel: + logger.Fatal(msg, fields...) + default: + logger.Debug(msg, fields...) + } } diff --git a/cmd/internal/grpc/pkg_lib_log_traceid.tmpl b/cmd/internal/grpc/pkg_lib_log_traceid.tmpl index 9da66bd..406aa3d 100644 --- a/cmd/internal/grpc/pkg_lib_log_traceid.tmpl +++ b/cmd/internal/grpc/pkg_lib_log_traceid.tmpl @@ -8,41 +8,25 @@ import ( "os" "strings" "sync/atomic" -) -// Key to use when setting the trace ID. -type CtxKeyTraceId int + "github.com/shenghui0779/yiigo/xhash" + "google.golang.org/grpc/metadata" +) -// TraceIdKey is the key that holds the unique trace ID in a trace context. const ( - TraceIdKey CtxKeyTraceId = 0 - FullMethodKey CtxKeyTraceId = 1 + TraceId = "x-trace-id" + TraceMethod = "x-trace-method" ) var ( - traceId uint64 - TracePrefix string + hostname string + tracePrefix string + traceSeq uint64 ) -// A quick note on the statistics here: we're trying to calculate the chance that -// two randomly generated base62 prefixes will collide. We use the formula from -// http://en.wikipedia.org/wiki/Birthday_problem -// -// P[m, n] \approx 1 - e^{-m^2/2n} -// -// We ballpark an upper bound for $m$ by imagining (for whatever reason) a server -// that restarts every second over 10 years, for $m = 86400 * 365 * 10 = 315360000$ -// -// For a $k$ character base-62 identifier, we have $n(k) = 62^k$ -// -// Plugging this in, we find $P[m, n(10)] \approx 5.75%$, which is good enough for -// our purposes, and is surely more than anyone would ever need in practice -- a -// process that is rebooted a handful of times a day for a hundred years has less -// than a millionth of a percent chance of generating two colliding IDs. - func init() { - hostname, err := os.Hostname() - if hostname == "" || err != nil { + hostname, _ = os.Hostname() + if len(hostname) == 0 { hostname = "localhost" } @@ -51,38 +35,31 @@ func init() { b64 string ) for len(b64) < 10 { - rand.Read(buf[:]) + _, _ = rand.Read(buf[:]) b64 = base64.StdEncoding.EncodeToString(buf[:]) b64 = strings.NewReplacer("+", "", "/", "").Replace(b64) } - TracePrefix = fmt.Sprintf("%s/%s", hostname, b64[0:10]) + tracePrefix = fmt.Sprintf("%s/%s", hostname, b64) } -// GetTraceId returns a trace ID from the given context if one is present. -// Returns the empty string if a trace ID cannot be found. -func GetTraceId(ctx context.Context) string { - if ctx == nil { - return "-" - } - if v, ok := ctx.Value(TraceIdKey).(string); ok { - return v - } - return "-" +// NewTraceId generates a new trace ID in the sequence. +func NewTraceId() string { + seq := atomic.AddUint64(&traceSeq, 1) + return xhash.MD5(fmt.Sprintf("%s-%d", tracePrefix, seq)) } -// GetFullMethod returns a full method from the given context if one is present. -// Returns the empty string if a full method cannot be found. -func GetFullMethod(ctx context.Context) string { - if ctx == nil { - return "-" +func GetTraceInfo(ctx context.Context) (traceId, fullMethod string) { + traceId = "-" + fullMethod = "-" + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return } - if method, ok := ctx.Value(FullMethodKey).(string); ok { - return method + if v := md.Get(TraceId); len(v) != 0 { + traceId = v[0] } - return "-" -} - -// NextTraceId generates the next trace ID in the sequence. -func NextTraceId() uint64 { - return atomic.AddUint64(&traceId, 1) + if v := md.Get(TraceMethod); len(v) != 0 { + fullMethod = v[0] + } + return } diff --git a/cmd/internal/grpc/pkg_lib_middleware_log.tmpl b/cmd/internal/grpc/pkg_lib_middleware_log.tmpl index efc8ff6..83276aa 100644 --- a/cmd/internal/grpc/pkg_lib_middleware_log.tmpl +++ b/cmd/internal/grpc/pkg_lib_middleware_log.tmpl @@ -10,9 +10,14 @@ import ( "{{.Module}}/pkg/lib/log" ) -func Log(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { +const HealthCheckMethod = "/grpc.health.v1.Health/Check" + +func Log(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { now := time.Now().Local() defer func() { + if info.FullMethod == HealthCheckMethod { + return + } log.Info(ctx, "Request info", zap.Any("request", req), zap.Any("response", resp), diff --git a/cmd/internal/grpc/pkg_lib_middleware_traceid.tmpl b/cmd/internal/grpc/pkg_lib_middleware_traceid.tmpl index 14da90c..d3fba72 100644 --- a/cmd/internal/grpc/pkg_lib_middleware_traceid.tmpl +++ b/cmd/internal/grpc/pkg_lib_middleware_traceid.tmpl @@ -2,18 +2,24 @@ package middleware import ( "context" - "fmt" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "{{.Module}}/pkg/lib/log" ) // TraceId is a middleware that injects a trace ID into the context of each request. func TraceId(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - if v, ok := ctx.Value(log.TraceIdKey).(string); !ok || len(v) == 0 { - ctx = context.WithValue(ctx, log.TraceIdKey, fmt.Sprintf("%s-%06d", log.TracePrefix, log.NextTraceId())) + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + md = metadata.Pairs() } - ctx = context.WithValue(ctx, log.FullMethodKey, info.FullMethod) - return handler(ctx, req) + if v := md.Get(log.TraceId); len(v) == 0 { + md.Set(log.TraceId, log.NewTraceId()) + } + md.Set(log.TraceMethod, info.FullMethod) + // set the response header + _ = grpc.SetHeader(ctx, metadata.Pairs(log.TraceId, md.Get(log.TraceId)[0])) + return handler(metadata.NewIncomingContext(ctx, md), req) } diff --git a/cmd/internal/grpc/pkg_lib_result_code.tmpl b/cmd/internal/grpc/pkg_lib_result_code.tmpl index 146776b..1b950b8 100644 --- a/cmd/internal/grpc/pkg_lib_result_code.tmpl +++ b/cmd/internal/grpc/pkg_lib_result_code.tmpl @@ -1,49 +1,45 @@ package result -import ( - "errors" - - "google.golang.org/grpc/codes" -) +import "errors" func ErrParams(err ...error) Status { if len(err) != 0 { - return New(codes.InvalidArgument, err[0]) + return New(10000, err[0]) } - return New(codes.InvalidArgument, errors.New("params error")) + return New(10000, errors.New("params error")) } func ErrAuth(err ...error) Status { if len(err) != 0 { - return New(codes.Unauthenticated, err[0]) + return New(20000, err[0]) } - return New(codes.Unauthenticated, errors.New("unauthenticated")) + return New(20000, errors.New("unauthenticated")) } func ErrPerm(err ...error) Status { if len(err) != 0 { - return New(codes.PermissionDenied, err[0]) + return New(30000, err[0]) } - return New(codes.PermissionDenied, errors.New("permission denied")) + return New(30000, errors.New("permission denied")) } func ErrNotFound(err ...error) Status { if len(err) != 0 { - return New(codes.NotFound, err[0]) + return New(40000, err[0]) } - return New(codes.NotFound, errors.New("entity not found")) + return New(40000, errors.New("entity not found")) } -func ErrPrecondition(err ...error) Status { +func ErrService(err ...error) Status { if len(err) != 0 { - return New(codes.FailedPrecondition, err[0]) + return New(60000, err[0]) } - return New(codes.FailedPrecondition, errors.New("failed precondition")) + return New(60000, errors.New("service error")) } func ErrSystem(err ...error) Status { if len(err) != 0 { - return New(codes.Internal, err[0]) + return New(50000, err[0]) } - return New(codes.Internal, errors.New("internal server error")) + return New(50000, errors.New("internal server error")) } diff --git a/cmd/internal/grpc/pkg_lib_result_status.tmpl b/cmd/internal/grpc/pkg_lib_result_status.tmpl index 5007f20..07eb8aa 100644 --- a/cmd/internal/grpc/pkg_lib_result_status.tmpl +++ b/cmd/internal/grpc/pkg_lib_result_status.tmpl @@ -20,12 +20,13 @@ type errstatus struct { } func (s *errstatus) Error(ctx context.Context) error { - return status.Error(s.code, fmt.Sprintf("[%s] %+v", log.GetTraceId(ctx), s.err)) + traceId, _ := log.GetTraceInfo(ctx) + return status.Error(s.code, fmt.Sprintf("[%s] %+v", traceId, s.err)) } -func New(code codes.Code, err error) Status { +func New(code int, err error) Status { return &errstatus{ - code: code, + code: codes.Code(code), err: err, } } diff --git a/cmd/internal/http/README.tmpl b/cmd/internal/http/README.tmpl index 6d071ce..f51df7f 100644 --- a/cmd/internal/http/README.tmpl +++ b/cmd/internal/http/README.tmpl @@ -9,5 +9,5 @@ Go Web项目快速开发脚手架 - 配置使用 [viper](https://github.com/spf13/viper) - 命令行使用 [cobra](https://github.com/spf13/cobra) - 工具包使用 [yiigo](https://github.com/shenghui0779/yiigo) -- 包含 认证、请求日志、跨域 中间价 +- 包含 TraceId、请求日志、跨域 中间件 - 简单好用的 API Result 统一输出方式 diff --git a/cmd/internal/http/app/README.tmpl b/cmd/internal/http/app/README.tmpl index 1eabea5..0b8b4da 100644 --- a/cmd/internal/http/app/README.tmpl +++ b/cmd/internal/http/app/README.tmpl @@ -8,8 +8,8 @@ go install entgo.io/ent/cmd/ent@latest ### 配置运行 -1. 配置文件: `config.toml` -2. 表 `t_demo` 对应 `ent/schema/demo.go` +1. 配置文件 `config.toml` +2. 执行 `go mod tidy` 下载依赖 3. 执行 `ent/generate.go` 生成ORM代码 (只要 `ent/schema` 目录下有变动都需要执行) 4. 执行 `go run main.go` 运行 5. 执行 `go run main.go -h` 查看命令 diff --git a/cmd/internal/http/app/pkg_app_api_controller_demo.tmpl b/cmd/internal/http/app/pkg_app_api_greeter.tmpl similarity index 65% rename from cmd/internal/http/app/pkg_app_api_controller_demo.tmpl rename to cmd/internal/http/app/pkg_app_api_greeter.tmpl index b289154..9f5dca0 100644 --- a/cmd/internal/http/app/pkg_app_api_controller_demo.tmpl +++ b/cmd/internal/http/app/pkg_app_api_greeter.tmpl @@ -1,4 +1,4 @@ -package controller +package api import ( "net/http" @@ -6,20 +6,20 @@ import ( "github.com/pkg/errors" "go.uber.org/zap" - "{{.Module}}/pkg/{{.AppPkg}}/api/service/demo" + "{{.Module}}/pkg/{{.AppPkg}}/service/greeter" "{{.Module}}/pkg/lib" "{{.Module}}/pkg/lib/log" "{{.Module}}/pkg/lib/result" ) -func DemoCreate(w http.ResponseWriter, r *http.Request) { +func Hello(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - req := new(demo.ReqCreate) + req := new(greeter.ReqHello) if err := lib.BindJSON(r, req); err != nil { log.Error(ctx, "Error params", zap.Error(err)) result.ErrParams(result.E(errors.WithMessage(err, "参数错误"))).JSON(w, r) return } - demo.Create(ctx, req).JSON(w, r) + greeter.Hello(ctx, req).JSON(w, r) } diff --git a/cmd/internal/http/app/pkg_app_api_service_demo_create.tmpl b/cmd/internal/http/app/pkg_app_api_service_demo_create.tmpl deleted file mode 100644 index fcaa14c..0000000 --- a/cmd/internal/http/app/pkg_app_api_service_demo_create.tmpl +++ /dev/null @@ -1,26 +0,0 @@ -package demo - -import ( - "context" - - "go.uber.org/zap" - - "{{.Module}}/pkg/{{.AppPkg}}/ent" - "{{.Module}}/pkg/lib/log" - "{{.Module}}/pkg/lib/result" -) - -type ReqCreate struct { - Title string `json:"title" valid:"required"` -} - -func Create(ctx context.Context, req *ReqCreate) result.Result { - _, err := ent.DB.Demo.Create(). - SetTitle(req.Title). - Save(ctx) - if err != nil { - log.Error(ctx, "Error create demo", zap.Error(err)) - return result.ErrSystem(result.E(err)) - } - return result.OK() -} diff --git a/cmd/internal/http/app/pkg_app_cmd_root.tmpl b/cmd/internal/http/app/pkg_app_cmd_root.tmpl index 24c5ec4..6eed3be 100644 --- a/cmd/internal/http/app/pkg_app_cmd_root.tmpl +++ b/cmd/internal/http/app/pkg_app_cmd_root.tmpl @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "os" + "os/signal" + "syscall" "time" "github.com/go-chi/chi/v5" @@ -14,8 +16,8 @@ import ( "github.com/spf13/viper" "go.uber.org/zap" - "{{.Module}}/pkg/{{.AppPkg}}/api/router" "{{.Module}}/pkg/{{.AppPkg}}/ent" + "{{.Module}}/pkg/{{.AppPkg}}/router" "{{.Module}}/pkg/lib/log" "{{.Module}}/pkg/lib/middleware" ) @@ -54,8 +56,14 @@ func preInit(ctx context.Context) { ent.Init(ctx) } +// CleanUp 清理资源 +func CleanUp() { + // 关闭数据库连接 + ent.Close() +} + func serving() { - r := chi.NewRouter() + go watchExit() withCors := cors.New(cors.Options{ // AllowedOrigins: []string{"*"}, @@ -64,13 +72,12 @@ func serving() { }, AllowedMethods: []string{http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodPut, http.MethodDelete, http.MethodOptions}, AllowedHeaders: []string{"Authorization", "Content-Type", "withCredentials"}, - ExposedHeaders: []string{}, // 服务器暴露一些自定义的头信息,允许客户端访问 + ExposedHeaders: []string{log.TraceId}, // 服务器暴露一些自定义的头信息,允许客户端访问 AllowCredentials: true, }) - - r.Use(withCors.Handler, chi_middleware.RequestID, middleware.Recovery) + r := chi.NewRouter() + r.Use(withCors.Handler, middleware.TraceId, middleware.Recovery) r.Mount("/debug", chi_middleware.Profiler()) - router.App(r) srv := &http.Server{ @@ -80,10 +87,22 @@ func serving() { WriteTimeout: 10 * time.Second, IdleTimeout: 10 * time.Second, } - fmt.Println("listening on", srv.Addr) - if err := srv.ListenAndServe(); err != nil { log.Fatal(context.Background(), "serving error", zap.Error(err)) } } + +func watchExit() { + // 创建一个通道来监听信号 + ch := make(chan os.Signal, 1) + // 监听特定的系统信号 + signal.Notify(ch, syscall.SIGINT, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM) + // 处理信号 + sig := <-ch + log.Warn(context.TODO(), "Received Signal", zap.String("signal", sig.String())) + // 执行清理操作 + CleanUp() + // 退出程序 + os.Exit(0) +} diff --git a/cmd/internal/http/app/pkg_app_main.tmpl b/cmd/internal/http/app/pkg_app_main.tmpl index 6c5615a..1734afb 100644 --- a/cmd/internal/http/app/pkg_app_main.tmpl +++ b/cmd/internal/http/app/pkg_app_main.tmpl @@ -1,17 +1,8 @@ package main -import ( - "{{.Module}}/pkg/{{.AppPkg}}/cmd" - "{{.Module}}/pkg/{{.AppPkg}}/ent" -) +import "{{.Module}}/pkg/{{.AppPkg}}/cmd" func main() { - defer clean() + defer cmd.CleanUp() cmd.Init() } - -// clean 清理资源 -func clean() { - // 关闭数据库连接 - ent.Close() -} diff --git a/cmd/internal/http/app/pkg_app_api_middleware_auth.tmpl b/cmd/internal/http/app/pkg_app_middleware_auth.tmpl similarity index 51% rename from cmd/internal/http/app/pkg_app_api_middleware_auth.tmpl rename to cmd/internal/http/app/pkg_app_middleware_auth.tmpl index 5ada944..0a55abd 100644 --- a/cmd/internal/http/app/pkg_app_api_middleware_auth.tmpl +++ b/cmd/internal/http/app/pkg_app_middleware_auth.tmpl @@ -1,10 +1,20 @@ package middleware -import "net/http" +import ( + "net/http" + + "{{.Module}}/pkg/lib/identity" + "{{.Module}}/pkg/lib/result" +) // Auth App授权中间件 func Auth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id := identity.FromContext(r.Context()) + if id.ID() == 0 { + result.ErrAuth().JSON(w, r) + return + } // TODO ... next.ServeHTTP(w, r) }) diff --git a/cmd/internal/http/app/pkg_app_api_router_app.tmpl b/cmd/internal/http/app/pkg_app_router_app.tmpl similarity index 84% rename from cmd/internal/http/app/pkg_app_api_router_app.tmpl rename to cmd/internal/http/app/pkg_app_router_app.tmpl index 29c7568..bf82100 100644 --- a/cmd/internal/http/app/pkg_app_api_router_app.tmpl +++ b/cmd/internal/http/app/pkg_app_router_app.tmpl @@ -5,8 +5,8 @@ import ( "github.com/go-chi/chi/v5" - "{{.Module}}/pkg/{{.AppPkg}}/api/controller" - "{{.Module}}/pkg/{{.AppPkg}}/api/middleware" + "{{.Module}}/pkg/{{.AppPkg}}/api" + "{{.Module}}/pkg/{{.AppPkg}}/middleware" "{{.Module}}/pkg/{{.AppPkg}}/web" "{{.Module}}/pkg/lib" lib_middleware "{{.Module}}/pkg/lib/middleware" @@ -29,7 +29,6 @@ func App(r chi.Router) { // r.Method(http.MethodGet, "/metrics", promhttp.Handler()) r.With(lib_middleware.Log).Group(func(r chi.Router) { - // v1 r.Route("/v1", func(r chi.Router) { v1(r) }) @@ -39,7 +38,7 @@ func App(r chi.Router) { func v1(r chi.Router) { r.With(middleware.Auth).Group(func(r chi.Router) { { - r.Post("/user/create", controller.DemoCreate) + r.Post("/greeter", api.Hello) } }) } diff --git a/cmd/internal/http/app/pkg_app_service_greeter.tmpl b/cmd/internal/http/app/pkg_app_service_greeter.tmpl new file mode 100644 index 0000000..8660bac --- /dev/null +++ b/cmd/internal/http/app/pkg_app_service_greeter.tmpl @@ -0,0 +1,22 @@ +package greeter + +import ( + "context" + + "{{.Module}}/pkg/lib/result" +) + +type ReqHello struct { + Name string `json:"name" valid:"required"` +} + +type RespHello struct { + Message string `json:"message"` +} + +func Hello(ctx context.Context, req *ReqHello) result.Result { + resp := RespHello{ + Message: "Hello " + req.Name, + } + return result.OK(result.V(resp)) +} diff --git a/cmd/internal/http/app/pkg_app_api_service_demo_test.tmpl b/cmd/internal/http/app/pkg_app_service_test.tmpl similarity index 80% rename from cmd/internal/http/app/pkg_app_api_service_demo_test.tmpl rename to cmd/internal/http/app/pkg_app_service_test.tmpl index 8ff64bf..5f1e5a8 100644 --- a/cmd/internal/http/app/pkg_app_api_service_demo_test.tmpl +++ b/cmd/internal/http/app/pkg_app_service_test.tmpl @@ -1,4 +1,4 @@ -package demo +package greeter import ( "context" @@ -27,10 +27,10 @@ func TestMain(m *testing.M) { ent.Close() } -func TestCreate(t *testing.T) { - req := &ReqCreate{ - Title: "demo", +func TestHello(t *testing.T) { + req := &ReqHello{ + Name: "world", } - ret := Create(context.Background(), req) + ret := Hello(context.Background(), req) fmt.Println("[Result] ---", ret) } diff --git a/cmd/internal/http/embed.go b/cmd/internal/http/embed.go index 93a6587..b5f2c7b 100644 --- a/cmd/internal/http/embed.go +++ b/cmd/internal/http/embed.go @@ -76,6 +76,11 @@ var Project = []map[string]string{ "path": "pkg_lib_log_log.tmpl", "output": "pkg/lib/log/log.go", }, + { + "name": "pkg_lib_log_traceid.tmpl", + "path": "pkg_lib_log_traceid.tmpl", + "output": "pkg/lib/log/trace_id.go", + }, { "name": "pkg_lib_middleware_log.tmpl", "path": "pkg_lib_middleware_log.tmpl", @@ -91,6 +96,11 @@ var Project = []map[string]string{ "path": "pkg_lib_middleware_recovery.tmpl", "output": "pkg/lib/middleware/recovery.go", }, + { + "name": "pkg_lib_middleware_traceid.tmpl", + "path": "pkg_lib_middleware_traceid.tmpl", + "output": "pkg/lib/middleware/trace_id.go", + }, { "name": "pkg_lib_result_code.tmpl", "path": "pkg_lib_result_code.tmpl", @@ -110,29 +120,29 @@ var Project = []map[string]string{ var App = []map[string]string{ { - "name": "pkg_app_api_controller_demo.tmpl", - "path": "app/pkg_app_api_controller_demo.tmpl", - "output": "api/controller/demo.go", + "name": "pkg_app_api_greeter.tmpl", + "path": "app/pkg_app_api_greeter.tmpl", + "output": "api/greeter.go", }, { - "name": "pkg_app_api_middleware_auth.tmpl", - "path": "app/pkg_app_api_middleware_auth.tmpl", - "output": "api/middleware/auth.go", + "name": "pkg_app_middleware_auth.tmpl", + "path": "app/pkg_app_middleware_auth.tmpl", + "output": "middleware/auth.go", }, { - "name": "pkg_app_api_router_app.tmpl", - "path": "app/pkg_app_api_router_app.tmpl", - "output": "api/router/app.go", + "name": "pkg_app_router_app.tmpl", + "path": "app/pkg_app_router_app.tmpl", + "output": "router/app.go", }, { - "name": "pkg_app_api_service_demo_create.tmpl", - "path": "app/pkg_app_api_service_demo_create.tmpl", - "output": "api/service/demo/create.go", + "name": "pkg_app_service_greeter.tmpl", + "path": "app/pkg_app_service_greeter.tmpl", + "output": "service/greeter/hello.go", }, { - "name": "pkg_app_api_service_demo_test.tmpl", - "path": "app/pkg_app_api_service_demo_test.tmpl", - "output": "api/service/demo/demo_test.go", + "name": "pkg_app_service_test.tmpl", + "path": "app/pkg_app_service_test.tmpl", + "output": "service/greeter/greeter_test.go", }, { "name": "pkg_app_cmd_hello.tmpl", diff --git a/cmd/internal/http/pkg_lib_identity_identity.tmpl b/cmd/internal/http/pkg_lib_identity_identity.tmpl index 2bf13c9..a9271e3 100644 --- a/cmd/internal/http/pkg_lib_identity_identity.tmpl +++ b/cmd/internal/http/pkg_lib_identity_identity.tmpl @@ -14,10 +14,6 @@ import ( "{{.Module}}/pkg/lib/log" ) -type CtxKeyAuth int - -const IdentityKey CtxKeyAuth = 0 - // Identity 授权身份 type Identity interface { // ID 授权ID @@ -77,17 +73,24 @@ func New(id int64, token string) Identity { } } +type identityKey struct{} + +// NewContext 将Identity注入context +func NewContext(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, identityKey{}, FromAuthToken(ctx, token)) +} + // FromContext 获取授权信息 func FromContext(ctx context.Context) Identity { if ctx == nil { return NewEmpty() } - _identity, ok := ctx.Value(IdentityKey).(Identity) + id, ok := ctx.Value(identityKey{}).(Identity) if !ok { return NewEmpty() } - return _identity + return id } // FromAuthToken 解析授权Token @@ -105,10 +108,10 @@ func FromAuthToken(ctx context.Context, token string) Identity { return NewEmpty() } - _identity := NewEmpty() - if err = json.Unmarshal(plainText, _identity); err != nil { + id := NewEmpty() + if err = json.Unmarshal(plainText, id); err != nil { log.Error(ctx, "Error json.Unmarshal AuthToken", zap.Error(err)) return NewEmpty() } - return _identity + return id } diff --git a/cmd/internal/http/pkg_lib_log_log.tmpl b/cmd/internal/http/pkg_lib_log_log.tmpl index fb0c158..4b5fa27 100644 --- a/cmd/internal/http/pkg_lib_log_log.tmpl +++ b/cmd/internal/http/pkg_lib_log_log.tmpl @@ -3,26 +3,49 @@ package log import ( "context" - "github.com/go-chi/chi/v5/middleware" "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) func Info(ctx context.Context, msg string, fields ...zap.Field) { - logger.Info(msg, append(fields, zap.String("request_id", middleware.GetReqID(ctx)))...) + Log(ctx, zapcore.InfoLevel, msg, fields...) } func Warn(ctx context.Context, msg string, fields ...zap.Field) { - logger.Warn(msg, append(fields, zap.String("request_id", middleware.GetReqID(ctx)))...) + Log(ctx, zapcore.WarnLevel, msg, fields...) } func Error(ctx context.Context, msg string, fields ...zap.Field) { - logger.Error(msg, append(fields, zap.String("request_id", middleware.GetReqID(ctx)))...) + Log(ctx, zapcore.ErrorLevel, msg, fields...) } func Panic(ctx context.Context, msg string, fields ...zap.Field) { - logger.Panic(msg, append(fields, zap.String("request_id", middleware.GetReqID(ctx)))...) + Log(ctx, zapcore.PanicLevel, msg, fields...) } func Fatal(ctx context.Context, msg string, fields ...zap.Field) { - logger.Fatal(msg, append(fields, zap.String("request_id", middleware.GetReqID(ctx)))...) + Log(ctx, zapcore.FatalLevel, msg, fields...) +} + +func Log(ctx context.Context, level zapcore.Level, msg string, fields ...zap.Field) { + traceId, path := GetTraceInfo(ctx) + fields = append(fields, + zap.String("hostname", hostname), + zap.String("trace_id", traceId), + zap.String("path", path), + ) + switch level { + case zapcore.InfoLevel: + logger.Info(msg, fields...) + case zapcore.WarnLevel: + logger.Warn(msg, fields...) + case zapcore.ErrorLevel: + logger.Error(msg, fields...) + case zapcore.PanicLevel: + logger.Panic(msg, fields...) + case zapcore.FatalLevel: + logger.Fatal(msg, fields...) + default: + logger.Debug(msg, fields...) + } } diff --git a/cmd/internal/http/pkg_lib_log_traceid.tmpl b/cmd/internal/http/pkg_lib_log_traceid.tmpl new file mode 100644 index 0000000..6669381 --- /dev/null +++ b/cmd/internal/http/pkg_lib_log_traceid.tmpl @@ -0,0 +1,63 @@ +package log + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "os" + "strings" + "sync/atomic" + + "github.com/shenghui0779/yiigo/metadata" + "github.com/shenghui0779/yiigo/xhash" +) + +const TraceId = "x-trace-id" +const TracePath = "x-trace-path" + +var ( + hostname string + tracePrefix string + traceSeq uint64 +) + +func init() { + hostname, _ = os.Hostname() + if len(hostname) == 0 { + hostname = "localhost" + } + + var ( + buf [12]byte + b64 string + ) + for len(b64) < 10 { + _, _ = rand.Read(buf[:]) + b64 = base64.StdEncoding.EncodeToString(buf[:]) + b64 = strings.NewReplacer("+", "", "/", "").Replace(b64) + } + tracePrefix = fmt.Sprintf("%s/%s", hostname, b64) +} + +// NewTraceId generates a new trace ID in the sequence. +func NewTraceId() string { + seq := atomic.AddUint64(&traceSeq, 1) + return xhash.MD5(fmt.Sprintf("%s-%d", tracePrefix, seq)) +} + +func GetTraceInfo(ctx context.Context) (traceId, path string) { + traceId = "-" + path = "-" + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return + } + if v := md.Get(TraceId); len(v) != 0 { + traceId = v[0] + } + if v := md.Get(TracePath); len(v) != 0 { + path = v[0] + } + return +} diff --git a/cmd/internal/http/pkg_lib_middleware_log.tmpl b/cmd/internal/http/pkg_lib_middleware_log.tmpl index 7abed1f..6aea4fe 100644 --- a/cmd/internal/http/pkg_lib_middleware_log.tmpl +++ b/cmd/internal/http/pkg_lib_middleware_log.tmpl @@ -33,7 +33,6 @@ func Log(next http.Handler) http.Handler { result.ErrSystem(result.E(errors.WithMessage(err, "表单解析失败"))).JSON(w, r) return } - body = r.Form.Encode() case xhttp.ContentFormMultipart: if err := r.ParseMultipartForm(xhttp.MaxFormMemory); err != nil { @@ -42,7 +41,6 @@ func Log(next http.Handler) http.Handler { return } } - body = r.Form.Encode() case ContentJSON: b, err := io.ReadAll(r.Body) // 取出Body @@ -50,16 +48,15 @@ func Log(next http.Handler) http.Handler { result.ErrSystem(result.E(errors.WithMessage(err, "请求Body读取失败"))).JSON(w, r) return } - r.Body.Close() // 关闭原Body - + _ = r.Body.Close() // 关闭原Body body = string(pretty.Ugly(b)) - r.Body = io.NopCloser(bytes.NewReader(b)) // 重新赋值Body + // 重新赋值Body + r.Body = io.NopCloser(bytes.NewReader(b)) } } defer func() { log.Info(r.Context(), "Request info", zap.String("method", r.Method), - zap.String("url", r.URL.String()), zap.String("ip", r.RemoteAddr), zap.String("body", body), zap.String("identity", identity.FromContext(r.Context()).String()), diff --git a/cmd/internal/http/pkg_lib_middleware_recovery.tmpl b/cmd/internal/http/pkg_lib_middleware_recovery.tmpl index c744966..ee0d48b 100644 --- a/cmd/internal/http/pkg_lib_middleware_recovery.tmpl +++ b/cmd/internal/http/pkg_lib_middleware_recovery.tmpl @@ -1,7 +1,6 @@ package middleware import ( - "context" "net/http" "runtime/debug" @@ -22,16 +21,12 @@ func Recovery(next http.Handler) http.Handler { result.ErrSystem().JSON(w, r) } }() - + // 注入Identity if token := r.Header.Get("Authorization"); len(token) != 0 { - ctx := r.Context() - id := identity.FromAuthToken(ctx, token) - - next.ServeHTTP(w, r.WithContext(context.WithValue(ctx, identity.IdentityKey, id))) - + ctx := identity.NewContext(r.Context(), token) + next.ServeHTTP(w, r.WithContext(ctx)) return } - next.ServeHTTP(w, r) }) } diff --git a/cmd/internal/http/pkg_lib_middleware_traceid.tmpl b/cmd/internal/http/pkg_lib_middleware_traceid.tmpl new file mode 100644 index 0000000..1a7db3b --- /dev/null +++ b/cmd/internal/http/pkg_lib_middleware_traceid.tmpl @@ -0,0 +1,35 @@ +package middleware + +import ( + "net/http" + + "github.com/shenghui0779/yiigo/metadata" + + "{{.Module}}/pkg/lib/log" +) + +// TraceId is a middleware that injects a trace ID into the context of each request. +func TraceId(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + md, ok := metadata.FromIncomingContext(r.Context()) + if !ok { + md = metadata.Pairs() + } + md.Set(log.TracePath, r.URL.Path) + // traceId已存在,则复用 + if len(md.Get(log.TraceId)) != 0 { + next.ServeHTTP(w, r) + return + } + // 去header取traceId + traceId := r.Header.Get(log.TraceId) + if len(traceId) == 0 { + traceId = log.NewTraceId() + } + // 设置traceId + md.Set(log.TraceId, traceId) + ctx := metadata.NewIncomingContext(r.Context(), md) + w.Header().Set(log.TraceId, traceId) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/cmd/internal/http/pkg_lib_result_result.tmpl b/cmd/internal/http/pkg_lib_result_result.tmpl index 9e9fe50..d6185dd 100644 --- a/cmd/internal/http/pkg_lib_result_result.tmpl +++ b/cmd/internal/http/pkg_lib_result_result.tmpl @@ -4,7 +4,6 @@ import ( "encoding/json" "net/http" - "github.com/go-chi/chi/v5/middleware" "github.com/shenghui0779/yiigo" "github.com/shenghui0779/yiigo/xhttp" "go.uber.org/zap" @@ -24,14 +23,10 @@ type response struct { } func (resp *response) JSON(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - w.Header().Set("Request-ID", middleware.GetReqID(ctx)) w.Header().Set(xhttp.HeaderContentType, xhttp.ContentJSON) w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(resp.x); err != nil { - log.Error(ctx, "Error write response", zap.Error(err)) + log.Error(r.Context(), "Error write response", zap.Error(err)) } } diff --git a/cmd/yiigo/main.go b/cmd/yiigo/main.go index 7185659..4ed1e9a 100644 --- a/cmd/yiigo/main.go +++ b/cmd/yiigo/main.go @@ -46,13 +46,13 @@ func project() *cobra.Command { Use: "new", Short: "创建项目", Example: yiigo.CmdExamples( - "-- HTTP项目", + "-- HTTP --", "yiigo new demo", "yiigo new demo --mod=xxx.yyy.com", "yiigo new demo --apps=foo,bar", "yiigo new demo --apps=foo --apps=bar", "yiigo new demo --mod=xxx.yyy.com --apps=foo --apps=bar", - "-- gRPC项目", + "-- gRPC --", "yiigo new demo --grpc", "yiigo new demo --mod=xxx.yyy.com --grpc", "yiigo new demo --apps=foo,bar --grpc", diff --git a/metadata/metadata.go b/metadata/metadata.go new file mode 100644 index 0000000..da8c626 --- /dev/null +++ b/metadata/metadata.go @@ -0,0 +1,240 @@ +package metadata + +import ( + "context" + "fmt" + "strings" +) + +// MD is a mapping from metadata keys to values. Users should use the following +// two convenience functions New and Pairs to generate MD. +type MD map[string][]string + +// New creates an MD from a given key-value map. +// +// Only the following ASCII characters are allowed in keys: +// - digits: 0-9 +// - uppercase letters: A-Z (normalized to lower) +// - lowercase letters: a-z +// - special characters: -_. +// +// Uppercase letters are automatically converted to lowercase. +func New(m map[string]string) MD { + md := make(MD, len(m)) + for k, val := range m { + key := strings.ToLower(k) + md[key] = append(md[key], val) + } + return md +} + +// Pairs returns an MD formed by the mapping of key, value ... +// Pairs panics if len(kv) is odd. +// +// Only the following ASCII characters are allowed in keys: +// - digits: 0-9 +// - uppercase letters: A-Z (normalized to lower) +// - lowercase letters: a-z +// - special characters: -_. +// +// Uppercase letters are automatically converted to lowercase. +func Pairs(kv ...string) MD { + if len(kv)%2 == 1 { + panic(fmt.Sprintf("metadata: Pairs got the odd number of input pairs for metadata: %d", len(kv))) + } + md := make(MD, len(kv)/2) + for i := 0; i < len(kv); i += 2 { + key := strings.ToLower(kv[i]) + md[key] = append(md[key], kv[i+1]) + } + return md +} + +// Len returns the number of items in md. +func (md MD) Len() int { + return len(md) +} + +// Copy returns a copy of md. +func (md MD) Copy() MD { + out := make(MD, len(md)) + for k, v := range md { + out[k] = copyOf(v) + } + return out +} + +// Get obtains the values for a given key. +// +// k is converted to lowercase before searching in md. +func (md MD) Get(k string) []string { + k = strings.ToLower(k) + return md[k] +} + +// Set sets the value of a given key with a slice of values. +// +// k is converted to lowercase before storing in md. +func (md MD) Set(k string, vals ...string) { + if len(vals) == 0 { + return + } + k = strings.ToLower(k) + md[k] = vals +} + +// Append adds the values to key k, not overwriting what was already stored at +// that key. +// +// k is converted to lowercase before storing in md. +func (md MD) Append(k string, vals ...string) { + if len(vals) == 0 { + return + } + k = strings.ToLower(k) + md[k] = append(md[k], vals...) +} + +// Delete removes the values for a given key k which is converted to lowercase +// before removing it from md. +func (md MD) Delete(k string) { + k = strings.ToLower(k) + delete(md, k) +} + +// Join joins any number of mds into a single MD. +// +// The order of values for each key is determined by the order in which the mds +// containing those values are presented to Join. +func Join(mds ...MD) MD { + out := MD{} + for _, md := range mds { + for k, v := range md { + out[k] = append(out[k], v...) + } + } + return out +} + +type rawMD struct { + md MD + added [][]string +} + +func copyOf(v []string) []string { + vals := make([]string, len(v)) + copy(vals, v) + return vals +} + +type mdIncomingKey struct{} +type mdOutgoingKey struct{} + +// NewIncomingContext creates a new context with incoming md attached. md must +// not be modified after calling this function. +func NewIncomingContext(ctx context.Context, md MD) context.Context { + return context.WithValue(ctx, mdIncomingKey{}, md) +} + +// NewOutgoingContext creates a new context with outgoing md attached. If used +// in conjunction with AppendToOutgoingContext, NewOutgoingContext will +// overwrite any previously-appended metadata. md must not be modified after +// calling this function. +func NewOutgoingContext(ctx context.Context, md MD) context.Context { + return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md}) +} + +// AppendToOutgoingContext returns a new context with the provided kv merged +// with any existing metadata in the context. Please refer to the documentation +// of Pairs for a description of kv. +func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context { + if len(kv)%2 == 1 { + panic(fmt.Sprintf("metadata: AppendToOutgoingContext got an odd number of input pairs for metadata: %d", len(kv))) + } + md, _ := ctx.Value(mdOutgoingKey{}).(rawMD) + added := make([][]string, len(md.added)+1) + copy(added, md.added) + kvCopy := make([]string, 0, len(kv)) + for i := 0; i < len(kv); i += 2 { + kvCopy = append(kvCopy, strings.ToLower(kv[i]), kv[i+1]) + } + added[len(added)-1] = kvCopy + return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md.md, added: added}) +} + +// FromIncomingContext returns the incoming metadata in ctx if it exists. +// +// All keys in the returned MD are lowercase. +func FromIncomingContext(ctx context.Context) (MD, bool) { + md, ok := ctx.Value(mdIncomingKey{}).(MD) + if !ok { + return nil, false + } + out := make(MD, len(md)) + for k, v := range md { + // We need to manually convert all keys to lower case, because MD is a + // map, and there's no guarantee that the MD attached to the context is + // created using our helper functions. + key := strings.ToLower(k) + out[key] = copyOf(v) + } + return out, true +} + +// ValueFromIncomingContext returns the metadata value corresponding to the metadata +// key from the incoming metadata if it exists. Keys are matched in a case insensitive +// manner. +func ValueFromIncomingContext(ctx context.Context, key string) []string { + md, ok := ctx.Value(mdIncomingKey{}).(MD) + if !ok { + return nil + } + + if v, ok := md[key]; ok { + return copyOf(v) + } + for k, v := range md { + // Case insensitive comparison: MD is a map, and there's no guarantee + // that the MD attached to the context is created using our helper + // functions. + if strings.EqualFold(k, key) { + return copyOf(v) + } + } + return nil +} + +// FromOutgoingContext returns the outgoing metadata in ctx if it exists. +// +// All keys in the returned MD are lowercase. +func FromOutgoingContext(ctx context.Context) (MD, bool) { + raw, ok := ctx.Value(mdOutgoingKey{}).(rawMD) + if !ok { + return nil, false + } + + mdSize := len(raw.md) + for i := range raw.added { + mdSize += len(raw.added[i]) / 2 + } + + out := make(MD, mdSize) + for k, v := range raw.md { + // We need to manually convert all keys to lower case, because MD is a + // map, and there's no guarantee that the MD attached to the context is + // created using our helper functions. + key := strings.ToLower(k) + out[key] = copyOf(v) + } + for _, added := range raw.added { + if len(added)%2 == 1 { + panic(fmt.Sprintf("metadata: FromOutgoingContext got an odd number of input pairs for metadata: %d", len(added))) + } + + for i := 0; i < len(added); i += 2 { + key := strings.ToLower(added[i]) + out[key] = append(out[key], added[i+1]) + } + } + return out, ok +} diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go new file mode 100644 index 0000000..ec1be09 --- /dev/null +++ b/metadata/metadata_test.go @@ -0,0 +1,389 @@ +package metadata + +import ( + "context" + "reflect" + "strconv" + "testing" + "time" +) + +const defaultTestTimeout = 10 * time.Second + +func TestPairsMD(t *testing.T) { + for _, test := range []struct { + // input + kv []string + // output + md MD + }{ + {[]string{}, MD{}}, + {[]string{"k1", "v1", "k1", "v2"}, MD{"k1": []string{"v1", "v2"}}}, + } { + md := Pairs(test.kv...) + if !reflect.DeepEqual(md, test.md) { + t.Fatalf("Pairs(%v) = %v, want %v", test.kv, md, test.md) + } + } +} + +func TestCopy(t *testing.T) { + const key, val = "key", "val" + orig := Pairs(key, val) + cpy := orig.Copy() + if !reflect.DeepEqual(orig, cpy) { + t.Errorf("copied value not equal to the original, got %v, want %v", cpy, orig) + } + orig[key][0] = "foo" + if v := cpy[key][0]; v != val { + t.Errorf("change in original should not affect copy, got %q, want %q", v, val) + } +} + +func TestJoin(t *testing.T) { + for _, test := range []struct { + mds []MD + want MD + }{ + {[]MD{}, MD{}}, + {[]MD{Pairs("foo", "bar")}, Pairs("foo", "bar")}, + {[]MD{Pairs("foo", "bar"), Pairs("foo", "baz")}, Pairs("foo", "bar", "foo", "baz")}, + {[]MD{Pairs("foo", "bar"), Pairs("foo", "baz"), Pairs("zip", "zap")}, Pairs("foo", "bar", "foo", "baz", "zip", "zap")}, + } { + md := Join(test.mds...) + if !reflect.DeepEqual(md, test.want) { + t.Errorf("context's metadata is %v, want %v", md, test.want) + } + } +} + +func TestGet(t *testing.T) { + for _, test := range []struct { + md MD + key string + wantVals []string + }{ + {md: Pairs("My-Optional-Header", "42"), key: "My-Optional-Header", wantVals: []string{"42"}}, + {md: Pairs("Header", "42", "Header", "43", "Header", "44", "other", "1"), key: "HEADER", wantVals: []string{"42", "43", "44"}}, + {md: Pairs("HEADER", "10"), key: "HEADER", wantVals: []string{"10"}}, + } { + vals := test.md.Get(test.key) + if !reflect.DeepEqual(vals, test.wantVals) { + t.Errorf("value of metadata %v is %v, want %v", test.key, vals, test.wantVals) + } + } +} + +func TestSet(t *testing.T) { + for _, test := range []struct { + md MD + setKey string + setVals []string + want MD + }{ + { + md: Pairs("My-Optional-Header", "42", "other-key", "999"), + setKey: "Other-Key", + setVals: []string{"1"}, + want: Pairs("my-optional-header", "42", "other-key", "1"), + }, + { + md: Pairs("My-Optional-Header", "42"), + setKey: "Other-Key", + setVals: []string{"1", "2", "3"}, + want: Pairs("my-optional-header", "42", "other-key", "1", "other-key", "2", "other-key", "3"), + }, + { + md: Pairs("My-Optional-Header", "42"), + setKey: "Other-Key", + setVals: []string{}, + want: Pairs("my-optional-header", "42"), + }, + } { + test.md.Set(test.setKey, test.setVals...) + if !reflect.DeepEqual(test.md, test.want) { + t.Errorf("value of metadata is %v, want %v", test.md, test.want) + } + } +} + +func TestAppend(t *testing.T) { + for _, test := range []struct { + md MD + appendKey string + appendVals []string + want MD + }{ + { + md: Pairs("My-Optional-Header", "42"), + appendKey: "Other-Key", + appendVals: []string{"1"}, + want: Pairs("my-optional-header", "42", "other-key", "1"), + }, + { + md: Pairs("My-Optional-Header", "42"), + appendKey: "my-OptIoNal-HeAder", + appendVals: []string{"1", "2", "3"}, + want: Pairs("my-optional-header", "42", "my-optional-header", "1", + "my-optional-header", "2", "my-optional-header", "3"), + }, + { + md: Pairs("My-Optional-Header", "42"), + appendKey: "my-OptIoNal-HeAder", + appendVals: []string{}, + want: Pairs("my-optional-header", "42"), + }, + } { + test.md.Append(test.appendKey, test.appendVals...) + if !reflect.DeepEqual(test.md, test.want) { + t.Errorf("value of metadata is %v, want %v", test.md, test.want) + } + } +} + +func TestDelete(t *testing.T) { + for _, test := range []struct { + md MD + deleteKey string + want MD + }{ + { + md: Pairs("My-Optional-Header", "42"), + deleteKey: "My-Optional-Header", + want: Pairs(), + }, + { + md: Pairs("My-Optional-Header", "42"), + deleteKey: "Other-Key", + want: Pairs("my-optional-header", "42"), + }, + { + md: Pairs("My-Optional-Header", "42"), + deleteKey: "my-OptIoNal-HeAder", + want: Pairs(), + }, + } { + test.md.Delete(test.deleteKey) + if !reflect.DeepEqual(test.md, test.want) { + t.Errorf("value of metadata is %v, want %v", test.md, test.want) + } + } +} + +func TestFromIncomingContext(t *testing.T) { + md := Pairs( + "X-My-Header-1", "42", + ) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Verify that we lowercase if callers directly modify md + md["X-INCORRECT-UPPERCASE"] = []string{"foo"} + ctx = NewIncomingContext(ctx, md) + + result, found := FromIncomingContext(ctx) + if !found { + t.Fatal("FromIncomingContext must return metadata") + } + expected := MD{ + "x-my-header-1": []string{"42"}, + "x-incorrect-uppercase": []string{"foo"}, + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("FromIncomingContext returned %#v, expected %#v", result, expected) + } + + // ensure modifying result does not modify the value in the context + result["new_key"] = []string{"foo"} + result["x-my-header-1"][0] = "mutated" + + result2, found := FromIncomingContext(ctx) + if !found { + t.Fatal("FromIncomingContext must return metadata") + } + if !reflect.DeepEqual(result2, expected) { + t.Errorf("FromIncomingContext after modifications returned %#v, expected %#v", result2, expected) + } +} + +func TestValueFromIncomingContext(t *testing.T) { + md := Pairs( + "X-My-Header-1", "42", + "X-My-Header-2", "43-1", + "X-My-Header-2", "43-2", + "x-my-header-3", "44", + ) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Verify that we lowercase if callers directly modify md + md["X-INCORRECT-UPPERCASE"] = []string{"foo"} + ctx = NewIncomingContext(ctx, md) + + for _, test := range []struct { + key string + want []string + }{ + { + key: "x-my-header-1", + want: []string{"42"}, + }, + { + key: "x-my-header-2", + want: []string{"43-1", "43-2"}, + }, + { + key: "x-my-header-3", + want: []string{"44"}, + }, + { + key: "x-unknown", + want: nil, + }, + { + key: "x-incorrect-uppercase", + want: []string{"foo"}, + }, + } { + v := ValueFromIncomingContext(ctx, test.key) + if !reflect.DeepEqual(v, test.want) { + t.Errorf("value of metadata is %v, want %v", v, test.want) + } + } +} + +func TestAppendToOutgoingContext(t *testing.T) { + // Pre-existing metadata + tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx := NewOutgoingContext(tCtx, Pairs("k1", "v1", "k2", "v2")) + ctx = AppendToOutgoingContext(ctx, "k1", "v3") + ctx = AppendToOutgoingContext(ctx, "k1", "v4") + md, ok := FromOutgoingContext(ctx) + if !ok { + t.Errorf("Expected MD to exist in ctx, but got none") + } + want := Pairs("k1", "v1", "k1", "v3", "k1", "v4", "k2", "v2") + if !reflect.DeepEqual(md, want) { + t.Errorf("context's metadata is %v, want %v", md, want) + } + + // No existing metadata + ctx = AppendToOutgoingContext(tCtx, "k1", "v1") + md, ok = FromOutgoingContext(ctx) + if !ok { + t.Errorf("Expected MD to exist in ctx, but got none") + } + want = Pairs("k1", "v1") + if !reflect.DeepEqual(md, want) { + t.Errorf("context's metadata is %v, want %v", md, want) + } +} + +func TestAppendToOutgoingContext_Repeated(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + for i := 0; i < 100; i = i + 2 { + ctx1 := AppendToOutgoingContext(ctx, "k", strconv.Itoa(i)) + ctx2 := AppendToOutgoingContext(ctx, "k", strconv.Itoa(i+1)) + + md1, _ := FromOutgoingContext(ctx1) + md2, _ := FromOutgoingContext(ctx2) + + if reflect.DeepEqual(md1, md2) { + t.Fatalf("md1, md2 = %v, %v; should not be equal", md1, md2) + } + + ctx = ctx1 + } +} + +func TestAppendToOutgoingContext_FromKVSlice(t *testing.T) { + const k, v = "a", "b" + kv := []string{k, v} + tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx := AppendToOutgoingContext(tCtx, kv...) + md, _ := FromOutgoingContext(ctx) + if md[k][0] != v { + t.Fatalf("md[%q] = %q; want %q", k, md[k], v) + } + kv[1] = "xxx" + md, _ = FromOutgoingContext(ctx) + if md[k][0] != v { + t.Fatalf("md[%q] = %q; want %q", k, md[k], v) + } +} + +// Old/slow approach to adding metadata to context +func Benchmark_AddingMetadata_ContextManipulationApproach(b *testing.B) { + // TODO: Add in N=1-100 tests once Go1.6 support is removed. + const num = 10 + for n := 0; n < b.N; n++ { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + for i := 0; i < num; i++ { + md, _ := FromOutgoingContext(ctx) + NewOutgoingContext(ctx, Join(Pairs("k1", "v1", "k2", "v2"), md)) + } + } +} + +// Newer/faster approach to adding metadata to context +func BenchmarkAppendToOutgoingContext(b *testing.B) { + const num = 10 + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + for n := 0; n < b.N; n++ { + for i := 0; i < num; i++ { + ctx = AppendToOutgoingContext(ctx, "k1", "v1", "k2", "v2") + } + } +} + +func BenchmarkFromOutgoingContext(b *testing.B) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx = NewOutgoingContext(ctx, MD{"k3": {"v3", "v4"}}) + ctx = AppendToOutgoingContext(ctx, "k1", "v1", "k2", "v2") + + for n := 0; n < b.N; n++ { + FromOutgoingContext(ctx) + } +} + +func BenchmarkFromIncomingContext(b *testing.B) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + md := Pairs("X-My-Header-1", "42") + ctx = NewIncomingContext(ctx, md) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + FromIncomingContext(ctx) + } +} + +func BenchmarkValueFromIncomingContext(b *testing.B) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + md := Pairs("X-My-Header-1", "42") + ctx = NewIncomingContext(ctx, md) + + b.Run("key-found", func(b *testing.B) { + for n := 0; n < b.N; n++ { + result := ValueFromIncomingContext(ctx, "x-my-header-1") + if len(result) != 1 { + b.Fatal("ensures not optimized away") + } + } + }) + + b.Run("key-not-found", func(b *testing.B) { + for n := 0; n < b.N; n++ { + result := ValueFromIncomingContext(ctx, "key-not-found") + if len(result) != 0 { + b.Fatal("ensures not optimized away") + } + } + }) +} diff --git a/value/value.go b/value/value.go index 2a5a6d9..8fb8ace 100644 --- a/value/value.go +++ b/value/value.go @@ -6,10 +6,10 @@ import ( "strings" ) -// V 用于处理 k-v 需要格式化的场景,如:签名 +// V 用于处理 k/v 需要格式化的场景,如:签名 type V map[string]string -// Set 设置 k-v +// Set 设置 k/v func (v V) Set(key, value string) { v[key] = value } @@ -54,7 +54,6 @@ func (v V) Encode(sym, sep string, opts ...Option) string { sort.Strings(keys) var buf strings.Builder - for _, k := range keys { val := v[k] if len(val) == 0 && o.emptyMode == EmptyIgnore { @@ -83,6 +82,5 @@ func (v V) Encode(sym, sep string, opts ...Option) string { buf.WriteString(sym) } } - return buf.String() } diff --git a/xhttp/client.go b/xhttp/client.go index c9eb26f..42c50f1 100644 --- a/xhttp/client.go +++ b/xhttp/client.go @@ -8,6 +8,8 @@ import ( "net" "net/http" "time" + + "github.com/shenghui0779/yiigo/metadata" ) // Client HTTP客户端 @@ -30,6 +32,14 @@ func (c *client) Do(ctx context.Context, method, reqURL string, body []byte, opt return nil, err } + // context元数据注入header + md, ok := metadata.FromIncomingContext(ctx) + if ok && len(md) != 0 { + for k, v := range md { + opts = append(opts, WithHeader(k, v...)) + } + } + // 处理options o := new(options) if len(opts) != 0 { o.header = http.Header{} @@ -37,7 +47,6 @@ func (c *client) Do(ctx context.Context, method, reqURL string, body []byte, opt f(o) } } - // header if len(o.header) != 0 { req.Header = o.header @@ -63,24 +72,29 @@ func (c *client) Do(ctx context.Context, method, reqURL string, body []byte, opt } return nil, err } - return resp, nil } -func (c *client) Upload(ctx context.Context, reqURL string, form UploadForm, options ...Option) (*http.Response, error) { +func (c *client) Upload(ctx context.Context, reqURL string, form UploadForm, opts ...Option) (*http.Response, error) { buf := bytes.NewBuffer(make([]byte, 0, 20<<10)) // 20kb w := multipart.NewWriter(buf) if err := form.Write(w); err != nil { return nil, err } - - options = append(options, WithHeader("Content-Type", w.FormDataContentType())) + // context元数据注入header + md, ok := metadata.FromIncomingContext(ctx) + if ok && len(md) != 0 { + for k, v := range md { + opts = append(opts, WithHeader(k, v...)) + } + } + opts = append(opts, WithHeader("Content-Type", w.FormDataContentType())) // Don't forget to close the multipart writer. // If you don't close it, your request will be missing the terminating boundary. if err := w.Close(); err != nil { return nil, err } - return c.Do(ctx, http.MethodPost, reqURL, buf.Bytes(), options...) + return c.Do(ctx, http.MethodPost, reqURL, buf.Bytes(), opts...) } // NewDefaultClient 生成一个默认的HTTP客户端