From d4521f8c1ee0fe3d4074521639e0061c236fbea0 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 13 May 2026 05:59:27 +0000 Subject: [PATCH] feat: add memory template customization feature --- .../biz/setting/handler/v1/memory_template.go | 147 ++++++++++ backend/biz/setting/register.go | 2 + backend/biz/user/repo/user.go | 7 + backend/biz/user/usecase/user.go | 20 +- backend/db/migrate/schema.go | 1 + backend/db/mutation.go | 75 +++++- backend/db/runtime/runtime.go | 4 +- backend/db/user.go | 16 +- backend/db/user/user.go | 8 + backend/db/user/where.go | 80 ++++++ backend/db/user_create.go | 78 ++++++ backend/db/user_update.go | 52 ++++ backend/domain/user.go | 30 ++- backend/ent/schema/user.go | 1 + .../console/settings/memory-template.tsx | 251 ++++++++++++++++++ .../console/settings/settings-dialog.tsx | 9 +- 16 files changed, 758 insertions(+), 23 deletions(-) create mode 100644 backend/biz/setting/handler/v1/memory_template.go create mode 100644 frontend/src/components/console/settings/memory-template.tsx diff --git a/backend/biz/setting/handler/v1/memory_template.go b/backend/biz/setting/handler/v1/memory_template.go new file mode 100644 index 00000000..75ba9d04 --- /dev/null +++ b/backend/biz/setting/handler/v1/memory_template.go @@ -0,0 +1,147 @@ +package v1 + +import ( + "log/slog" + "net/http" + + "github.com/GoYoko/web" + "github.com/samber/do" + + "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/errcode" + "github.com/chaitin/MonkeyCode/backend/middleware" +) + +// MemoryTemplateHandler Memory模板处理器 +type MemoryTemplateHandler struct { + usecase domain.UserUsecase + logger *slog.Logger +} + +// NewMemoryTemplateHandler 创建Memory模板处理器 +func NewMemoryTemplateHandler(i *do.Injector) (*MemoryTemplateHandler, error) { + w := do.MustInvoke[*web.Web](i) + logger := do.MustInvoke[*slog.Logger](i) + usecase := do.MustInvoke[domain.UserUsecase](i) + auth := do.MustInvoke[*middleware.AuthMiddleware](i) + targetActive := do.MustInvoke[*middleware.TargetActiveMiddleware](i) + + h := &MemoryTemplateHandler{ + logger: logger.With("component", "handler.memory-template"), + usecase: usecase, + } + + v1 := w.Group("/api/v1/users/settings") + + v1.Use(auth.Auth(), targetActive.TargetActive()) + v1.GET("/memory-template", web.BindHandler(h.Get)) + v1.PUT("/memory-template", web.BindHandler(h.Update)) + v1.DELETE("/memory-template", web.BindHandler(h.Delete)) + + return h, nil +} + +// GetMemoryTemplateReq 获取Memory模板请求 +type GetMemoryTemplateReq struct{} + +// Get 获取用户Memory模板 +// +// @Summary 获取用户Memory模板 +// @Description 获取当前用户的Memory模板设置 +// @Tags 【用户】Memory模板 +// @Accept json +// @Produce json +// @Security MonkeyCodeAIAuth +// @Success 200 {object} web.Resp{data=string} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/users/settings/memory-template [get] +func (h *MemoryTemplateHandler) Get(c *web.Context, req GetMemoryTemplateReq) error { + user := middleware.GetUser(c) + + u, err := h.usecase.Get(c.Request().Context(), user.ID) + if err != nil { + h.logger.ErrorContext(c.Request().Context(), "failed to get user memory template", "error", err, "user_id", user.ID) + return errcode.ErrDatabaseQuery.Wrap(err) + } + + // 如果用户没有设置模板,返回空字符串 + if u == nil || u.MemoryTemplate == nil { + return c.Success("") + } + + return c.Success(*u.MemoryTemplate) +} + +// UpdateMemoryTemplateReq 更新Memory模板请求 +type UpdateMemoryTemplateReq struct { + MemoryTemplate string `json:"memory_template"` +} + +// Update 更新用户Memory模板 +// +// @Summary 更新用户Memory模板 +// @Description 更新当前用户的Memory模板设置 +// @Tags 【用户】Memory模板 +// @Accept json +// @Produce json +// @Security MonkeyCodeAIAuth +// @Param req body UpdateMemoryTemplateReq true "更新Memory模板请求" +// @Success 200 {object} web.Resp{} "成功" +// @Failure 400 {object} web.Resp "请求参数错误" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/users/settings/memory-template [put] +func (h *MemoryTemplateHandler) Update(c *web.Context, req UpdateMemoryTemplateReq) error { + user := middleware.GetUser(c) + + // 验证模板大小(最大 500KB) + if len(req.MemoryTemplate) > 500*1024 { + return c.JSON(http.StatusBadRequest, map[string]interface{}{ + "code": 400, + "message": "模板大小超过限制(最大500KB)", + }) + } + + // 更新用户Memory模板 + _, err := h.usecase.Update(c.Request().Context(), user.ID, "", domain.UpdateUserReq{ + MemoryTemplate: &req.MemoryTemplate, + }) + if err != nil { + h.logger.ErrorContext(c.Request().Context(), "failed to update user memory template", "error", err, "user_id", user.ID) + return errcode.ErrDatabaseOperation.Wrap(err) + } + + return c.Success(nil) +} + +// DeleteMemoryTemplateReq 删除Memory模板请求 +type DeleteMemoryTemplateReq struct{} + +// Delete 删除用户Memory模板(恢复默认) +// +// @Summary 删除用户Memory模板 +// @Description 删除当前用户的Memory模板设置,恢复为默认 +// @Tags 【用户】Memory模板 +// @Accept json +// @Produce json +// @Security MonkeyCodeAIAuth +// @Success 200 {object} web.Resp{} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/users/settings/memory-template [delete] +func (h *MemoryTemplateHandler) Delete(c *web.Context, req DeleteMemoryTemplateReq) error { + user := middleware.GetUser(c) + + // 将模板设为空字符串 + emptyTemplate := "" + _, err := h.usecase.Update(c.Request().Context(), user.ID, "", domain.UpdateUserReq{ + MemoryTemplate: &emptyTemplate, + }) + if err != nil { + h.logger.ErrorContext(c.Request().Context(), "failed to delete user memory template", "error", err, "user_id", user.ID) + return errcode.ErrDatabaseOperation.Wrap(err) + } + + return c.Success(nil) +} diff --git a/backend/biz/setting/register.go b/backend/biz/setting/register.go index 5dbf8c26..ad22b031 100644 --- a/backend/biz/setting/register.go +++ b/backend/biz/setting/register.go @@ -20,6 +20,7 @@ func ProvideSetting(i *do.Injector) { do.Provide(i, v1.NewModelHandler) do.Provide(i, v1.NewImageHandler) do.Provide(i, v1.NewMCPHandler) + do.Provide(i, v1.NewMemoryTemplateHandler) } // InvokeSetting 触发 setting 模块的 handler 初始化 @@ -27,4 +28,5 @@ func InvokeSetting(i *do.Injector) { do.MustInvoke[*v1.ModelHandler](i) do.MustInvoke[*v1.ImageHandler](i) do.MustInvoke[*v1.MCPHandler](i) + do.MustInvoke[*v1.MemoryTemplateHandler](i) } diff --git a/backend/biz/user/repo/user.go b/backend/biz/user/repo/user.go index a73a073a..8f91d2fb 100644 --- a/backend/biz/user/repo/user.go +++ b/backend/biz/user/repo/user.go @@ -114,6 +114,13 @@ func (u *userRepo) GetUserByEmail(ctx context.Context, emails []string) ([]*db.U return u.db.User.Query().WithTeams().Where(user.EmailIn(emails...)).All(ctx) } +// UpdateMemoryTemplate implements domain.UserRepo. +func (u *userRepo) UpdateMemoryTemplate(ctx context.Context, uid uuid.UUID, memoryTemplate string) error { + return u.db.User.UpdateOneID(uid). + SetMemoryTemplate(memoryTemplate). + Exec(ctx) +} + // SetEmail implements domain.UserRepo. func (u *userRepo) SetEmail(ctx context.Context, userID uuid.UUID, email string) error { return u.db.User.UpdateOneID(userID).SetEmail(email).Exec(ctx) diff --git a/backend/biz/user/usecase/user.go b/backend/biz/user/usecase/user.go index 7c0c76af..12ffd3e3 100644 --- a/backend/biz/user/usecase/user.go +++ b/backend/biz/user/usecase/user.go @@ -48,10 +48,22 @@ func (u *UserUsecase) Get(ctx context.Context, uid uuid.UUID) (*domain.User, err // Update implements domain.UserUsecase. func (u *UserUsecase) Update(ctx context.Context, uid uuid.UUID, avatarURL string, req domain.UpdateUserReq) (*domain.User, error) { - err := u.repo.Update(ctx, uid, req.Name, avatarURL) - if err != nil { - u.logger.ErrorContext(ctx, "update user failed", "error", err, "user_id", uid) - return nil, err + // 如果有 memory_template,更新它 + if req.MemoryTemplate != nil { + err := u.repo.UpdateMemoryTemplate(ctx, uid, *req.MemoryTemplate) + if err != nil { + u.logger.ErrorContext(ctx, "update memory template failed", "error", err, "user_id", uid) + return nil, err + } + } + + // 更新其他字段 + if req.Name != "" || avatarURL != "" { + err := u.repo.Update(ctx, uid, req.Name, avatarURL) + if err != nil { + u.logger.ErrorContext(ctx, "update user failed", "error", err, "user_id", uid) + return nil, err + } } user, err := u.Get(ctx, uid) diff --git a/backend/db/migrate/schema.go b/backend/db/migrate/schema.go index c8c31987..9de253cd 100644 --- a/backend/db/migrate/schema.go +++ b/backend/db/migrate/schema.go @@ -1202,6 +1202,7 @@ var ( {Name: "status", Type: field.TypeString}, {Name: "is_blocked", Type: field.TypeBool, Default: false}, {Name: "default_configs", Type: field.TypeJSON, Nullable: true}, + {Name: "memory_template", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, {Name: "updated_at", Type: field.TypeTime}, } diff --git a/backend/db/mutation.go b/backend/db/mutation.go index 595314d6..01297d57 100644 --- a/backend/db/mutation.go +++ b/backend/db/mutation.go @@ -34598,6 +34598,7 @@ type UserMutation struct { status *consts.UserStatus is_blocked *bool default_configs *map[consts.DefaultConfigType]uuid.UUID + memory_template *string created_at *time.Time updated_at *time.Time clearedFields map[string]struct{} @@ -35162,6 +35163,55 @@ func (m *UserMutation) ResetDefaultConfigs() { delete(m.clearedFields, user.FieldDefaultConfigs) } +// SetMemoryTemplate sets the "memory_template" field. +func (m *UserMutation) SetMemoryTemplate(s string) { + m.memory_template = &s +} + +// MemoryTemplate returns the value of the "memory_template" field in the mutation. +func (m *UserMutation) MemoryTemplate() (r string, exists bool) { + v := m.memory_template + if v == nil { + return + } + return *v, true +} + +// OldMemoryTemplate returns the old "memory_template" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldMemoryTemplate(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMemoryTemplate is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMemoryTemplate requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMemoryTemplate: %w", err) + } + return oldValue.MemoryTemplate, nil +} + +// ClearMemoryTemplate clears the value of the "memory_template" field. +func (m *UserMutation) ClearMemoryTemplate() { + m.memory_template = nil + m.clearedFields[user.FieldMemoryTemplate] = struct{}{} +} + +// MemoryTemplateCleared returns if the "memory_template" field was cleared in this mutation. +func (m *UserMutation) MemoryTemplateCleared() bool { + _, ok := m.clearedFields[user.FieldMemoryTemplate] + return ok +} + +// ResetMemoryTemplate resets all changes to the "memory_template" field. +func (m *UserMutation) ResetMemoryTemplate() { + m.memory_template = nil + delete(m.clearedFields, user.FieldMemoryTemplate) +} + // SetCreatedAt sets the "created_at" field. func (m *UserMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -36402,7 +36452,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 11) + fields := make([]string, 0, 12) if m.deleted_at != nil { fields = append(fields, user.FieldDeletedAt) } @@ -36430,6 +36480,9 @@ func (m *UserMutation) Fields() []string { if m.default_configs != nil { fields = append(fields, user.FieldDefaultConfigs) } + if m.memory_template != nil { + fields = append(fields, user.FieldMemoryTemplate) + } if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -36462,6 +36515,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.IsBlocked() case user.FieldDefaultConfigs: return m.DefaultConfigs() + case user.FieldMemoryTemplate: + return m.MemoryTemplate() case user.FieldCreatedAt: return m.CreatedAt() case user.FieldUpdatedAt: @@ -36493,6 +36548,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldIsBlocked(ctx) case user.FieldDefaultConfigs: return m.OldDefaultConfigs(ctx) + case user.FieldMemoryTemplate: + return m.OldMemoryTemplate(ctx) case user.FieldCreatedAt: return m.OldCreatedAt(ctx) case user.FieldUpdatedAt: @@ -36569,6 +36626,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetDefaultConfigs(v) return nil + case user.FieldMemoryTemplate: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMemoryTemplate(v) + return nil case user.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -36628,6 +36692,9 @@ func (m *UserMutation) ClearedFields() []string { if m.FieldCleared(user.FieldDefaultConfigs) { fields = append(fields, user.FieldDefaultConfigs) } + if m.FieldCleared(user.FieldMemoryTemplate) { + fields = append(fields, user.FieldMemoryTemplate) + } return fields } @@ -36657,6 +36724,9 @@ func (m *UserMutation) ClearField(name string) error { case user.FieldDefaultConfigs: m.ClearDefaultConfigs() return nil + case user.FieldMemoryTemplate: + m.ClearMemoryTemplate() + return nil } return fmt.Errorf("unknown User nullable field %s", name) } @@ -36692,6 +36762,9 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldDefaultConfigs: m.ResetDefaultConfigs() return nil + case user.FieldMemoryTemplate: + m.ResetMemoryTemplate() + return nil case user.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/backend/db/runtime/runtime.go b/backend/db/runtime/runtime.go index feba1a3c..28953ce2 100644 --- a/backend/db/runtime/runtime.go +++ b/backend/db/runtime/runtime.go @@ -826,11 +826,11 @@ func init() { // user.DefaultIsBlocked holds the default value on creation for the is_blocked field. user.DefaultIsBlocked = userDescIsBlocked.Default.(bool) // userDescCreatedAt is the schema descriptor for created_at field. - userDescCreatedAt := userFields[9].Descriptor() + userDescCreatedAt := userFields[10].Descriptor() // user.DefaultCreatedAt holds the default value on creation for the created_at field. user.DefaultCreatedAt = userDescCreatedAt.Default.(func() time.Time) // userDescUpdatedAt is the schema descriptor for updated_at field. - userDescUpdatedAt := userFields[10].Descriptor() + userDescUpdatedAt := userFields[11].Descriptor() // user.DefaultUpdatedAt holds the default value on creation for the updated_at field. user.DefaultUpdatedAt = userDescUpdatedAt.Default.(func() time.Time) // user.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. diff --git a/backend/db/user.go b/backend/db/user.go index 36dc2187..ef6ab23b 100644 --- a/backend/db/user.go +++ b/backend/db/user.go @@ -38,6 +38,8 @@ type User struct { IsBlocked bool `json:"is_blocked,omitempty"` // DefaultConfigs holds the value of the "default_configs" field. DefaultConfigs map[consts.DefaultConfigType]uuid.UUID `json:"default_configs,omitempty"` + // MemoryTemplate holds the value of the "memory_template" field. + MemoryTemplate *string `json:"memory_template,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. @@ -295,7 +297,7 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case user.FieldIsBlocked: values[i] = new(sql.NullBool) - case user.FieldName, user.FieldEmail, user.FieldAvatarURL, user.FieldPassword, user.FieldRole, user.FieldStatus: + case user.FieldName, user.FieldEmail, user.FieldAvatarURL, user.FieldPassword, user.FieldRole, user.FieldStatus, user.FieldMemoryTemplate: values[i] = new(sql.NullString) case user.FieldDeletedAt, user.FieldCreatedAt, user.FieldUpdatedAt: values[i] = new(sql.NullTime) @@ -378,6 +380,13 @@ func (_m *User) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field default_configs: %w", err) } } + case user.FieldMemoryTemplate: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field memory_template", values[i]) + } else if value.Valid { + _m.MemoryTemplate = new(string) + *_m.MemoryTemplate = value.String + } case user.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -558,6 +567,11 @@ func (_m *User) String() string { builder.WriteString("default_configs=") builder.WriteString(fmt.Sprintf("%v", _m.DefaultConfigs)) builder.WriteString(", ") + if v := _m.MemoryTemplate; v != nil { + builder.WriteString("memory_template=") + builder.WriteString(*v) + } + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") diff --git a/backend/db/user/user.go b/backend/db/user/user.go index 0af0a2fc..df452d6b 100644 --- a/backend/db/user/user.go +++ b/backend/db/user/user.go @@ -33,6 +33,8 @@ const ( FieldIsBlocked = "is_blocked" // FieldDefaultConfigs holds the string denoting the default_configs field in the database. FieldDefaultConfigs = "default_configs" + // FieldMemoryTemplate holds the string denoting the memory_template field in the database. + FieldMemoryTemplate = "memory_template" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldUpdatedAt holds the string denoting the updated_at field in the database. @@ -236,6 +238,7 @@ var Columns = []string{ FieldStatus, FieldIsBlocked, FieldDefaultConfigs, + FieldMemoryTemplate, FieldCreatedAt, FieldUpdatedAt, } @@ -330,6 +333,11 @@ func ByIsBlocked(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldIsBlocked, opts...).ToFunc() } +// ByMemoryTemplate orders the results by the memory_template field. +func ByMemoryTemplate(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMemoryTemplate, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/db/user/where.go b/backend/db/user/where.go index 984ece12..d20b3cfd 100644 --- a/backend/db/user/where.go +++ b/backend/db/user/where.go @@ -99,6 +99,11 @@ func IsBlocked(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldIsBlocked, v)) } +// MemoryTemplate applies equality check predicate on the "memory_template" field. It's identical to MemoryTemplateEQ. +func MemoryTemplate(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldMemoryTemplate, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -637,6 +642,81 @@ func DefaultConfigsNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldDefaultConfigs)) } +// MemoryTemplateEQ applies the EQ predicate on the "memory_template" field. +func MemoryTemplateEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldMemoryTemplate, v)) +} + +// MemoryTemplateNEQ applies the NEQ predicate on the "memory_template" field. +func MemoryTemplateNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldMemoryTemplate, v)) +} + +// MemoryTemplateIn applies the In predicate on the "memory_template" field. +func MemoryTemplateIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldMemoryTemplate, vs...)) +} + +// MemoryTemplateNotIn applies the NotIn predicate on the "memory_template" field. +func MemoryTemplateNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldMemoryTemplate, vs...)) +} + +// MemoryTemplateGT applies the GT predicate on the "memory_template" field. +func MemoryTemplateGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldMemoryTemplate, v)) +} + +// MemoryTemplateGTE applies the GTE predicate on the "memory_template" field. +func MemoryTemplateGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldMemoryTemplate, v)) +} + +// MemoryTemplateLT applies the LT predicate on the "memory_template" field. +func MemoryTemplateLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldMemoryTemplate, v)) +} + +// MemoryTemplateLTE applies the LTE predicate on the "memory_template" field. +func MemoryTemplateLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldMemoryTemplate, v)) +} + +// MemoryTemplateContains applies the Contains predicate on the "memory_template" field. +func MemoryTemplateContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldMemoryTemplate, v)) +} + +// MemoryTemplateHasPrefix applies the HasPrefix predicate on the "memory_template" field. +func MemoryTemplateHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldMemoryTemplate, v)) +} + +// MemoryTemplateHasSuffix applies the HasSuffix predicate on the "memory_template" field. +func MemoryTemplateHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldMemoryTemplate, v)) +} + +// MemoryTemplateIsNil applies the IsNil predicate on the "memory_template" field. +func MemoryTemplateIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldMemoryTemplate)) +} + +// MemoryTemplateNotNil applies the NotNil predicate on the "memory_template" field. +func MemoryTemplateNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldMemoryTemplate)) +} + +// MemoryTemplateEqualFold applies the EqualFold predicate on the "memory_template" field. +func MemoryTemplateEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldMemoryTemplate, v)) +} + +// MemoryTemplateContainsFold applies the ContainsFold predicate on the "memory_template" field. +func MemoryTemplateContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldMemoryTemplate, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/db/user_create.go b/backend/db/user_create.go index f4a502c9..70c561e6 100644 --- a/backend/db/user_create.go +++ b/backend/db/user_create.go @@ -139,6 +139,20 @@ func (_c *UserCreate) SetDefaultConfigs(v map[consts.DefaultConfigType]uuid.UUID return _c } +// SetMemoryTemplate sets the "memory_template" field. +func (_c *UserCreate) SetMemoryTemplate(v string) *UserCreate { + _c.mutation.SetMemoryTemplate(v) + return _c +} + +// SetNillableMemoryTemplate sets the "memory_template" field if the given value is not nil. +func (_c *UserCreate) SetNillableMemoryTemplate(v *string) *UserCreate { + if v != nil { + _c.SetMemoryTemplate(*v) + } + return _c +} + // SetCreatedAt sets the "created_at" field. func (_c *UserCreate) SetCreatedAt(v time.Time) *UserCreate { _c.mutation.SetCreatedAt(v) @@ -643,6 +657,10 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldDefaultConfigs, field.TypeJSON, value) _node.DefaultConfigs = value } + if value, ok := _c.mutation.MemoryTemplate(); ok { + _spec.SetField(user.FieldMemoryTemplate, field.TypeString, value) + _node.MemoryTemplate = &value + } if value, ok := _c.mutation.CreatedAt(); ok { _spec.SetField(user.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -1192,6 +1210,24 @@ func (u *UserUpsert) ClearDefaultConfigs() *UserUpsert { return u } +// SetMemoryTemplate sets the "memory_template" field. +func (u *UserUpsert) SetMemoryTemplate(v string) *UserUpsert { + u.Set(user.FieldMemoryTemplate, v) + return u +} + +// UpdateMemoryTemplate sets the "memory_template" field to the value that was provided on create. +func (u *UserUpsert) UpdateMemoryTemplate() *UserUpsert { + u.SetExcluded(user.FieldMemoryTemplate) + return u +} + +// ClearMemoryTemplate clears the value of the "memory_template" field. +func (u *UserUpsert) ClearMemoryTemplate() *UserUpsert { + u.SetNull(user.FieldMemoryTemplate) + return u +} + // SetCreatedAt sets the "created_at" field. func (u *UserUpsert) SetCreatedAt(v time.Time) *UserUpsert { u.Set(user.FieldCreatedAt, v) @@ -1425,6 +1461,27 @@ func (u *UserUpsertOne) ClearDefaultConfigs() *UserUpsertOne { }) } +// SetMemoryTemplate sets the "memory_template" field. +func (u *UserUpsertOne) SetMemoryTemplate(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetMemoryTemplate(v) + }) +} + +// UpdateMemoryTemplate sets the "memory_template" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateMemoryTemplate() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateMemoryTemplate() + }) +} + +// ClearMemoryTemplate clears the value of the "memory_template" field. +func (u *UserUpsertOne) ClearMemoryTemplate() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearMemoryTemplate() + }) +} + // SetCreatedAt sets the "created_at" field. func (u *UserUpsertOne) SetCreatedAt(v time.Time) *UserUpsertOne { return u.Update(func(s *UserUpsert) { @@ -1829,6 +1886,27 @@ func (u *UserUpsertBulk) ClearDefaultConfigs() *UserUpsertBulk { }) } +// SetMemoryTemplate sets the "memory_template" field. +func (u *UserUpsertBulk) SetMemoryTemplate(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetMemoryTemplate(v) + }) +} + +// UpdateMemoryTemplate sets the "memory_template" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateMemoryTemplate() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateMemoryTemplate() + }) +} + +// ClearMemoryTemplate clears the value of the "memory_template" field. +func (u *UserUpsertBulk) ClearMemoryTemplate() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearMemoryTemplate() + }) +} + // SetCreatedAt sets the "created_at" field. func (u *UserUpsertBulk) SetCreatedAt(v time.Time) *UserUpsertBulk { return u.Update(func(s *UserUpsert) { diff --git a/backend/db/user_update.go b/backend/db/user_update.go index 3cc5121f..ec869bc4 100644 --- a/backend/db/user_update.go +++ b/backend/db/user_update.go @@ -199,6 +199,26 @@ func (_u *UserUpdate) ClearDefaultConfigs() *UserUpdate { return _u } +// SetMemoryTemplate sets the "memory_template" field. +func (_u *UserUpdate) SetMemoryTemplate(v string) *UserUpdate { + _u.mutation.SetMemoryTemplate(v) + return _u +} + +// SetNillableMemoryTemplate sets the "memory_template" field if the given value is not nil. +func (_u *UserUpdate) SetNillableMemoryTemplate(v *string) *UserUpdate { + if v != nil { + _u.SetMemoryTemplate(*v) + } + return _u +} + +// ClearMemoryTemplate clears the value of the "memory_template" field. +func (_u *UserUpdate) ClearMemoryTemplate() *UserUpdate { + _u.mutation.ClearMemoryTemplate() + return _u +} + // SetCreatedAt sets the "created_at" field. func (_u *UserUpdate) SetCreatedAt(v time.Time) *UserUpdate { _u.mutation.SetCreatedAt(v) @@ -1092,6 +1112,12 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.DefaultConfigsCleared() { _spec.ClearField(user.FieldDefaultConfigs, field.TypeJSON) } + if value, ok := _u.mutation.MemoryTemplate(); ok { + _spec.SetField(user.FieldMemoryTemplate, field.TypeString, value) + } + if _u.mutation.MemoryTemplateCleared() { + _spec.ClearField(user.FieldMemoryTemplate, field.TypeString) + } if value, ok := _u.mutation.CreatedAt(); ok { _spec.SetField(user.FieldCreatedAt, field.TypeTime, value) } @@ -2258,6 +2284,26 @@ func (_u *UserUpdateOne) ClearDefaultConfigs() *UserUpdateOne { return _u } +// SetMemoryTemplate sets the "memory_template" field. +func (_u *UserUpdateOne) SetMemoryTemplate(v string) *UserUpdateOne { + _u.mutation.SetMemoryTemplate(v) + return _u +} + +// SetNillableMemoryTemplate sets the "memory_template" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableMemoryTemplate(v *string) *UserUpdateOne { + if v != nil { + _u.SetMemoryTemplate(*v) + } + return _u +} + +// ClearMemoryTemplate clears the value of the "memory_template" field. +func (_u *UserUpdateOne) ClearMemoryTemplate() *UserUpdateOne { + _u.mutation.ClearMemoryTemplate() + return _u +} + // SetCreatedAt sets the "created_at" field. func (_u *UserUpdateOne) SetCreatedAt(v time.Time) *UserUpdateOne { _u.mutation.SetCreatedAt(v) @@ -3181,6 +3227,12 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.DefaultConfigsCleared() { _spec.ClearField(user.FieldDefaultConfigs, field.TypeJSON) } + if value, ok := _u.mutation.MemoryTemplate(); ok { + _spec.SetField(user.FieldMemoryTemplate, field.TypeString, value) + } + if _u.mutation.MemoryTemplateCleared() { + _spec.ClearField(user.FieldMemoryTemplate, field.TypeString) + } if value, ok := _u.mutation.CreatedAt(); ok { _spec.SetField(user.FieldCreatedAt, field.TypeTime, value) } diff --git a/backend/domain/user.go b/backend/domain/user.go index 15df2e52..92a7811a 100644 --- a/backend/domain/user.go +++ b/backend/domain/user.go @@ -28,6 +28,7 @@ type UserUsecase interface { type UserRepo interface { Get(ctx context.Context, uid uuid.UUID) (*db.User, error) Update(ctx context.Context, uid uuid.UUID, name, avatarURL string) error + UpdateMemoryTemplate(ctx context.Context, uid uuid.UUID, memoryTemplate string) error GetUserWithTeams(ctx context.Context, uid uuid.UUID) (*db.User, error) PasswordLogin(ctx context.Context, req *TeamLoginReq) (*db.User, error) ChangePassword(ctx context.Context, uid uuid.UUID, currentPassword, newPassword string, isReset bool) error @@ -43,17 +44,18 @@ type UserActiveRepo interface { } type User struct { - ID uuid.UUID `json:"id"` - Name string `json:"name"` - AvatarURL string `json:"avatar_url"` - Email string `json:"email"` - Role consts.UserRole `json:"role"` - Status consts.UserStatus `json:"status"` - IsBlocked bool `json:"is_blocked"` - Token string `json:"token,omitempty"` - Identities []*UserIdentity `json:"identities"` - Team *Team `json:"team,omitempty"` - HasPassword bool `json:"has_password"` + ID uuid.UUID `json:"id"` + Name string `json:"name"` + AvatarURL string `json:"avatar_url"` + Email string `json:"email"` + Role consts.UserRole `json:"role"` + Status consts.UserStatus `json:"status"` + IsBlocked bool `json:"is_blocked"` + Token string `json:"token,omitempty"` + Identities []*UserIdentity `json:"identities"` + Team *Team `json:"team,omitempty"` + HasPassword bool `json:"has_password"` + MemoryTemplate *string `json:"memory_template,omitempty"` } func (u *User) From(src *db.User) *User { @@ -69,6 +71,7 @@ func (u *User) From(src *db.User) *User { u.Status = src.Status u.IsBlocked = src.IsBlocked u.HasPassword = src.Password != "" + u.MemoryTemplate = src.MemoryTemplate u.Identities = cvt.Iter(src.Edges.Identities, func(_ int, i *db.UserIdentity) *UserIdentity { return cvt.From(i, &UserIdentity{}) }) @@ -127,8 +130,9 @@ type TeamUserLoginResp struct { // UpdateUserReq 更新用户信息请求 type UpdateUserReq struct { - Name string `json:"name,omitempty" form:"name"` - AvatarURL string `json:"avatar_url,omitempty" form:"avatar_url"` + Name string `json:"name,omitempty" form:"name"` + AvatarURL string `json:"avatar_url,omitempty" form:"avatar_url"` + MemoryTemplate *string `json:"memory_template,omitempty" form:"memory_template"` } // UpdateUserResp 更新用户信息响应 diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index be079778..65cbb976 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -43,6 +43,7 @@ func (User) Fields() []ent.Field { field.String("status").GoType(consts.UserStatus("")), field.Bool("is_blocked").Default(false), field.JSON("default_configs", map[consts.DefaultConfigType]uuid.UUID{}).Optional(), + field.String("memory_template").Optional().Nillable(), field.Time("created_at").Default(time.Now), field.Time("updated_at").Default(time.Now).UpdateDefault(time.Now), } diff --git a/frontend/src/components/console/settings/memory-template.tsx b/frontend/src/components/console/settings/memory-template.tsx new file mode 100644 index 00000000..3761634d --- /dev/null +++ b/frontend/src/components/console/settings/memory-template.tsx @@ -0,0 +1,251 @@ +import { useState, useEffect } from "react" +import { FileText, Save, RotateCcw, AlertTriangle } from "lucide-react" +import { Button } from "@/components/ui/button" +import { Textarea } from "@/components/ui/textarea" +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog" +import { toast } from "sonner" + +const DEFAULT_TEMPLATE = `# 用户指令记忆 + +本文件记录了用户的指令、偏好和教导,用于在未来的交互中提供参考。 + +## 格式 + +### 用户指令条目 +用户指令条目应遵循以下格式: + +[用户指令摘要] +- Date: {{date}} +- Context: [提及的场景或时间] +- Instructions: + - [用户教导或指示的内容,逐行描述] + +### 项目知识条目 +Agent 在任务执行过程中发现的条目应遵循以下格式: + +[项目知识摘要] +- Date: {{date}} +- Context: Agent 在执行 [具体任务描述] 时发现 +- Category: [代码结构|代码模式|代码生成|构建方法|测试方法|依赖关系|环境配置] +- Instructions: + - [具体的知识点,逐行描述] + +## 去重策略 +- 添加新条目前,检查是否存在相似或相同的指令 +- 若发现重复,跳过新条目或与已有条目合并 +- 合并时,更新上下文或日期信息 +- 这有助于避免冗余条目,保持记忆文件整洁 + +## 条目 + +[按上述格式记录的记忆条目] +` + +const VARIABLES = [ + { name: "date", desc: "当前日期" }, + { name: "datetime", desc: "当前日期时间" }, + { name: "project_name", desc: "项目名称" }, + { name: "user_name", desc: "用户名" }, + { name: "workspace_path", desc: "工作空间路径" }, +] + +export default function MemoryTemplate() { + const [template, setTemplate] = useState("") + const [originalTemplate, setOriginalTemplate] = useState("") + const [loading, setLoading] = useState(false) + const [showResetConfirm, setShowResetConfirm] = useState(false) + const [hasCustomTemplate, setHasCustomTemplate] = useState(false) + + useEffect(() => { + fetchTemplate() + }, []) + + const fetchTemplate = async () => { + try { + const response = await fetch("/api/v1/users/settings/memory-template", { + headers: { + "Content-Type": "application/json", + }, + credentials: "include", + }) + + if (response.ok) { + const data = await response.json() + if (data.data) { + setTemplate(data.data) + setOriginalTemplate(data.data) + setHasCustomTemplate(true) + } else { + setTemplate(DEFAULT_TEMPLATE) + setOriginalTemplate(DEFAULT_TEMPLATE) + setHasCustomTemplate(false) + } + } else { + setTemplate(DEFAULT_TEMPLATE) + setOriginalTemplate(DEFAULT_TEMPLATE) + } + } catch (error) { + console.error("Error fetching template:", error) + setTemplate(DEFAULT_TEMPLATE) + setOriginalTemplate(DEFAULT_TEMPLATE) + } + } + + const handleSave = async () => { + setLoading(true) + try { + const response = await fetch("/api/v1/users/settings/memory-template", { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + credentials: "include", + body: JSON.stringify({ memory_template: template }), + }) + + if (response.ok) { + setOriginalTemplate(template) + setHasCustomTemplate(true) + toast.success("模板保存成功") + } else { + const error = await response.json() + toast.error(error.message || "保存失败") + } + } catch (error) { + console.error("Error saving template:", error) + toast.error("保存失败,请重试") + } finally { + setLoading(false) + } + } + + const handleReset = async () => { + setLoading(true) + try { + const response = await fetch("/api/v1/users/settings/memory-template", { + method: "DELETE", + credentials: "include", + }) + + if (response.ok) { + setTemplate(DEFAULT_TEMPLATE) + setOriginalTemplate(DEFAULT_TEMPLATE) + setHasCustomTemplate(false) + toast.success("已恢复为默认模板") + } else { + toast.error("恢复失败") + } + } catch (error) { + console.error("Error resetting template:", error) + toast.error("恢复失败") + } finally { + setLoading(false) + setShowResetConfirm(false) + } + } + + const hasChanges = template !== originalTemplate + + return ( +
+
+

Memory 模板

+

+ 自定义 MEMORY.md 初始化模板,新项目创建时将使用此模板 +

+
+ +
+
+
+ +

+ 支持以下变量: + {VARIABLES.map((v) => ( + + {`{{${v.name}}}`} + + ))} +

+
+
+ {hasCustomTemplate && ( + + + 已自定义 + + )} +
+
+ +
+