背景
在使用Gin框架进行服务开发时,我们遇到了一个日志记录的问题。由于Gin的上下文(*gin.Context
)实现了context.Context
接口,在调用日志记录器的Info
、Warn
、Error
等方法时,直接传递Gin的上下文通常不会导致编译错误。会导致我们在《golang:微服务架构下的日志追踪系统》一文中定义的日志统计维度信息无法正确串联。
为了解决这一问题,我们之前的解决方案是让业务方在使用日志记录器时,将Gin的上下文手动转换为context.Context
。
但这种方案带来了一个明显的弊端:它依赖于开发人员的主动转换,而开发人员往往会忘记进行这一步操作,直接传递Gin的上下文,从而导致日志并未按照预期格式打印。这不仅增加了开发人员的负担,也降低了日志系统的可靠性和准确性。
解决方案
针对上述问题,我们提出了一种更为优雅的解决方案:在日志记录器的方法内部封装Gin上下文到context.Context
的转换逻辑。这样一来,开发人员在使用日志记录器时无需再关心上下文的转换问题,只需直接传递Gin的上下文即可。
接下来,我们将这一转换逻辑集成到日志记录器的各个方法中,确保无论传入的是Gin的上下文还是context.Context
,都能正确记录日志信息。
在封装的方法中实现gin的context的转换即可。
// gin context to context
func GinContextToContext(ctx context.Context) context.Context {
if v, ok := ctx.(*gin.Context); ok {
if v == nil || v.Request == nil {
return context.Background()
}
return v.Request.Context()
}
return ctx
}
完整的逻辑。
package logger
import (
"context"
"errors"
"fmt"
"time"
"{{your module}}}/go-core/utils/constant"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/utils"
)
// Logger logger for gorm2
type Logger struct {
log *zap.Logger
logger.Config
customFields []func(ctx context.Context) zap.Field
}
// Option logger/recover option
type Option func(l *Logger)
// WithCustomFields optional custom field
func WithCustomFields(fields ...func(ctx context.Context) zap.Field) Option {
return func(l *Logger) {
l.customFields = fields
}
}
// WithConfig optional custom logger.Config
func WithConfig(cfg logger.Config) Option {
return func(l *Logger) {
l.Config = cfg
}
}
// SetGormDBLogger set db logger
func SetGormDBLogger(db *gorm.DB, l logger.Interface) {
db.Logger = l
}
// New logger form gorm2
func New(zapLogger *zap.Logger, opts ...Option) logger.Interface {
l := &Logger{
log: zapLogger,
Config: logger.Config{
SlowThreshold: 200 * time.Millisecond,
Colorful: false,
IgnoreRecordNotFoundError: false,
LogLevel: logger.Warn,
},
}
for _, opt := range opts {
opt(l)
}
return l
}
// NewDefault new default logger
// 初始化一个默认的 logger
func NewDefault(zapLogger *zap.Logger) logger.Interface {
return New(zapLogger, WithCustomFields(
func(ctx context.Context) zap.Field {
v := ctx.Value("Request-Id")
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String("trace", vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value("method")
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String("method", vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value("path")
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String("path", vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value("version")
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String("version", vv)
}
return zap.Skip()
}),
WithConfig(logger.Config{
SlowThreshold: 200 * time.Millisecond,
Colorful: false,
IgnoreRecordNotFoundError: false,
LogLevel: logger.Info,
}))
}
// 用于支持微服务架构下的链路追踪
func NewTracingLogger(zapLogger *zap.Logger) logger.Interface {
return New(zapLogger, WithCustomFields(
// trace是链路追踪的唯一标识
// span是当前请求的唯一标识
// parent_span是父请求的唯一标识
func(ctx context.Context) zap.Field {
v := ctx.Value(constant.CONTEXT_KEY_TRACE)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_TRACE, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value(constant.CONTEXT_KEY_SPAN)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_SPAN, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value(constant.CONTEXT_KEY_PARENT_SPAN)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_PARENT_SPAN, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value(constant.CONTEXT_KEY_METHOD)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_METHOD, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value(constant.CONTEXT_KEY_PATH)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_PATH, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
v := ctx.Value(constant.CONTEXT_KEY_VERSION)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_VERSION, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
// 用于标识调用方服务名
v := ctx.Value(constant.CONTEXT_KEY_CALLER_SERVICE_NAME)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_CALLER_SERVICE_NAME, vv)
}
return zap.Skip()
}, func(ctx context.Context) zap.Field {
// 用于标识调用方ip
v := ctx.Value(constant.CONTEXT_KEY_CALLER_IP)
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String(constant.CONTEXT_KEY_CALLER_IP, vv)
}
return zap.Skip()
}),
WithConfig(logger.Config{
SlowThreshold: 200 * time.Millisecond,
Colorful: false,
IgnoreRecordNotFoundError: false,
LogLevel: logger.Info,
}))
}
// gin context to context
func GinContextToContext(ctx context.Context) context.Context {
if v, ok := ctx.(*gin.Context); ok {
if v == nil || v.Request == nil {
return context.Background()
}
return v.Request.Context()
}
return ctx
}
// LogMode log mode
func (l *Logger) LogMode(level logger.LogLevel) logger.Interface {
newLogger := *l
newLogger.LogLevel = level
return &newLogger
}
// Info print info
func (l Logger) Info(ctx context.Context, msg string, args ...interface{}) {
if l.LogLevel >= logger.Info {
ctx = GinContextToContext(ctx)
//预留10个字段位置
fields := make([]zap.Field, 0, 10+len(l.customFields))
fields = append(fields, zap.String("file", utils.FileWithLineNum()))
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
now := time.Now().UnixMilli()
// 从ctx中获取操作的开始时间
if v := ctx.Value(constant.CONTEXT_KEY_EXECUTE_START_TIME); v != nil {
if vv, ok := v.(int64); ok {
// 计算操作的执行时间,以毫秒为单位
duration := now - vv
// 将操作的执行时间放入ctx
fields = append(fields, zap.Int64(constant.CONTEXT_KEY_EXECUTE_DURATION, duration))
}
}
for _, arg := range args {
if vv, ok := arg.(zapcore.Field); ok {
if len(vv.String) > 0 {
fields = append(fields, zap.String(vv.Key, vv.String))
} else if vv.Integer > 0 {
fields = append(fields, zap.Int64(vv.Key, vv.Integer))
} else {
fields = append(fields, zap.Any(vv.Key, vv.Interface))
}
}
}
l.log.Info(msg, fields...)
}
}
// Warn print warn messages
func (l Logger) Warn(ctx context.Context, msg string, args ...interface{}) {
if l.LogLevel >= logger.Warn {
ctx = GinContextToContext(ctx)
//预留10个字段位置
fields := make([]zap.Field, 0, 10+len(l.customFields))
fields = append(fields, zap.String("file", utils.FileWithLineNum()))
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
for _, arg := range args {
if vv, ok := arg.(zapcore.Field); ok {
if len(vv.String) > 0 {
fields = append(fields, zap.String(vv.Key, vv.String))
} else if vv.Integer > 0 {
fields = append(fields, zap.Int64(vv.Key, vv.Integer))
} else {
fields = append(fields, zap.Any(vv.Key, vv.Interface))
}
}
}
l.log.Warn(msg, fields...)
}
}
// Error print error messages
func (l Logger) Error(ctx context.Context, msg string, args ...interface{}) {
if l.LogLevel >= logger.Error {
ctx = GinContextToContext(ctx)
//预留10个字段位置
fields := make([]zap.Field, 0, 10+len(l.customFields))
fields = append(fields, zap.String("file", utils.FileWithLineNum()))
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
for _, arg := range args {
if vv, ok := arg.(zapcore.Field); ok {
if len(vv.String) > 0 {
fields = append(fields, zap.String(vv.Key, vv.String))
} else if vv.Integer > 0 {
fields = append(fields, zap.Int64(vv.Key, vv.Integer))
} else {
fields = append(fields, zap.Any(vv.Key, vv.Interface))
}
}
}
l.log.Error(msg, fields...)
}
}
// Trace print sql message
func (l Logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= logger.Silent {
return
}
ctx = GinContextToContext(ctx)
fields := make([]zap.Field, 0, 6+len(l.customFields))
elapsed := time.Since(begin)
switch {
case err != nil && l.LogLevel >= logger.Error && (!l.IgnoreRecordNotFoundError || !errors.Is(err, gorm.ErrRecordNotFound)):
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
fields = append(fields,
zap.Error(err),
zap.String("file", utils.FileWithLineNum()),
zap.Duration("latency", elapsed),
)
sql, rows := fc()
if rows == -1 {
fields = append(fields, zap.String("rows", "-"))
} else {
fields = append(fields, zap.Int64("rows", rows))
}
fields = append(fields, zap.String("sql", sql))
l.log.Error("", fields...)
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= logger.Warn:
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
fields = append(fields,
zap.Error(err),
zap.String("file", utils.FileWithLineNum()),
zap.String("slow!!!", fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)),
zap.Duration("latency", elapsed),
)
sql, rows := fc()
if rows == -1 {
fields = append(fields, zap.String("rows", "-"))
} else {
fields = append(fields, zap.Int64("rows", rows))
}
fields = append(fields, zap.String("sql", sql))
l.log.Warn("", fields...)
case l.LogLevel == logger.Info:
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
fields = append(fields,
zap.Error(err),
zap.String("file", utils.FileWithLineNum()),
zap.Duration("latency", elapsed),
)
sql, rows := fc()
if rows == -1 {
fields = append(fields, zap.String("rows", "-"))
} else {
fields = append(fields, zap.Int64("rows", rows))
}
fields = append(fields, zap.String("sql", sql))
l.log.Info("", fields...)
}
}
// Immutable custom immutable field
// Deprecated: use Any instead
func Immutable(key string, value interface{}) func(ctx context.Context) zap.Field {
return Any(key, value)
}
// Any custom immutable any field
func Any(key string, value interface{}) func(ctx context.Context) zap.Field {
field := zap.Any(key, value)
return func(ctx context.Context) zap.Field { return field }
}
// String custom immutable string field
func String(key string, value string) func(ctx context.Context) zap.Field {
field := zap.String(key, value)
return func(ctx context.Context) zap.Field { return field }
}
// Int64 custom immutable int64 field
func Int64(key string, value int64) func(ctx context.Context) zap.Field {
field := zap.Int64(key, value)
return func(ctx context.Context) zap.Field { return field }
}
// Uint64 custom immutable uint64 field
func Uint64(key string, value uint64) func(ctx context.Context) zap.Field {
field := zap.Uint64(key, value)
return func(ctx context.Context) zap.Field { return field }
}
// Float64 custom immutable float32 field
func Float64(key string, value float64) func(ctx context.Context) zap.Field {
field := zap.Float64(key, value)
return func(ctx context.Context) zap.Field { return field }
}
测试用例
package logger
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"{{your module}}/go-core/utils/constant"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// Mock function to create a gin context
func createGinContext() *gin.Context {
w := httptest.NewRecorder()
// Create a mock request with an attached context
req, _ := http.NewRequest(http.MethodPost, "/api/v1/users", nil)
ctx := context.Background()
// Here you can set values in the context
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_SPAN, "123456")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_PARENT_SPAN, "parent_span_123456")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_TRACE, "trace_id_123456")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_METHOD, "POST")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_PATH, "/api/v1/users")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_VERSION, "v1.0.0")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_CALLER_SERVICE_NAME, "user-service")
ctx = context.WithValue(ctx, constant.CONTEXT_KEY_CALLER_IP, "172.0.0.3")
req = req.WithContext(ctx)
// Create the Gin context
c, _ := gin.CreateTestContext(w)
c.Request = req
return c
}
// 测试ginContextWithLogger
func TestGinContextWithLogger(t *testing.T) {
// 创建一个 zap logger 实例
zapLogger, _ := zap.NewProduction()
defer zapLogger.Sync() // 确保日志被刷新
// 创建一个带有自定义字段和配置的 Logger 实例
customLogger := NewTracingLogger(zapLogger)
// 创建一个 Gin context
c := createGinContext()
// 测试 Info 方法
customLogger.Info(c, "This is an info message")
// 测试 Warn 方法
customLogger.Warn(c, "This is a warning message")
// 测试 Error 方法
customLogger.Error(c, "This is an error message")
// 测试 Trace 方法,模拟一个慢查询
slowQueryBegin := time.Now()
slowQueryFunc := func() (string, int64) {
return "SELECT * FROM users", 100
}
time.Sleep(2 * time.Second) // 模拟一个慢查询
customLogger.Trace(c, slowQueryBegin, slowQueryFunc, nil)
// 测试 Trace 方法,模拟一个错误查询
errorQueryBegin := time.Now()
errorQueryFunc := func() (string, int64) {
return "SELECT * FROM non_existent_table", 0
}
customLogger.Trace(c, errorQueryBegin, errorQueryFunc, fmt.Errorf("table not found"))
// 由于日志是异步的,我们需要在测试结束时等待一段时间以确保所有日志都被输出
time.Sleep(500 * time.Millisecond)
}