From 3d0c5d47a78a88879f04859ccb1834bf76a38883 Mon Sep 17 00:00:00 2001 From: Yumechi Date: Thu, 16 Apr 2026 16:19:40 +0800 Subject: [PATCH 1/2] enhance(security): add body size and pluginConf/image entity size limits --- api/application.go | 13 ++++++++++++ api/application_test.go | 44 +++++++++++++++++++++++++++++++++++++++-- api/limits.go | 8 ++++++++ api/plugin.go | 8 +++++++- api/plugin_test.go | 28 ++++++++++++++++++++++++++ router/router.go | 5 +++++ 6 files changed, 103 insertions(+), 3 deletions(-) create mode 100644 api/limits.go diff --git a/api/application.go b/api/application.go index 2527842d..680153d5 100644 --- a/api/application.go +++ b/api/application.go @@ -3,6 +3,7 @@ package api import ( "errors" "fmt" + "log" "net/http" "os" "path/filepath" @@ -329,6 +330,18 @@ func (a *ApplicationAPI) UploadApplicationImage(ctx *gin.Context) { return } if app != nil && app.UserID == auth.GetUserID(ctx) { + // https://gin-gonic.com/en/docs/routing/upload-file/limit-bytes/ + ctx.Request.Body = http.MaxBytesReader(ctx.Writer, ctx.Request.Body, MaxUploadSize) + if err := ctx.Request.ParseMultipartForm(MaxUploadSize); err != nil { + log.Println("error parsing multipart form", err) + if _, ok := err.(*http.MaxBytesError); ok { + ctx.AbortWithError(http.StatusRequestEntityTooLarge, fmt.Errorf("file too large (max: %d bytes)", MaxUploadSize)) + return + } + ctx.AbortWithError(http.StatusBadRequest, err) + return + } + file, err := ctx.FormFile("file") if err == http.ErrMissingFile { ctx.AbortWithError(400, errors.New("file with key 'file' must be present")) diff --git a/api/application_test.go b/api/application_test.go index 4070d0bf..4fac5cde 100644 --- a/api/application_test.go +++ b/api/application_test.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "mime/multipart" + "net/http" "net/http/httptest" "os" "strings" @@ -336,6 +337,46 @@ func (s *ApplicationSuite) Test_UploadAppImage_NoImageProvided_expectBadRequest( assert.Equal(s.T(), s.ctx.Errors[0].Err, errors.New("file with key 'file' must be present")) } +func (s *ApplicationSuite) Test_UploadAppImage_FileTooLarge_expectBadRequest() { + s.db.User(5).App(1) + testImageFileData, err := os.ReadFile("../test/assets/image.png") + assert.Nil(s.T(), err) + + tempFile, err := os.CreateTemp("", "test-image-*.png") + assert.Nil(s.T(), err) + defer os.Remove(tempFile.Name()) + totalSize := 0 + for totalSize <= MaxUploadSize { + _, err := tempFile.Write(testImageFileData) + assert.Nil(s.T(), err) + totalSize += len(testImageFileData) + } + _, err = tempFile.Seek(0, io.SeekStart) + assert.Nil(s.T(), err) + + cType, buffer, err := upload(map[string]*os.File{"file": tempFile}) + assert.Nil(s.T(), err) + s.ctx.Request = httptest.NewRequest("POST", "/irrelevant", &buffer) + s.ctx.Request.Header.Set("Content-Type", cType) + test.WithUser(s.ctx, 5) + s.ctx.Params = gin.Params{{Key: "id", Value: "1"}} + + s.a.UploadApplicationImage(s.ctx) + + if app, err := s.db.GetApplicationByID(1); assert.NoError(s.T(), err) { + imgName := app.Image + + assert.Equal(s.T(), http.StatusRequestEntityTooLarge, s.recorder.Code) + _, err = os.Stat(imgName) + assert.True(s.T(), os.IsNotExist(err)) + + s.a.DeleteApplication(s.ctx) + + _, err = os.Stat(imgName) + assert.True(s.T(), os.IsNotExist(err)) + } +} + func (s *ApplicationSuite) Test_UploadAppImage_OtherErrors_expectServerError() { s.db.User(5).App(1) var b bytes.Buffer @@ -349,8 +390,7 @@ func (s *ApplicationSuite) Test_UploadAppImage_OtherErrors_expectServerError() { s.a.UploadApplicationImage(s.ctx) - assert.Equal(s.T(), 500, s.recorder.Code) - assert.Error(s.T(), s.ctx.Errors[0].Err, "multipart: NextPart: EOF") + assert.Equal(s.T(), 400, s.recorder.Code) } func (s *ApplicationSuite) Test_UploadAppImage_WithImageFile_expectSuccess() { diff --git a/api/limits.go b/api/limits.go new file mode 100644 index 00000000..cf6e00a4 --- /dev/null +++ b/api/limits.go @@ -0,0 +1,8 @@ +package api + +const ( + // Maximum upload size of 32MB for blobs and files. + MaxUploadSize = 32 << 20 + // Catch-all request body limit of 64MB enforced at middleware level. + MaxBodySize = 64 << 20 +) diff --git a/api/plugin.go b/api/plugin.go index fcaeeb4c..f62f0634 100644 --- a/api/plugin.go +++ b/api/plugin.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "net/http" "github.com/gin-gonic/gin" "github.com/gotify/location" @@ -385,8 +386,13 @@ func (c *PluginAPI) UpdateConfig(ctx *gin.Context) { } newConf := instance.DefaultConfig() - newconfBytes, err := io.ReadAll(ctx.Request.Body) + bodyReader := http.MaxBytesReader(ctx.Writer, ctx.Request.Body, MaxUploadSize) + newconfBytes, err := io.ReadAll(bodyReader) if err != nil { + if _, ok := err.(*http.MaxBytesError); ok { + ctx.AbortWithError(http.StatusRequestEntityTooLarge, errors.New("file too large")) + return + } ctx.AbortWithError(500, err) return } diff --git a/api/plugin_test.go b/api/plugin_test.go index f38e7b20..f8318247 100644 --- a/api/plugin_test.go +++ b/api/plugin_test.go @@ -5,7 +5,9 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "net/http/httptest" + "strings" "testing" "github.com/gin-gonic/gin" @@ -494,6 +496,32 @@ func (s *PluginSuite) Test_UpdateConfig() { } } +func (s *PluginSuite) Test_UpdateConfig_tooBig_expect413() { + conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) + assert.NoError(s.T(), err) + inst, err := s.manager.Instance(conf.ID) + assert.Nil(s.T(), err) + mockInst := inst.(*mock.PluginInstance) + + newConfig := &mock.PluginConfig{ + TestKey: strings.Repeat("a", MaxUploadSize+1), + } + newConfigYAML, err := yaml.Marshal(newConfig) + assert.Nil(s.T(), err) + + { + test.WithUser(s.ctx, 1) + + s.ctx.Request = httptest.NewRequest("POST", fmt.Sprintf("/plugin/%d/config", conf.ID), bytes.NewReader(newConfigYAML)) + s.ctx.Header("Content-Type", "application/x-yaml") + s.ctx.Params = gin.Params{{Key: "id", Value: fmt.Sprint(conf.ID)}} + s.a.UpdateConfig(s.ctx) + + assert.Equal(s.T(), http.StatusRequestEntityTooLarge, s.recorder.Code) + assert.NotEqual(s.T(), newConfig, mockInst.Config, "config should not be received by plugin") + } +} + func (s *PluginSuite) Test_UpdateConfig_invalidConfig_expect400() { conf, err := s.db.GetPluginConfByUserAndPath(1, mock.ModulePath) assert.NoError(s.T(), err) diff --git a/router/router.go b/router/router.go index f1be10ba..ad142883 100644 --- a/router/router.go +++ b/router/router.go @@ -66,6 +66,11 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co } streamHandler := stream.New( time.Duration(conf.Server.Stream.PingPeriodSeconds)*time.Second, 15*time.Second, conf.Server.Stream.AllowedOrigins) + // a global catch-all truncation for all requests with 1M headroom + g.Use(func(ctx *gin.Context) { + ctx.Request.Body = http.MaxBytesReader(ctx.Writer, ctx.Request.Body, api.MaxBodySize) + ctx.Next() + }) go func() { ticker := time.NewTicker(5 * time.Minute) for range ticker.C { From f1b6e4f16a1063b322dac3e4a452111096136127 Mon Sep 17 00:00:00 2001 From: Yumechi Date: Sat, 18 Apr 2026 14:18:25 +0800 Subject: [PATCH 2/2] enhance(security): add quota for clients and applications --- api/application.go | 30 +++++++++++++++++++++++++---- api/application_test.go | 31 ++++++++++++++++++++++++++---- api/client.go | 31 ++++++++++++++++++++++++++++-- api/client_test.go | 37 +++++++++++++++++++++++++++++++++++- api/limits.go | 2 ++ api/oidc.go | 2 +- api/session.go | 4 ++-- config.example.yml | 3 +++ config/config.go | 4 ++++ database/application.go | 13 ++++++++++++- database/application_test.go | 6 +++--- database/client.go | 35 ++++++++++++++++++++++++++++++++-- database/client_test.go | 2 +- database/database_test.go | 8 ++++---- database/errors.go | 5 +++++ database/message_test.go | 8 ++++---- database/migration_test.go | 2 +- database/user_test.go | 8 ++++---- docs/spec.json | 18 ++++++++++++++++++ plugin/manager.go | 4 ++-- plugin/manager_test.go | 6 +++--- router/router.go | 2 ++ test/testdb/database.go | 8 ++++---- 23 files changed, 226 insertions(+), 43 deletions(-) create mode 100644 database/errors.go diff --git a/api/application.go b/api/application.go index 680153d5..4f0650e8 100644 --- a/api/application.go +++ b/api/application.go @@ -3,7 +3,6 @@ package api import ( "errors" "fmt" - "log" "net/http" "os" "path/filepath" @@ -11,6 +10,7 @@ import ( "github.com/gin-gonic/gin" "github.com/gotify/server/v2/auth" + "github.com/gotify/server/v2/database" "github.com/gotify/server/v2/model" "github.com/h2non/filetype" "gorm.io/gorm" @@ -18,7 +18,7 @@ import ( // The ApplicationDatabase interface for encapsulating database access. type ApplicationDatabase interface { - CreateApplication(application *model.Application) error + CreateApplication(application *model.Application, quota uint32) error GetApplicationByToken(token string) (*model.Application, error) GetApplicationByID(id uint) (*model.Application, error) GetApplicationsByUser(userID uint) ([]*model.Application, error) @@ -30,6 +30,7 @@ type ApplicationDatabase interface { type ApplicationAPI struct { DB ApplicationDatabase ImageDir string + Quota uint32 } // Application Params Model @@ -57,6 +58,10 @@ type ApplicationParams struct { SortKey string `form:"sortKey" query:"sortKey" json:"sortKey"` } +func (p *ApplicationParams) EffectiveSize() uint64 { + return uint64(len(p.Name)) + uint64(len(p.Description)) + uint64(len(p.SortKey)) +} + // CreateApplication creates an application and returns the access token. // swagger:operation POST /application application createApp // @@ -90,9 +95,17 @@ type ApplicationParams struct { // description: Forbidden // schema: // $ref: "#/definitions/Error" +// 422: +// description: Unprocessable Entity +// schema: +// $ref: "#/definitions/Error" func (a *ApplicationAPI) CreateApplication(ctx *gin.Context) { applicationParams := ApplicationParams{} if err := ctx.Bind(&applicationParams); err == nil { + if applicationParams.EffectiveSize() > MaxApplicationClientEntrySize { + ctx.AbortWithError(http.StatusUnprocessableEntity, fmt.Errorf("application entry too large (max: %d bytes)", MaxApplicationClientEntrySize)) + return + } app := model.Application{ Name: applicationParams.Name, Description: applicationParams.Description, @@ -103,7 +116,7 @@ func (a *ApplicationAPI) CreateApplication(ctx *gin.Context) { Internal: false, } - if err := a.DB.CreateApplication(&app); err != nil { + if err := a.DB.CreateApplication(&app, a.Quota); err != nil { handleApplicationError(ctx, err) return } @@ -248,6 +261,10 @@ func (a *ApplicationAPI) DeleteApplication(ctx *gin.Context) { // description: Not Found // schema: // $ref: "#/definitions/Error" +// 422: +// description: Unprocessable Entity +// schema: +// $ref: "#/definitions/Error" func (a *ApplicationAPI) UpdateApplication(ctx *gin.Context) { withID(ctx, "id", func(id uint) { app, err := a.DB.GetApplicationByID(id) @@ -257,6 +274,10 @@ func (a *ApplicationAPI) UpdateApplication(ctx *gin.Context) { if app != nil && app.UserID == auth.GetUserID(ctx) { applicationParams := ApplicationParams{} if err := ctx.Bind(&applicationParams); err == nil { + if applicationParams.EffectiveSize() > MaxApplicationClientEntrySize { + ctx.AbortWithError(http.StatusUnprocessableEntity, fmt.Errorf("application entry too large (max: %d bytes)", MaxApplicationClientEntrySize)) + return + } app.Description = applicationParams.Description app.Name = applicationParams.Name app.DefaultPriority = applicationParams.DefaultPriority @@ -333,7 +354,6 @@ func (a *ApplicationAPI) UploadApplicationImage(ctx *gin.Context) { // https://gin-gonic.com/en/docs/routing/upload-file/limit-bytes/ ctx.Request.Body = http.MaxBytesReader(ctx.Writer, ctx.Request.Body, MaxUploadSize) if err := ctx.Request.ParseMultipartForm(MaxUploadSize); err != nil { - log.Println("error parsing multipart form", err) if _, ok := err.(*http.MaxBytesError); ok { ctx.AbortWithError(http.StatusRequestEntityTooLarge, fmt.Errorf("file too large (max: %d bytes)", MaxUploadSize)) return @@ -496,6 +516,8 @@ func ValidApplicationImageExt(ext string) bool { func handleApplicationError(ctx *gin.Context, err error) { if errors.Is(err, gorm.ErrDuplicatedKey) { ctx.AbortWithError(400, errors.New("sort key is not unique")) + } else if errors.Is(err, database.ErrQuotaExceeded) { + ctx.AbortWithError(http.StatusUnprocessableEntity, fmt.Errorf("quota exceeded")) } else { ctx.AbortWithError(500, err) } diff --git a/api/application_test.go b/api/application_test.go index 4fac5cde..ef40bbb7 100644 --- a/api/application_test.go +++ b/api/application_test.go @@ -56,7 +56,7 @@ func (s *ApplicationSuite) BeforeTest(suiteName, testName string) { s.db = testdb.NewDB(s.T()) s.ctx, _ = gin.CreateTestContext(s.recorder) withURL(s.ctx, "http", "example.com") - s.a = &ApplicationAPI{DB: s.db} + s.a = &ApplicationAPI{DB: s.db, Quota: 128} } func (s *ApplicationSuite) AfterTest(suiteName, testName string) { @@ -173,6 +173,29 @@ func (s *ApplicationSuite) Test_CreateApplication_onlyRequiredParameters() { } } +func (s *ApplicationSuite) Test_CreateApplication_tooBigApplicationEntry_expectBadRequest() { + s.db.User(5) + + test.WithUser(s.ctx, 5) + s.withFormData(fmt.Sprintf("name=%s", strings.Repeat("a", MaxApplicationClientEntrySize+1))) + s.a.CreateApplication(s.ctx) + + assert.Equal(s.T(), http.StatusUnprocessableEntity, s.recorder.Code) +} + +func (s *ApplicationSuite) Test_CreateApplication_quotaExceeded_expectBadRequest() { + user := s.db.User(5) + for i := uint(0); i < uint(s.a.Quota); i++ { + user.AppWithToken(100+i, fmt.Sprintf("app%d", i)) + } + + test.WithUser(s.ctx, 5) + s.withFormData("name=custom_name") + s.a.CreateApplication(s.ctx) + + assert.Equal(s.T(), http.StatusUnprocessableEntity, s.recorder.Code) +} + func (s *ApplicationSuite) Test_CreateApplication_returnsApplicationWithID() { s.db.User(5) @@ -424,7 +447,7 @@ func (s *ApplicationSuite) Test_UploadAppImage_WithImageFile_DeleteExstingImageA firstGeneratedImageName := firstApplicationToken[1:] + ".png" secondGeneratedImageName := secondApplicationToken[1:] + ".png" s.db.User(5) - s.db.CreateApplication(&model.Application{UserID: 5, ID: 1, Image: existingImageName}) + s.db.CreateApplication(&model.Application{UserID: 5, ID: 1, Image: existingImageName}, 0) cType, buffer, err := upload(map[string]*os.File{"file": mustOpen("../test/assets/image.png")}) assert.Nil(s.T(), err) @@ -450,7 +473,7 @@ func (s *ApplicationSuite) Test_UploadAppImage_WithImageFile_DeleteExstingImageA func (s *ApplicationSuite) Test_UploadAppImage_WithImageFile_DeleteExistingImage() { s.db.User(5) - s.db.CreateApplication(&model.Application{UserID: 5, ID: 1, Image: "existing.png"}) + s.db.CreateApplication(&model.Application{UserID: 5, ID: 1, Image: "existing.png"}, 0) fakeImage(s.T(), "existing.png") cType, buffer, err := upload(map[string]*os.File{"file": mustOpen("../test/assets/image.png")}) @@ -541,7 +564,7 @@ func (s *ApplicationSuite) Test_RemoveAppImage_expectSuccess() { s.db.User(5) imageFile := "existing.png" - s.db.CreateApplication(&model.Application{UserID: 5, ID: 1, Image: imageFile}) + s.db.CreateApplication(&model.Application{UserID: 5, ID: 1, Image: imageFile}, 0) fakeImage(s.T(), imageFile) test.WithUser(s.ctx, 5) diff --git a/api/client.go b/api/client.go index 865d74dc..b05a7787 100644 --- a/api/client.go +++ b/api/client.go @@ -1,16 +1,19 @@ package api import ( + "errors" "fmt" + "net/http" "github.com/gin-gonic/gin" "github.com/gotify/server/v2/auth" + "github.com/gotify/server/v2/database" "github.com/gotify/server/v2/model" ) // The ClientDatabase interface for encapsulating database access. type ClientDatabase interface { - CreateClient(client *model.Client) error + CreateClient(client *model.Client, quota uint32) error GetClientByToken(token string) (*model.Client, error) GetClientByID(id uint) (*model.Client, error) GetClientsByUser(userID uint) ([]*model.Client, error) @@ -23,6 +26,7 @@ type ClientAPI struct { DB ClientDatabase ImageDir string NotifyDeleted func(uint, string) + Quota uint32 } // Client Params Model @@ -38,6 +42,10 @@ type ClientParams struct { Name string `form:"name" query:"name" json:"name" binding:"required"` } +func (p *ClientParams) EffectiveSize() uint64 { + return uint64(len(p.Name)) +} + // UpdateClient updates a client by its id. // swagger:operation PUT /client/{id} client updateClient // @@ -90,6 +98,10 @@ func (a *ClientAPI) UpdateClient(ctx *gin.Context) { if client != nil && client.UserID == auth.GetUserID(ctx) { newValues := ClientParams{} if err := ctx.Bind(&newValues); err == nil { + if newValues.EffectiveSize() > MaxApplicationClientEntrySize { + ctx.AbortWithError(http.StatusUnprocessableEntity, fmt.Errorf("client entry too large (max: %d bytes)", MaxApplicationClientEntrySize)) + return + } client.Name = newValues.Name if success := successOrAbort(ctx, 500, a.DB.UpdateClient(client)); !success { @@ -136,18 +148,33 @@ func (a *ClientAPI) UpdateClient(ctx *gin.Context) { // description: Forbidden // schema: // $ref: "#/definitions/Error" +// 422: +// description: Unprocessable Entity +// schema: +// $ref: "#/definitions/Error" func (a *ClientAPI) CreateClient(ctx *gin.Context) { clientParams := ClientParams{} if err := ctx.Bind(&clientParams); err == nil { + if clientParams.EffectiveSize() > MaxApplicationClientEntrySize { + ctx.AbortWithError(http.StatusUnprocessableEntity, fmt.Errorf("client entry too large (max: %d bytes)", MaxApplicationClientEntrySize)) + return + } client := model.Client{ Name: clientParams.Name, Token: auth.GenerateNotExistingToken(generateClientToken, a.clientExists), UserID: auth.GetUserID(ctx), } - if success := successOrAbort(ctx, 500, a.DB.CreateClient(&client)); !success { + dbErr := a.DB.CreateClient(&client, a.Quota) + if dbErr != nil { + if errors.Is(dbErr, database.ErrQuotaExceeded) { + ctx.AbortWithError(http.StatusUnprocessableEntity, fmt.Errorf("quota exceeded")) + return + } + ctx.AbortWithError(500, dbErr) return } + ctx.JSON(200, client) } } diff --git a/api/client_test.go b/api/client_test.go index e2b3f28a..fa8c122e 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -1,6 +1,8 @@ package api import ( + "fmt" + "net/http" "net/http/httptest" "net/url" "strings" @@ -44,7 +46,7 @@ func (s *ClientSuite) BeforeTest(suiteName, testName string) { s.ctx, _ = gin.CreateTestContext(s.recorder) withURL(s.ctx, "http", "example.com") s.notified = false - s.a = &ClientAPI{DB: s.db, NotifyDeleted: s.notify} + s.a = &ClientAPI{DB: s.db, NotifyDeleted: s.notify, Quota: 128} } func (s *ClientSuite) notify(uint, string) { @@ -145,6 +147,39 @@ func (s *ClientSuite) Test_CreateClient_withExistingToken() { test.BodyEquals(s.T(), expected, s.recorder) } +func (s *ClientSuite) Test_CreateClient_quotaRotationWorks() { + user := s.db.User(5) + for i := uint(0); i < uint(s.a.Quota); i++ { + user.ClientWithToken(100+i, fmt.Sprintf("client%d", i)) + } + + test.WithUser(s.ctx, 5) + s.withFormData("name=custom_name") + + baseCount, err := s.db.CountClientsByUserID(s.db.DB, 5) + assert.NoError(s.T(), err) + + s.a.CreateClient(s.ctx) + + userCount, err := s.db.CountClientsByUserID(s.db.DB, 5) + assert.NoError(s.T(), err) + + assert.Equal(s.T(), baseCount, userCount) + + assert.Equal(s.T(), 200, s.recorder.Code) +} + +func (s *ClientSuite) Test_CreateClient_rejectTooLargeClient() { + s.db.User(5) + + test.WithUser(s.ctx, 5) + s.withFormData(fmt.Sprintf("name=%s", strings.Repeat("a", MaxApplicationClientEntrySize+1))) + + s.a.CreateClient(s.ctx) + + assert.Equal(s.T(), http.StatusUnprocessableEntity, s.recorder.Code) +} + func (s *ClientSuite) Test_GetClients() { userBuilder := s.db.User(5) first := userBuilder.NewClientWithToken(1, "perfper") diff --git a/api/limits.go b/api/limits.go index cf6e00a4..0c3780b6 100644 --- a/api/limits.go +++ b/api/limits.go @@ -1,6 +1,8 @@ package api const ( + // Max size for a client or application entry. + MaxApplicationClientEntrySize = 64 << 10 // Maximum upload size of 32MB for blobs and files. MaxUploadSize = 32 << 20 // Catch-all request body limit of 64MB enforced at middleware level. diff --git a/api/oidc.go b/api/oidc.go index 002fcb16..50461053 100644 --- a/api/oidc.go +++ b/api/oidc.go @@ -337,7 +337,7 @@ func (a *OIDCAPI) createClient(name string, userID uint) (*model.Client, error) Token: auth.GenerateNotExistingToken(generateClientToken, func(t string) bool { c, _ := a.DB.GetClientByToken(t); return c != nil }), UserID: userID, } - return client, a.DB.CreateClient(client) + return client, a.DB.CreateClient(client, 0) } func (a *OIDCAPI) popPendingSession(key string) (*pendingOIDCSession, bool) { diff --git a/api/session.go b/api/session.go index 5e581484..444b078c 100644 --- a/api/session.go +++ b/api/session.go @@ -12,7 +12,7 @@ import ( // SessionDatabase is the interface for session-related database access. type SessionDatabase interface { GetUserByName(name string) (*model.User, error) - CreateClient(client *model.Client) error + CreateClient(client *model.Client, quota uint32) error GetClientByToken(token string) (*model.Client, error) DeleteClientByID(id uint) error } @@ -52,7 +52,7 @@ func (a *SessionAPI) Login(ctx *gin.Context) { Token: auth.GenerateNotExistingToken(generateClientToken, a.clientExists), UserID: user.ID, } - if success := successOrAbort(ctx, 500, a.DB.CreateClient(&client)); !success { + if success := successOrAbort(ctx, 500, a.DB.CreateClient(&client, 0)); !success { return } diff --git a/config.example.yml b/config.example.yml index b1bdfed0..5e7b8c5b 100644 --- a/config.example.yml +++ b/config.example.yml @@ -66,3 +66,6 @@ passstrength: 10 # the bcrypt password strength (higher = better but also slower uploadedimagesdir: data/images # the directory for storing uploaded images pluginsdir: data/plugins # the directory where plugin resides (leave empty to disable plugins) registration: false # enable registrations +quota: + clients: 1024 # the maximum number of clients a non-admin user can have + applications: 1024 # the maximum number of applications a non-admin user can have \ No newline at end of file diff --git a/config/config.go b/config/config.go index c299d8ee..8f2d9c87 100644 --- a/config/config.go +++ b/config/config.go @@ -66,6 +66,10 @@ type Configuration struct { AutoRegister bool `default:"true"` Scopes []string } + Quota struct { + Clients uint32 `default:"1024"` + Applications uint32 `default:"1024"` + } } func configFiles() []string { diff --git a/database/application.go b/database/application.go index eb4ad097..5bc8e13a 100644 --- a/database/application.go +++ b/database/application.go @@ -36,7 +36,7 @@ func (d *GormDatabase) GetApplicationByID(id uint) (*model.Application, error) { } // CreateApplication creates an application. -func (d *GormDatabase) CreateApplication(application *model.Application) error { +func (d *GormDatabase) CreateApplication(application *model.Application, quota uint32) error { return d.DB.Transaction(func(tx *gorm.DB) error { if application.SortKey == "" { sortKey := "" @@ -50,6 +50,17 @@ func (d *GormDatabase) CreateApplication(application *model.Application) error { } } + if quota > 0 { + var count int64 + err := tx.Model(&model.Application{}).Where("user_id = ?", application.UserID).Count(&count).Error + if err != nil { + return err + } + if uint64(count) >= uint64(quota) { + return ErrQuotaExceeded + } + } + return tx.Create(application).Error }, &sql.TxOptions{Isolation: sql.LevelSerializable}) } diff --git a/database/application_test.go b/database/application_test.go index 4c389be6..9fe94b59 100644 --- a/database/application_test.go +++ b/database/application_test.go @@ -25,7 +25,7 @@ func (s *DatabaseSuite) TestApplication() { } app := &model.Application{UserID: user.ID, Token: "C0000000000", Name: "backupserver"} - s.db.CreateApplication(app) + s.db.CreateApplication(app, 0) if apps, err := s.db.GetApplicationsByUser(user.ID); assert.NoError(s.T(), err) { assert.Len(s.T(), apps, 1) @@ -70,8 +70,8 @@ func (s *DatabaseSuite) TestApplication() { } func (s *DatabaseSuite) TestDeleteAppDeletesMessages() { - assert.NoError(s.T(), s.db.CreateApplication(&model.Application{ID: 55, Token: "token"})) - assert.NoError(s.T(), s.db.CreateApplication(&model.Application{ID: 66, Token: "token2"})) + assert.NoError(s.T(), s.db.CreateApplication(&model.Application{ID: 55, Token: "token"}, 0)) + assert.NoError(s.T(), s.db.CreateApplication(&model.Application{ID: 66, Token: "token2"}, 0)) assert.NoError(s.T(), s.db.CreateMessage(&model.Message{ID: 12, ApplicationID: 55})) assert.NoError(s.T(), s.db.CreateMessage(&model.Message{ID: 13, ApplicationID: 66})) assert.NoError(s.T(), s.db.CreateMessage(&model.Message{ID: 14, ApplicationID: 55})) diff --git a/database/client.go b/database/client.go index f85c4948..eef56398 100644 --- a/database/client.go +++ b/database/client.go @@ -33,9 +33,40 @@ func (d *GormDatabase) GetClientByToken(token string) (*model.Client, error) { return nil, err } +func (d *GormDatabase) CountClientsByUserID(tx *gorm.DB, userID uint) (int64, error) { + var count int64 + err := tx.Model(&model.Client{}).Where("user_id = ?", userID).Count(&count).Error + return count, err +} + // CreateClient creates a client. -func (d *GormDatabase) CreateClient(client *model.Client) error { - return d.DB.Create(client).Error +func (d *GormDatabase) CreateClient(client *model.Client, quota uint32) error { + txn := d.DB.Begin() + defer txn.Rollback() + res := txn.Create(client) + if res.Error != nil { + return res.Error + } + if quota > 0 { + count, err := d.CountClientsByUserID(txn, client.UserID) + if err != nil { + return err + } + if uint64(count) > uint64(quota) { + // quota exceeded, delete the oldest client + var oldestClient model.Client + err := txn.Where("user_id = ?", client.UserID).Order("last_used ASC").First(&oldestClient).Error + qe := ErrQuotaExceeded + if err != nil { + return qe + } + err = txn.Delete(&oldestClient).Error + if err != nil { + return qe + } + } + } + return txn.Commit().Error } // GetClientsByUser returns all clients from a user. diff --git a/database/client_test.go b/database/client_test.go index b1dda969..2205b2bf 100644 --- a/database/client_test.go +++ b/database/client_test.go @@ -24,7 +24,7 @@ func (s *DatabaseSuite) TestClient() { } client := &model.Client{UserID: user.ID, Token: "C0000000000", Name: "android"} - assert.NoError(s.T(), s.db.CreateClient(client)) + assert.NoError(s.T(), s.db.CreateClient(client, 0)) if clients, err := s.db.GetClientsByUser(user.ID); assert.NoError(s.T(), err) { assert.Len(s.T(), clients, 1) diff --git a/database/database_test.go b/database/database_test.go index cc962e11..60e9dbb8 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -79,13 +79,13 @@ func TestMigrateSortKey(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, db) - err = db.CreateApplication(&model.Application{Name: "one", Token: "one", UserID: 1}) + err = db.CreateApplication(&model.Application{Name: "one", Token: "one", UserID: 1}, 0) assert.NoError(t, err) - err = db.CreateApplication(&model.Application{Name: "two", Token: "two", UserID: 1}) + err = db.CreateApplication(&model.Application{Name: "two", Token: "two", UserID: 1}, 0) assert.NoError(t, err) - err = db.CreateApplication(&model.Application{Name: "three", Token: "three", UserID: 1}) + err = db.CreateApplication(&model.Application{Name: "three", Token: "three", UserID: 1}, 0) assert.NoError(t, err) - err = db.CreateApplication(&model.Application{Name: "one-other", Token: "one-other", UserID: 2}) + err = db.CreateApplication(&model.Application{Name: "one-other", Token: "one-other", UserID: 2}, 0) assert.NoError(t, err) err = db.DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(new(model.Application)).UpdateColumn("sort_key", nil).Error diff --git a/database/errors.go b/database/errors.go new file mode 100644 index 00000000..376e78cb --- /dev/null +++ b/database/errors.go @@ -0,0 +1,5 @@ +package database + +import "errors" + +var ErrQuotaExceeded = errors.New("quota exceeded") diff --git a/database/message_test.go b/database/message_test.go index 950567f7..14a46716 100644 --- a/database/message_test.go +++ b/database/message_test.go @@ -19,7 +19,7 @@ func (s *DatabaseSuite) TestMessage() { assert.NotEqual(s.T(), 0, user.ID) backupServer := &model.Application{UserID: user.ID, Token: "A0000000000", Name: "backupserver"} - s.db.CreateApplication(backupServer) + s.db.CreateApplication(backupServer, 0) assert.NotEqual(s.T(), 0, backupServer.ID) msgs, err := s.db.GetMessagesByUser(user.ID) @@ -49,7 +49,7 @@ func (s *DatabaseSuite) TestMessage() { assertEquals(s.T(), msgs[0], backupdone) loginServer := &model.Application{UserID: user.ID, Token: "A0000000001", Name: "loginserver"} - require.NoError(s.T(), s.db.CreateApplication(loginServer)) + require.NoError(s.T(), s.db.CreateApplication(loginServer, 0)) assert.NotEqual(s.T(), 0, loginServer.ID) logindone := &model.Message{ApplicationID: loginServer.ID, Message: "login done", Title: "login", Priority: 1, Date: time.Now()} @@ -153,8 +153,8 @@ func (s *DatabaseSuite) TestGetMessagesSince() { app := &model.Application{UserID: user.ID, Token: "A0000000000"} app2 := &model.Application{UserID: user.ID, Token: "A0000000001"} - require.NoError(s.T(), s.db.CreateApplication(app)) - require.NoError(s.T(), s.db.CreateApplication(app2)) + require.NoError(s.T(), s.db.CreateApplication(app, 0)) + require.NoError(s.T(), s.db.CreateApplication(app2, 0)) curDate := time.Now() for i := 1; i <= 500; i++ { diff --git a/database/migration_test.go b/database/migration_test.go index dc3d10e4..998819f2 100644 --- a/database/migration_test.go +++ b/database/migration_test.go @@ -66,7 +66,7 @@ func (s *MigrationSuite) TestMigration() { UserID: user.ID, Description: "this is a test application", Name: "test application", - })) + }, 0)) } if app, err := db.GetApplicationByToken("A1234"); assert.NoError(s.T(), err) { assert.Equal(s.T(), "test application", app.Name) diff --git a/database/user_test.go b/database/user_test.go index 78a5f737..f328b360 100644 --- a/database/user_test.go +++ b/database/user_test.go @@ -103,15 +103,15 @@ func (s *DatabaseSuite) TestUserPlugins() { func (s *DatabaseSuite) TestDeleteUserDeletesApplicationsAndClientsAndPluginConfs() { require.NoError(s.T(), s.db.CreateUser(&model.User{Name: "nicories", ID: 10})) - require.NoError(s.T(), s.db.CreateApplication(&model.Application{ID: 100, Token: "apptoken", UserID: 10})) + require.NoError(s.T(), s.db.CreateApplication(&model.Application{ID: 100, Token: "apptoken", UserID: 10}, 0)) require.NoError(s.T(), s.db.CreateMessage(&model.Message{ID: 1000, ApplicationID: 100})) - require.NoError(s.T(), s.db.CreateClient(&model.Client{ID: 10000, Token: "clienttoken", UserID: 10})) + require.NoError(s.T(), s.db.CreateClient(&model.Client{ID: 10000, Token: "clienttoken", UserID: 10}, 0)) require.NoError(s.T(), s.db.CreatePluginConf(&model.PluginConf{ID: 1000, Token: "plugintoken", UserID: 10})) require.NoError(s.T(), s.db.CreateUser(&model.User{Name: "nicories2", ID: 20})) - require.NoError(s.T(), s.db.CreateApplication(&model.Application{ID: 200, Token: "apptoken2", UserID: 20})) + require.NoError(s.T(), s.db.CreateApplication(&model.Application{ID: 200, Token: "apptoken2", UserID: 20}, 0)) require.NoError(s.T(), s.db.CreateMessage(&model.Message{ID: 2000, ApplicationID: 200})) - require.NoError(s.T(), s.db.CreateClient(&model.Client{ID: 20000, Token: "clienttoken2", UserID: 20})) + require.NoError(s.T(), s.db.CreateClient(&model.Client{ID: 20000, Token: "clienttoken2", UserID: 20}, 0)) require.NoError(s.T(), s.db.CreatePluginConf(&model.PluginConf{ID: 2000, Token: "plugintoken2", UserID: 20})) require.NoError(s.T(), s.db.DeleteUserByID(10)) diff --git a/docs/spec.json b/docs/spec.json index 090204f2..0d649e16 100644 --- a/docs/spec.json +++ b/docs/spec.json @@ -133,6 +133,12 @@ "schema": { "$ref": "#/definitions/Error" } + }, + "422": { + "description": "Unprocessable Entity", + "schema": { + "$ref": "#/definitions/Error" + } } } } @@ -213,6 +219,12 @@ "schema": { "$ref": "#/definitions/Error" } + }, + "422": { + "description": "Unprocessable Entity", + "schema": { + "$ref": "#/definitions/Error" + } } } }, @@ -846,6 +858,12 @@ "schema": { "$ref": "#/definitions/Error" } + }, + "422": { + "description": "Unprocessable Entity", + "schema": { + "$ref": "#/definitions/Error" + } } } } diff --git a/plugin/manager.go b/plugin/manager.go index 2bb4041e..c827bc4f 100644 --- a/plugin/manager.go +++ b/plugin/manager.go @@ -32,7 +32,7 @@ type Database interface { GetPluginConfByID(id uint) (*model.PluginConf, error) GetPluginConfByToken(token string) (*model.PluginConf, error) GetUserByID(id uint) (*model.User, error) - CreateApplication(application *model.Application) error + CreateApplication(application *model.Application, quota uint32) error UpdateApplication(app *model.Application) error GetApplicationsByUser(userID uint) ([]*model.Application, error) GetApplicationByToken(token string) (*model.Application, error) @@ -413,7 +413,7 @@ func (m *Manager) createPluginConf(instance compat.PluginInstance, info compat.I Internal: true, Description: fmt.Sprintf("auto generated application for %s", info.ModulePath), } - if err := m.db.CreateApplication(app); err != nil { + if err := m.db.CreateApplication(app, 0); err != nil { return nil, err } pluginConf.ApplicationID = app.ID diff --git a/plugin/manager_test.go b/plugin/manager_test.go index 8d698966..555f1d97 100644 --- a/plugin/manager_test.go +++ b/plugin/manager_test.go @@ -376,7 +376,7 @@ func TestNewManager_InternalApplicationManagement(t *testing.T) { Internal: true, Name: "obsolete plugin application", UserID: 1, - }) + }, 0) if app, err := db.GetApplicationByToken("Ainternal_obsolete"); assert.NoError(t, err) { assert.True(t, app.Internal) @@ -394,7 +394,7 @@ func TestNewManager_InternalApplicationManagement(t *testing.T) { Internal: true, Name: "not loaded plugin application", UserID: 1, - })) + }, 0)) if app, err := db.GetApplicationByToken("Ainternal_not_loaded"); assert.NoError(t, err) { assert.NoError(t, db.CreatePluginConf(&model.PluginConf{ ApplicationID: app.ID, @@ -420,7 +420,7 @@ func TestNewManager_InternalApplicationManagement(t *testing.T) { Internal: false, Name: "not loaded plugin application", UserID: 1, - })) + }, 0)) if app, err := db.GetApplicationByToken("Ainternal_loaded"); assert.NoError(t, err) { assert.NoError(t, db.CreatePluginConf(&model.PluginConf{ ApplicationID: app.ID, diff --git a/router/router.go b/router/router.go index ad142883..040cb4d4 100644 --- a/router/router.go +++ b/router/router.go @@ -86,10 +86,12 @@ func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Co DB: db, ImageDir: conf.UploadedImagesDir, NotifyDeleted: streamHandler.NotifyDeletedClient, + Quota: conf.Quota.Clients, } applicationHandler := api.ApplicationAPI{ DB: db, ImageDir: conf.UploadedImagesDir, + Quota: conf.Quota.Applications, } sessionHandler := api.SessionAPI{DB: db, NotifyDeleted: streamHandler.NotifyDeletedClient, SecureCookie: conf.Server.SecureCookie} userChangeNotifier := new(api.UserChangeNotifier) diff --git a/test/testdb/database.go b/test/testdb/database.go index 29f8cfdd..104bb588 100644 --- a/test/testdb/database.go +++ b/test/testdb/database.go @@ -103,7 +103,7 @@ func (ab *AppClientBuilder) NewInternalAppWithToken(id uint, token string) *mode func (ab *AppClientBuilder) newAppWithToken(id uint, token string, internal bool) *model.Application { application := &model.Application{ID: id, UserID: ab.userID, Token: token, Internal: internal} - ab.db.CreateApplication(application) + ab.db.CreateApplication(application, 0) return application } @@ -134,14 +134,14 @@ func (ab *AppClientBuilder) NewInternalAppWithTokenAndName(id uint, token, name func (ab *AppClientBuilder) newAppWithTokenAndName(id uint, token, name string, internal bool) *model.Application { application := &model.Application{ID: id, UserID: ab.userID, Token: token, Name: name, Internal: internal} - ab.db.CreateApplication(application) + ab.db.CreateApplication(application, 0) return application } // AppWithTokenAndDefaultPriority creates an application with a token and defaultPriority and returns a message builder. func (ab *AppClientBuilder) AppWithTokenAndDefaultPriority(id uint, token string, defaultPriority int) *MessageBuilder { application := &model.Application{ID: id, UserID: ab.userID, Token: token, DefaultPriority: defaultPriority} - ab.db.CreateApplication(application) + ab.db.CreateApplication(application, 0) return &MessageBuilder{db: ab.db, appID: id} } @@ -159,7 +159,7 @@ func (ab *AppClientBuilder) ClientWithToken(id uint, token string) *AppClientBui // NewClientWithToken creates a client with a token and returns the client. func (ab *AppClientBuilder) NewClientWithToken(id uint, token string) *model.Client { client := &model.Client{ID: id, Token: token, UserID: ab.userID} - ab.db.CreateClient(client) + ab.db.CreateClient(client, 0) return client }