diff --git a/README.md b/README.md index f75d1214..fcc4c1b6 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,10 @@ Architecture diagrams are available in: - `docs/architecture_overview.md` +Deployment/runbook documents: + +- `docs/docker_multinode_manual_run.md` (manual `docker run`, 4-5 node cluster on multiple VMs, no docker compose) + ## Example Usage diff --git a/adapter/add_voter_join_test.go b/adapter/add_voter_join_test.go new file mode 100644 index 00000000..6ca2f5f2 --- /dev/null +++ b/adapter/add_voter_join_test.go @@ -0,0 +1,330 @@ +package adapter + +import ( + "bytes" + "context" + "errors" + "net" + "strconv" + "sync" + "testing" + "time" + + "github.com/Jille/raft-grpc-leader-rpc/leaderhealth" + "github.com/Jille/raftadmin" + raftadminpb "github.com/Jille/raftadmin/proto" + "github.com/bootjp/elastickv/kv" + pb "github.com/bootjp/elastickv/proto" + "github.com/bootjp/elastickv/store" + "github.com/hashicorp/raft" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +func TestAddVoterJoinPath_RegistersMemberAndServesAdapterTraffic(t *testing.T) { + t.Parallel() + + const ( + waitTimeout = 12 * time.Second + waitInterval = 100 * time.Millisecond + rpcTimeout = 2 * time.Second + ) + + ctx := context.Background() + nodes, servers := setupAddVoterJoinPathNodes(t, ctx) + t.Cleanup(func() { + shutdown(nodes) + servers.AwaitNoError(t, waitTimeout) + }) + + waitForNodeListeners(t, ctx, nodes, waitTimeout, waitInterval) + require.Eventually(t, func() bool { + return nodes[0].raft.State() == raft.Leader + }, waitTimeout, waitInterval) + + adminConn, err := grpc.NewClient(nodes[0].grpcAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + t.Cleanup(func() { _ = adminConn.Close() }) + admin := raftadminpb.NewRaftAdminClient(adminConn) + + addVotersAndAwait(t, ctx, rpcTimeout, admin, nodes, []int{1, 2}) + + expectedCfg := expectedVoterConfig(nodes) + waitForConfigReplication(t, expectedCfg, nodes, waitTimeout, waitInterval) + waitForRaftReadiness(t, nodes, waitTimeout, waitInterval) + + followerConn, err := grpc.NewClient(nodes[1].grpcAddress, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = followerConn.Close() }) + followerRaw := pb.NewRawKVClient(followerConn) + + leaderConn, err := grpc.NewClient(nodes[0].grpcAddress, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = leaderConn.Close() }) + leaderRaw := pb.NewRawKVClient(leaderConn) + + putAndWaitForRead(t, ctx, rpcTimeout, followerRaw, leaderRaw, []byte("addvoter-key"), []byte("ok"), waitTimeout, waitInterval) + + // Simulate a partition-like failure by isolating node2's raft transport. + require.NoError(t, nodes[2].tm.Close()) + nodes[2].tm = nil + + putAndWaitForRead(t, ctx, rpcTimeout, followerRaw, leaderRaw, []byte("partition-survive-key"), []byte("ok2"), waitTimeout, waitInterval) + + // Force leader change while one node is isolated, then confirm write/read path. + require.NoError(t, nodes[0].raft.LeadershipTransferToServer(raft.ServerID("1"), raft.ServerAddress(nodes[1].raftAddress)).Error()) + require.Eventually(t, func() bool { + return nodes[1].raft.State() == raft.Leader + }, waitTimeout, waitInterval) + + putAndWaitForRead(t, ctx, rpcTimeout, leaderRaw, followerRaw, []byte("leader-transfer-key"), []byte("ok3"), waitTimeout, waitInterval) +} + +func setupAddVoterJoinPathNodes(t *testing.T, ctx context.Context) ([]Node, *serverWorkers) { + t.Helper() + + ports, lis := reserveAddVoterJoinListeners(t, ctx, 3) + + // AddVoter address must point to the node's shared gRPC endpoint where + // raft transport and adapter services are served. + require.Equal(t, ports[1].raftAddress, ports[1].grpcAddress) + require.Equal(t, ports[2].raftAddress, ports[2].grpcAddress) + + leaderRedisMap := map[raft.ServerAddress]string{ + raft.ServerAddress(ports[0].raftAddress): ports[0].redisAddress, + raft.ServerAddress(ports[1].raftAddress): ports[1].redisAddress, + raft.ServerAddress(ports[2].raftAddress): ports[2].redisAddress, + } + bootstrapCfg := raft.Configuration{ + Servers: []raft.Server{ + { + Suffrage: raft.Voter, + ID: raft.ServerID("0"), + Address: raft.ServerAddress(ports[0].raftAddress), + }, + }, + } + + workers := newServerWorkers(len(ports) * 3) + nodes := make([]Node, 0, len(ports)) + for i := range ports { + nodes = append(nodes, startAddVoterJoinNode(t, workers, i, ports[i], lis[i], bootstrapCfg, leaderRedisMap)) + } + return nodes, workers +} + +func reserveAddVoterJoinListeners(t *testing.T, ctx context.Context, n int) ([]portsAdress, []listeners) { + t.Helper() + + var lc net.ListenConfig + ports := assignPorts(n) + lis := make([]listeners, 0, n) + for i := range ports { + for { + bound, ls, retry, err := bindListeners(ctx, &lc, ports[i]) + require.NoError(t, err) + if !retry { + ports[i] = bound + lis = append(lis, ls) + break + } + ports[i] = assignPorts(1)[0] + } + } + return ports, lis +} + +func startAddVoterJoinNode( + t *testing.T, + workers *serverWorkers, + idx int, + port portsAdress, + lis listeners, + bootstrapCfg raft.Configuration, + leaderRedisMap map[raft.ServerAddress]string, +) Node { + t.Helper() + + st := store.NewMVCCStore() + fsm := kv.NewKvFSM(st) + + electionTimeout := leaderElectionTimeout + if idx != 0 { + electionTimeout = followerElectionTimeout + } + + r, tm, err := newRaft(strconv.Itoa(idx), port.raftAddress, fsm, idx == 0, bootstrapCfg, electionTimeout) + require.NoError(t, err) + + s := grpc.NewServer() + trx := kv.NewTransaction(r) + coordinator := kv.NewCoordinator(trx, r) + routedStore := kv.NewLeaderRoutedStore(st, coordinator) + gs := NewGRPCServer(routedStore, coordinator, WithCloseStore()) + tm.Register(s) + pb.RegisterRawKVServer(s, gs) + pb.RegisterTransactionalKVServer(s, gs) + pb.RegisterInternalServer(s, NewInternal(trx, r, coordinator.Clock())) + leaderhealth.Setup(r, s, []string{"RawKV"}) + raftadmin.Register(s, r) + + workers.Go(func() error { + err := s.Serve(lis.grpc) + if errors.Is(err, grpc.ErrServerStopped) || errors.Is(err, net.ErrClosed) { + return nil + } + return err + }) + + rd := NewRedisServer(lis.redis, st, coordinator, leaderRedisMap) + workers.Go(func() error { + err := rd.Run() + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + }) + + ds := NewDynamoDBServer(lis.dynamo, st, coordinator) + workers.Go(func() error { + err := ds.Run() + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + }) + + return newNode( + port.grpcAddress, + port.raftAddress, + port.redisAddress, + port.dynamoAddress, + r, + tm, + s, + gs, + rd, + ds, + ) +} + +func addVotersAndAwait( + t *testing.T, + ctx context.Context, + rpcTimeout time.Duration, + admin raftadminpb.RaftAdminClient, + nodes []Node, + targets []int, +) { + t.Helper() + + for _, target := range targets { + addCtx, cancelAdd := context.WithTimeout(ctx, rpcTimeout) + future, err := admin.AddVoter(addCtx, &raftadminpb.AddVoterRequest{ + Id: strconv.Itoa(target), + Address: nodes[target].grpcAddress, + PreviousIndex: 0, + }) + cancelAdd() + require.NoError(t, err) + + awaitCtx, cancelAwait := context.WithTimeout(ctx, rpcTimeout) + await, err := admin.Await(awaitCtx, future) + cancelAwait() + require.NoError(t, err) + require.Empty(t, await.GetError()) + require.Greater(t, await.GetIndex(), uint64(0)) + } +} + +func expectedVoterConfig(nodes []Node) raft.Configuration { + servers := make([]raft.Server, 0, len(nodes)) + for i, n := range nodes { + servers = append(servers, raft.Server{ + Suffrage: raft.Voter, + ID: raft.ServerID(strconv.Itoa(i)), + Address: raft.ServerAddress(n.raftAddress), + }) + } + return raft.Configuration{Servers: servers} +} + +func putAndWaitForRead( + t *testing.T, + ctx context.Context, + rpcTimeout time.Duration, + writer pb.RawKVClient, + reader pb.RawKVClient, + key []byte, + value []byte, + waitTimeout time.Duration, + waitInterval time.Duration, +) { + t.Helper() + + putCtx, cancelPut := context.WithTimeout(ctx, rpcTimeout) + _, err := writer.RawPut(putCtx, &pb.RawPutRequest{Key: key, Value: value}) + cancelPut() + require.NoError(t, err) + + require.Eventually(t, func() bool { + getCtx, cancelGet := context.WithTimeout(ctx, rpcTimeout) + resp, getErr := reader.RawGet(getCtx, &pb.RawGetRequest{Key: key}) + cancelGet() + if getErr != nil { + return false + } + return resp.Exists && bytes.Equal(resp.Value, value) + }, waitTimeout, waitInterval) +} + +type serverWorkers struct { + wg sync.WaitGroup + errCh chan error +} + +func newServerWorkers(buffer int) *serverWorkers { + return &serverWorkers{errCh: make(chan error, buffer)} +} + +func (w *serverWorkers) Go(run func() error) { + if w == nil || run == nil { + return + } + w.wg.Add(1) + go func() { + defer w.wg.Done() + if err := run(); err != nil { + w.errCh <- err + } + }() +} + +func (w *serverWorkers) AwaitNoError(t *testing.T, timeout time.Duration) { + t.Helper() + if w == nil { + return + } + + done := make(chan struct{}) + go func() { + w.wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(timeout): + require.FailNow(t, "server goroutines did not finish in time") + } + + close(w.errCh) + for err := range w.errCh { + require.NoError(t, err) + } +} diff --git a/adapter/dynamodb.go b/adapter/dynamodb.go index b359867a..a892dcad 100644 --- a/adapter/dynamodb.go +++ b/adapter/dynamodb.go @@ -1067,7 +1067,11 @@ func resolveQueryCondition(in queryInput, schema *dynamoTableSchema) (dynamoKeyS if err != nil { return dynamoKeySchema{}, queryCondition{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) } - parsed, err := parseKeyConditionExpression(replaceNames(in.KeyConditionExpression, in.ExpressionAttributeNames)) + keyExpr, err := replaceNames(in.KeyConditionExpression, in.ExpressionAttributeNames) + if err != nil { + return dynamoKeySchema{}, queryCondition{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) + } + parsed, err := parseKeyConditionExpression(keyExpr) if err != nil { return dynamoKeySchema{}, queryCondition{}, newDynamoAPIError(http.StatusBadRequest, dynamoErrValidation, err.Error()) } @@ -1570,7 +1574,10 @@ func (d *DynamoDBServer) runTransactWriteAttempt( func (d *DynamoDBServer) buildTransactWriteItemsRequest(ctx context.Context, in transactWriteItemsInput) (*kv.OperationGroup[kv.OP], map[string]uint64, [][]byte, error) { readTS := d.nextTxnReadTS() reqs := &kv.OperationGroup[kv.OP]{ - IsTxn: true, + IsTxn: true, + // Keep transaction start aligned with the snapshot used to evaluate + // ConditionCheck/ConditionExpression so concurrent writes after readTS + // are detected as write conflicts at commit time. StartTS: readTS, } schemaCache := make(map[string]*dynamoTableSchema) @@ -1993,9 +2000,12 @@ func writeDynamoJSON(w http.ResponseWriter, payload any) { _ = json.NewEncoder(w).Encode(payload) } -func replaceNames(expr string, names map[string]string) string { +func replaceNames(expr string, names map[string]string) (string, error) { if expr == "" || len(names) == 0 { - return expr + return expr, nil + } + if err := validateExpressionAttributeNames(names); err != nil { + return "", err } keys := make([]string, 0, len(names)) for k := range names { @@ -2013,11 +2023,68 @@ func replaceNames(expr string, names map[string]string) string { for _, key := range keys { args = append(args, key, names[key]) } - return strings.NewReplacer(args...).Replace(expr) + return strings.NewReplacer(args...).Replace(expr), nil +} + +func validateExpressionAttributeNames(names map[string]string) error { + for placeholder, name := range names { + if !isExpressionAttributePlaceholder(placeholder) { + return errors.Errorf("invalid expression attribute placeholder %q", placeholder) + } + if !isExpressionAttributeName(name) { + return errors.Errorf("invalid expression attribute name %q for placeholder %q", name, placeholder) + } + } + return nil +} + +func isExpressionAttributePlaceholder(s string) bool { + if len(s) <= 1 || s[0] != '#' { + return false + } + return isExpressionPlaceholderIdentifier(s[1:]) +} + +func isExpressionPlaceholderIdentifier(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + if isExpressionPlaceholderIdentByte(s[i]) { + continue + } + return false + } + return true +} + +func isExpressionPlaceholderIdentByte(b byte) bool { + return b == '_' || (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') +} + +func isExpressionAttributeName(s string) bool { + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + if isExpressionAttributeNameByte(s[i]) { + continue + } + return false + } + return true +} + +func isExpressionAttributeNameByte(b byte) bool { + return b == '_' || b == '.' || b == '-' || (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') } func applyUpdateExpression(expr string, names map[string]string, values map[string]attributeValue, item map[string]attributeValue) error { - updExpr := strings.TrimSpace(replaceNames(expr, names)) + updExpr, err := replaceNames(expr, names) + if err != nil { + return err + } + updExpr = strings.TrimSpace(updExpr) sections, err := parseUpdateExpressionSections(updExpr) if err != nil { return err @@ -2293,7 +2360,11 @@ func splitTopLevelByComma(expr string) ([]string, error) { } func validateConditionOnItem(expr string, names map[string]string, values map[string]attributeValue, item map[string]attributeValue) error { - cond := strings.TrimSpace(replaceNames(expr, names)) + cond, err := replaceNames(expr, names) + if err != nil { + return err + } + cond = strings.TrimSpace(cond) if cond == "" { return nil } @@ -2484,9 +2555,9 @@ func isLogicalKeywordBoundary(s string, pos int) bool { return true } ch := s[pos] - // Keep ASCII letters/digits/underscore as identifier token characters so - // expressions like "MY_AND_VAR" are not split at the "AND" substring. - if (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' { + // Keep identifier-style characters as token characters so expressions like + // "MY_AND_VAR" or "a-OR-b" are not split at logical keyword substrings. + if isExpressionAttributeNameByte(ch) { return false } return true diff --git a/adapter/dynamodb_table_compat_test.go b/adapter/dynamodb_table_compat_test.go index 479dbc13..d72cd9c1 100644 --- a/adapter/dynamodb_table_compat_test.go +++ b/adapter/dynamodb_table_compat_test.go @@ -32,292 +32,272 @@ func TestDynamoDB_TableAPICompatibility(t *testing.T) { threadsTable := "threads" messagesTable := "messages" - _, err = client.CreateTable(ctx, &dynamodb.CreateTableInput{ - TableName: aws.String(threadsTable), - BillingMode: ddbTypes.BillingModePayPerRequest, - AttributeDefinitions: []ddbTypes.AttributeDefinition{ - {AttributeName: aws.String("threadId"), AttributeType: ddbTypes.ScalarAttributeTypeS}, - {AttributeName: aws.String("status"), AttributeType: ddbTypes.ScalarAttributeTypeS}, - {AttributeName: aws.String("createdAt"), AttributeType: ddbTypes.ScalarAttributeTypeS}, - }, - KeySchema: []ddbTypes.KeySchemaElement{ - {AttributeName: aws.String("threadId"), KeyType: ddbTypes.KeyTypeHash}, - }, - GlobalSecondaryIndexes: []ddbTypes.GlobalSecondaryIndex{ - { - IndexName: aws.String("statusIndex"), - KeySchema: []ddbTypes.KeySchemaElement{ - {AttributeName: aws.String("status"), KeyType: ddbTypes.KeyTypeHash}, - {AttributeName: aws.String("createdAt"), KeyType: ddbTypes.KeyTypeRange}, + createThreadsTable := func(tb testing.TB) { + tb.Helper() + _, createErr := client.CreateTable(ctx, &dynamodb.CreateTableInput{ + TableName: aws.String(threadsTable), + BillingMode: ddbTypes.BillingModePayPerRequest, + AttributeDefinitions: []ddbTypes.AttributeDefinition{ + {AttributeName: aws.String("threadId"), AttributeType: ddbTypes.ScalarAttributeTypeS}, + {AttributeName: aws.String("status"), AttributeType: ddbTypes.ScalarAttributeTypeS}, + {AttributeName: aws.String("createdAt"), AttributeType: ddbTypes.ScalarAttributeTypeS}, + }, + KeySchema: []ddbTypes.KeySchemaElement{ + {AttributeName: aws.String("threadId"), KeyType: ddbTypes.KeyTypeHash}, + }, + GlobalSecondaryIndexes: []ddbTypes.GlobalSecondaryIndex{ + { + IndexName: aws.String("statusIndex"), + KeySchema: []ddbTypes.KeySchemaElement{ + {AttributeName: aws.String("status"), KeyType: ddbTypes.KeyTypeHash}, + {AttributeName: aws.String("createdAt"), KeyType: ddbTypes.KeyTypeRange}, + }, + Projection: &ddbTypes.Projection{ProjectionType: ddbTypes.ProjectionTypeAll}, }, - Projection: &ddbTypes.Projection{ProjectionType: ddbTypes.ProjectionTypeAll}, }, - }, - }) - require.NoError(t, err) + }) + require.NoError(tb, createErr) + } - _, err = client.CreateTable(ctx, &dynamodb.CreateTableInput{ - TableName: aws.String(messagesTable), - BillingMode: ddbTypes.BillingModePayPerRequest, - AttributeDefinitions: []ddbTypes.AttributeDefinition{ - {AttributeName: aws.String("threadId"), AttributeType: ddbTypes.ScalarAttributeTypeS}, - {AttributeName: aws.String("createdAt"), AttributeType: ddbTypes.ScalarAttributeTypeS}, - }, - KeySchema: []ddbTypes.KeySchemaElement{ - {AttributeName: aws.String("threadId"), KeyType: ddbTypes.KeyTypeHash}, - {AttributeName: aws.String("createdAt"), KeyType: ddbTypes.KeyTypeRange}, - }, - }) - require.NoError(t, err) + createMessagesTable := func(tb testing.TB) { + tb.Helper() + _, createErr := client.CreateTable(ctx, &dynamodb.CreateTableInput{ + TableName: aws.String(messagesTable), + BillingMode: ddbTypes.BillingModePayPerRequest, + AttributeDefinitions: []ddbTypes.AttributeDefinition{ + {AttributeName: aws.String("threadId"), AttributeType: ddbTypes.ScalarAttributeTypeS}, + {AttributeName: aws.String("createdAt"), AttributeType: ddbTypes.ScalarAttributeTypeS}, + }, + KeySchema: []ddbTypes.KeySchemaElement{ + {AttributeName: aws.String("threadId"), KeyType: ddbTypes.KeyTypeHash}, + {AttributeName: aws.String("createdAt"), KeyType: ddbTypes.KeyTypeRange}, + }, + }) + require.NoError(tb, createErr) + } - listAllOut, err := client.ListTables(ctx, &dynamodb.ListTablesInput{}) - require.NoError(t, err) - require.ElementsMatch(t, []string{threadsTable, messagesTable}, listAllOut.TableNames) - listPageOut, err := client.ListTables(ctx, &dynamodb.ListTablesInput{Limit: aws.Int32(1)}) - require.NoError(t, err) - require.Len(t, listPageOut.TableNames, 1) - require.NotEmpty(t, aws.ToString(listPageOut.LastEvaluatedTableName)) - desc, err := client.DescribeTable(ctx, &dynamodb.DescribeTableInput{TableName: aws.String(threadsTable)}) - require.NoError(t, err) - require.NotNil(t, desc.Table) - require.Equal(t, threadsTable, aws.ToString(desc.Table.TableName)) + putThread := func(tb testing.TB, threadID string, title string, createdAt string, status string, accessToken string) { + tb.Helper() + _, putErr := client.PutItem(ctx, &dynamodb.PutItemInput{ + TableName: aws.String(threadsTable), + Item: map[string]ddbTypes.AttributeValue{ + "threadId": &ddbTypes.AttributeValueMemberS{Value: threadID}, + "title": &ddbTypes.AttributeValueMemberS{Value: title}, + "createdAt": &ddbTypes.AttributeValueMemberS{Value: createdAt}, + "status": &ddbTypes.AttributeValueMemberS{Value: status}, + "accessToken": &ddbTypes.AttributeValueMemberS{Value: accessToken}, + }, + }) + require.NoError(tb, putErr) + } - _, err = client.PutItem(ctx, &dynamodb.PutItemInput{ - TableName: aws.String(threadsTable), - Item: map[string]ddbTypes.AttributeValue{ - "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, - "title": &ddbTypes.AttributeValueMemberS{Value: "title1"}, - "createdAt": &ddbTypes.AttributeValueMemberS{Value: "2026-01-01T00:00:00Z"}, - "status": &ddbTypes.AttributeValueMemberS{Value: "pending"}, - "accessToken": &ddbTypes.AttributeValueMemberS{Value: ""}, - }, - }) - require.NoError(t, err) - _, err = client.PutItem(ctx, &dynamodb.PutItemInput{ - TableName: aws.String(threadsTable), - Item: map[string]ddbTypes.AttributeValue{ - "threadId": &ddbTypes.AttributeValueMemberS{Value: "t2"}, - "title": &ddbTypes.AttributeValueMemberS{Value: "title2"}, - "createdAt": &ddbTypes.AttributeValueMemberS{Value: "2026-01-02T00:00:00Z"}, - "status": &ddbTypes.AttributeValueMemberS{Value: "pending"}, - "accessToken": &ddbTypes.AttributeValueMemberS{Value: ""}, - }, - }) - require.NoError(t, err) - _, err = client.PutItem(ctx, &dynamodb.PutItemInput{ - TableName: aws.String(threadsTable), - Item: map[string]ddbTypes.AttributeValue{ - "threadId": &ddbTypes.AttributeValueMemberS{Value: "t3"}, - "title": &ddbTypes.AttributeValueMemberS{Value: "title3"}, - "createdAt": &ddbTypes.AttributeValueMemberS{Value: "2026-01-03T00:00:00Z"}, - "status": &ddbTypes.AttributeValueMemberS{Value: "answered"}, - "accessToken": &ddbTypes.AttributeValueMemberS{Value: ""}, - }, - }) - require.NoError(t, err) + queryThreadsByStatus := func(tb testing.TB, status string, scanIndexForward bool) *dynamodb.QueryOutput { + tb.Helper() + out, queryErr := client.Query(ctx, &dynamodb.QueryInput{ + TableName: aws.String(threadsTable), + IndexName: aws.String("statusIndex"), + KeyConditionExpression: aws.String("#status = :status"), + ExpressionAttributeNames: map[string]string{ + "#status": "status", + }, + ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ + ":status": &ddbTypes.AttributeValueMemberS{Value: status}, + }, + ScanIndexForward: aws.Bool(scanIndexForward), + }) + require.NoError(tb, queryErr) + return out + } - getOut, err := client.GetItem(ctx, &dynamodb.GetItemInput{ - TableName: aws.String(threadsTable), - Key: map[string]ddbTypes.AttributeValue{ - "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, - }, - }) - require.NoError(t, err) - require.NotNil(t, getOut.Item) - threadID, ok := getOut.Item["threadId"].(*ddbTypes.AttributeValueMemberS) - require.True(t, ok) - require.Equal(t, "t1", threadID.Value) - status, ok := getOut.Item["status"].(*ddbTypes.AttributeValueMemberS) - require.True(t, ok) - require.Equal(t, "pending", status.Value) - title, ok := getOut.Item["title"].(*ddbTypes.AttributeValueMemberS) - require.True(t, ok) - require.Equal(t, "title1", title.Value) + // Shared setup for all sub-tests in this end-to-end scenario. + createThreadsTable(t) + createMessagesTable(t) - queryThreadsOut, err := client.Query(ctx, &dynamodb.QueryInput{ - TableName: aws.String(threadsTable), - IndexName: aws.String("statusIndex"), - KeyConditionExpression: aws.String("#status = :status"), - ExpressionAttributeNames: map[string]string{ - "#status": "status", - }, - ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ - ":status": &ddbTypes.AttributeValueMemberS{Value: "pending"}, - }, - ScanIndexForward: aws.Bool(false), - }) - require.NoError(t, err) - require.Len(t, queryThreadsOut.Items, 2) - created0, ok := queryThreadsOut.Items[0]["createdAt"].(*ddbTypes.AttributeValueMemberS) - require.True(t, ok) - created1, ok := queryThreadsOut.Items[1]["createdAt"].(*ddbTypes.AttributeValueMemberS) - require.True(t, ok) - require.Equal(t, "2026-01-02T00:00:00Z", created0.Value) - require.Equal(t, "2026-01-01T00:00:00Z", created1.Value) + t.Run("TableLifecycle", func(t *testing.T) { + listAllOut, listErr := client.ListTables(ctx, &dynamodb.ListTablesInput{}) + require.NoError(t, listErr) + require.ElementsMatch(t, []string{threadsTable, messagesTable}, listAllOut.TableNames) - _, err = client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ - TableName: aws.String(threadsTable), - Key: map[string]ddbTypes.AttributeValue{ - "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, - }, - UpdateExpression: aws.String("SET #status = :status"), - ExpressionAttributeNames: map[string]string{ - "#status": "status", - }, - ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ - ":status": &ddbTypes.AttributeValueMemberS{Value: "answered"}, - }, - }) - require.NoError(t, err) - queryPendingAfterUpdate, err := client.Query(ctx, &dynamodb.QueryInput{ - TableName: aws.String(threadsTable), - IndexName: aws.String("statusIndex"), - KeyConditionExpression: aws.String("#status = :status"), - ExpressionAttributeNames: map[string]string{ - "#status": "status", - }, - ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ - ":status": &ddbTypes.AttributeValueMemberS{Value: "pending"}, - }, - ScanIndexForward: aws.Bool(false), - }) - require.NoError(t, err) - require.Len(t, queryPendingAfterUpdate.Items, 1) - queryAnsweredAfterUpdate, err := client.Query(ctx, &dynamodb.QueryInput{ - TableName: aws.String(threadsTable), - IndexName: aws.String("statusIndex"), - KeyConditionExpression: aws.String("#status = :status"), - ExpressionAttributeNames: map[string]string{ - "#status": "status", - }, - ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ - ":status": &ddbTypes.AttributeValueMemberS{Value: "answered"}, - }, - ScanIndexForward: aws.Bool(false), - }) - require.NoError(t, err) - require.Len(t, queryAnsweredAfterUpdate.Items, 2) + listPageOut, listPageErr := client.ListTables(ctx, &dynamodb.ListTablesInput{Limit: aws.Int32(1)}) + require.NoError(t, listPageErr) + require.Len(t, listPageOut.TableNames, 1) + require.NotEmpty(t, aws.ToString(listPageOut.LastEvaluatedTableName)) - _, err = client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ - TableName: aws.String(threadsTable), - Key: map[string]ddbTypes.AttributeValue{ - "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, - }, - UpdateExpression: aws.String("SET #accessToken = :accessToken"), - ConditionExpression: aws.String("attribute_exists(#threadId) AND (attribute_not_exists(#accessToken) OR #accessToken = :empty)"), - ExpressionAttributeNames: map[string]string{ - "#threadId": "threadId", - "#accessToken": "accessToken", - }, - ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ - ":accessToken": &ddbTypes.AttributeValueMemberS{Value: "token1"}, - ":empty": &ddbTypes.AttributeValueMemberS{Value: ""}, - }, + desc, descErr := client.DescribeTable(ctx, &dynamodb.DescribeTableInput{TableName: aws.String(threadsTable)}) + require.NoError(t, descErr) + require.NotNil(t, desc.Table) + require.Equal(t, threadsTable, aws.ToString(desc.Table.TableName)) }) - require.NoError(t, err) - _, err = client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ - TableName: aws.String(threadsTable), - Key: map[string]ddbTypes.AttributeValue{ - "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, - }, - UpdateExpression: aws.String("SET #accessToken = :accessToken"), - ConditionExpression: aws.String("attribute_exists(#threadId) AND (attribute_not_exists(#accessToken) OR #accessToken = :empty)"), - ExpressionAttributeNames: map[string]string{ - "#threadId": "threadId", - "#accessToken": "accessToken", - }, - ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ - ":accessToken": &ddbTypes.AttributeValueMemberS{Value: "token2"}, - ":empty": &ddbTypes.AttributeValueMemberS{Value: ""}, - }, - }) - require.Error(t, err) - var condErr *ddbTypes.ConditionalCheckFailedException - require.ErrorAs(t, err, &condErr) - getOut, err = client.GetItem(ctx, &dynamodb.GetItemInput{ - TableName: aws.String(threadsTable), - Key: map[string]ddbTypes.AttributeValue{ - "threadId": &ddbTypes.AttributeValueMemberS{Value: "does-not-exist"}, - }, - }) - require.NoError(t, err) - require.Empty(t, getOut.Item) + t.Run("ThreadItemOperations", func(t *testing.T) { + putThread(t, "t1", "title1", "2026-01-01T00:00:00Z", "pending", "") + putThread(t, "t2", "title2", "2026-01-02T00:00:00Z", "pending", "") + putThread(t, "t3", "title3", "2026-01-03T00:00:00Z", "answered", "") - _, err = client.PutItem(ctx, &dynamodb.PutItemInput{ - TableName: aws.String(messagesTable), - Item: map[string]ddbTypes.AttributeValue{ - "messageId": &ddbTypes.AttributeValueMemberS{Value: "m1"}, - "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, - "content": &ddbTypes.AttributeValueMemberS{Value: "hello"}, - "sender": &ddbTypes.AttributeValueMemberS{Value: "user"}, - "createdAt": &ddbTypes.AttributeValueMemberS{Value: "2026-01-01T00:00:01Z"}, - }, - }) - require.NoError(t, err) - _, err = client.PutItem(ctx, &dynamodb.PutItemInput{ - TableName: aws.String(messagesTable), - Item: map[string]ddbTypes.AttributeValue{ - "messageId": &ddbTypes.AttributeValueMemberS{Value: "m2"}, - "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, - "content": &ddbTypes.AttributeValueMemberS{Value: "world"}, - "sender": &ddbTypes.AttributeValueMemberS{Value: "admin"}, - "createdAt": &ddbTypes.AttributeValueMemberS{Value: "2026-01-01T00:00:02Z"}, - }, - }) - require.NoError(t, err) + getOut, getErr := client.GetItem(ctx, &dynamodb.GetItemInput{ + TableName: aws.String(threadsTable), + Key: map[string]ddbTypes.AttributeValue{ + "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, + }, + }) + require.NoError(t, getErr) + require.NotNil(t, getOut.Item) + threadID, ok := getOut.Item["threadId"].(*ddbTypes.AttributeValueMemberS) + require.True(t, ok) + require.Equal(t, "t1", threadID.Value) + threadStatus, ok := getOut.Item["status"].(*ddbTypes.AttributeValueMemberS) + require.True(t, ok) + require.Equal(t, "pending", threadStatus.Value) + title, ok := getOut.Item["title"].(*ddbTypes.AttributeValueMemberS) + require.True(t, ok) + require.Equal(t, "title1", title.Value) - queryMessagesOut, err := client.Query(ctx, &dynamodb.QueryInput{ - TableName: aws.String(messagesTable), - KeyConditionExpression: aws.String("threadId = :threadId"), - ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ - ":threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, - }, - ScanIndexForward: aws.Bool(true), + queryThreadsOut := queryThreadsByStatus(t, "pending", false) + require.Len(t, queryThreadsOut.Items, 2) + created0, ok := queryThreadsOut.Items[0]["createdAt"].(*ddbTypes.AttributeValueMemberS) + require.True(t, ok) + created1, ok := queryThreadsOut.Items[1]["createdAt"].(*ddbTypes.AttributeValueMemberS) + require.True(t, ok) + require.Equal(t, "2026-01-02T00:00:00Z", created0.Value) + require.Equal(t, "2026-01-01T00:00:00Z", created1.Value) + + _, updateErr := client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ + TableName: aws.String(threadsTable), + Key: map[string]ddbTypes.AttributeValue{ + "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, + }, + UpdateExpression: aws.String("SET #status = :status"), + ExpressionAttributeNames: map[string]string{ + "#status": "status", + }, + ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ + ":status": &ddbTypes.AttributeValueMemberS{Value: "answered"}, + }, + }) + require.NoError(t, updateErr) + + queryPendingAfterUpdate := queryThreadsByStatus(t, "pending", false) + require.Len(t, queryPendingAfterUpdate.Items, 1) + queryAnsweredAfterUpdate := queryThreadsByStatus(t, "answered", false) + require.Len(t, queryAnsweredAfterUpdate.Items, 2) + + _, updateErr = client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ + TableName: aws.String(threadsTable), + Key: map[string]ddbTypes.AttributeValue{ + "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, + }, + UpdateExpression: aws.String("SET #accessToken = :accessToken"), + ConditionExpression: aws.String("attribute_exists(#threadId) AND (attribute_not_exists(#accessToken) OR #accessToken = :empty)"), + ExpressionAttributeNames: map[string]string{ + "#threadId": "threadId", + "#accessToken": "accessToken", + }, + ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ + ":accessToken": &ddbTypes.AttributeValueMemberS{Value: "token1"}, + ":empty": &ddbTypes.AttributeValueMemberS{Value: ""}, + }, + }) + require.NoError(t, updateErr) + + _, updateErr = client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ + TableName: aws.String(threadsTable), + Key: map[string]ddbTypes.AttributeValue{ + "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, + }, + UpdateExpression: aws.String("SET #accessToken = :accessToken"), + ConditionExpression: aws.String("attribute_exists(#threadId) AND (attribute_not_exists(#accessToken) OR #accessToken = :empty)"), + ExpressionAttributeNames: map[string]string{ + "#threadId": "threadId", + "#accessToken": "accessToken", + }, + ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ + ":accessToken": &ddbTypes.AttributeValueMemberS{Value: "token2"}, + ":empty": &ddbTypes.AttributeValueMemberS{Value: ""}, + }, + }) + require.Error(t, updateErr) + var condErr *ddbTypes.ConditionalCheckFailedException + require.ErrorAs(t, updateErr, &condErr) + + missingOut, missingErr := client.GetItem(ctx, &dynamodb.GetItemInput{ + TableName: aws.String(threadsTable), + Key: map[string]ddbTypes.AttributeValue{ + "threadId": &ddbTypes.AttributeValueMemberS{Value: "does-not-exist"}, + }, + }) + require.NoError(t, missingErr) + require.Empty(t, missingOut.Item) }) - require.NoError(t, err) - require.Len(t, queryMessagesOut.Items, 2) - mc0, ok := queryMessagesOut.Items[0]["createdAt"].(*ddbTypes.AttributeValueMemberS) - require.True(t, ok) - mc1, ok := queryMessagesOut.Items[1]["createdAt"].(*ddbTypes.AttributeValueMemberS) - require.True(t, ok) - require.Equal(t, "2026-01-01T00:00:01Z", mc0.Value) - require.Equal(t, "2026-01-01T00:00:02Z", mc1.Value) - _, err = client.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{ - TransactItems: []ddbTypes.TransactWriteItem{ - { - Put: &ddbTypes.Put{ - TableName: aws.String(threadsTable), - Item: map[string]ddbTypes.AttributeValue{ - "threadId": &ddbTypes.AttributeValueMemberS{Value: "t4"}, - "title": &ddbTypes.AttributeValueMemberS{Value: "title4"}, - "createdAt": &ddbTypes.AttributeValueMemberS{Value: "2026-01-04T00:00:00Z"}, - "status": &ddbTypes.AttributeValueMemberS{Value: "pending"}, - "accessToken": &ddbTypes.AttributeValueMemberS{Value: ""}, + t.Run("MessagesQuery", func(t *testing.T) { + _, putErr := client.PutItem(ctx, &dynamodb.PutItemInput{ + TableName: aws.String(messagesTable), + Item: map[string]ddbTypes.AttributeValue{ + "messageId": &ddbTypes.AttributeValueMemberS{Value: "m1"}, + "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, + "content": &ddbTypes.AttributeValueMemberS{Value: "hello"}, + "sender": &ddbTypes.AttributeValueMemberS{Value: "user"}, + "createdAt": &ddbTypes.AttributeValueMemberS{Value: "2026-01-01T00:00:01Z"}, + }, + }) + require.NoError(t, putErr) + _, putErr = client.PutItem(ctx, &dynamodb.PutItemInput{ + TableName: aws.String(messagesTable), + Item: map[string]ddbTypes.AttributeValue{ + "messageId": &ddbTypes.AttributeValueMemberS{Value: "m2"}, + "threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, + "content": &ddbTypes.AttributeValueMemberS{Value: "world"}, + "sender": &ddbTypes.AttributeValueMemberS{Value: "admin"}, + "createdAt": &ddbTypes.AttributeValueMemberS{Value: "2026-01-01T00:00:02Z"}, + }, + }) + require.NoError(t, putErr) + + queryMessagesOut, queryErr := client.Query(ctx, &dynamodb.QueryInput{ + TableName: aws.String(messagesTable), + KeyConditionExpression: aws.String("threadId = :threadId"), + ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ + ":threadId": &ddbTypes.AttributeValueMemberS{Value: "t1"}, + }, + ScanIndexForward: aws.Bool(true), + }) + require.NoError(t, queryErr) + require.Len(t, queryMessagesOut.Items, 2) + mc0, ok := queryMessagesOut.Items[0]["createdAt"].(*ddbTypes.AttributeValueMemberS) + require.True(t, ok) + mc1, ok := queryMessagesOut.Items[1]["createdAt"].(*ddbTypes.AttributeValueMemberS) + require.True(t, ok) + require.Equal(t, "2026-01-01T00:00:01Z", mc0.Value) + require.Equal(t, "2026-01-01T00:00:02Z", mc1.Value) + }) + + t.Run("TransactWriteItems", func(t *testing.T) { + _, txErr := client.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{ + TransactItems: []ddbTypes.TransactWriteItem{ + { + Put: &ddbTypes.Put{ + TableName: aws.String(threadsTable), + Item: map[string]ddbTypes.AttributeValue{ + "threadId": &ddbTypes.AttributeValueMemberS{Value: "t4"}, + "title": &ddbTypes.AttributeValueMemberS{Value: "title4"}, + "createdAt": &ddbTypes.AttributeValueMemberS{Value: "2026-01-04T00:00:00Z"}, + "status": &ddbTypes.AttributeValueMemberS{Value: "pending_txn"}, + "accessToken": &ddbTypes.AttributeValueMemberS{Value: ""}, + }, }, }, }, - }, - }) - require.NoError(t, err) - queryPendingAfterTransact, err := client.Query(ctx, &dynamodb.QueryInput{ - TableName: aws.String(threadsTable), - IndexName: aws.String("statusIndex"), - KeyConditionExpression: aws.String("#status = :status"), - ExpressionAttributeNames: map[string]string{ - "#status": "status", - }, - ExpressionAttributeValues: map[string]ddbTypes.AttributeValue{ - ":status": &ddbTypes.AttributeValueMemberS{Value: "pending"}, - }, - ScanIndexForward: aws.Bool(false), + }) + require.NoError(t, txErr) + + queryPendingAfterTransact := queryThreadsByStatus(t, "pending_txn", false) + require.Len(t, queryPendingAfterTransact.Items, 1) }) - require.NoError(t, err) - require.Len(t, queryPendingAfterTransact.Items, 2) - _, err = client.DeleteTable(ctx, &dynamodb.DeleteTableInput{TableName: aws.String(messagesTable)}) - require.NoError(t, err) + t.Run("DeleteTable", func(t *testing.T) { + _, deleteErr := client.DeleteTable(ctx, &dynamodb.DeleteTableInput{TableName: aws.String(messagesTable)}) + require.NoError(t, deleteErr) + }) } func TestDynamoDB_UpdateItem_ConditionOnMissingItemFails(t *testing.T) { diff --git a/adapter/dynamodb_test.go b/adapter/dynamodb_test.go index f4ac8e90..1b1cf9cc 100644 --- a/adapter/dynamodb_test.go +++ b/adapter/dynamodb_test.go @@ -210,6 +210,60 @@ func TestDynamoDB_UpdateItem_Condition(t *testing.T) { assert.Error(t, err) } +func TestDynamoDB_UpdateItem_RejectsExpressionAttributeNameInjection(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + + cfg, err := config.LoadDefaultConfig(context.Background(), + config.WithRegion("us-west-2"), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("dummy", "dummy", "")), + ) + require.NoError(t, err) + + client := dynamodb.NewFromConfig(cfg, func(o *dynamodb.Options) { + o.BaseEndpoint = aws.String("http://" + nodes[0].dynamoAddress) + }) + createSimpleKeyTable(t, context.Background(), client) + + _, err = client.PutItem(context.Background(), &dynamodb.PutItemInput{ + TableName: aws.String("t"), + Item: map[string]types.AttributeValue{ + "key": &types.AttributeValueMemberS{Value: "test"}, + "value": &types.AttributeValueMemberS{Value: "v1"}, + }, + }) + require.NoError(t, err) + + _, err = client.UpdateItem(context.Background(), &dynamodb.UpdateItemInput{ + TableName: aws.String("t"), + Key: map[string]types.AttributeValue{ + "key": &types.AttributeValueMemberS{Value: "test"}, + }, + UpdateExpression: aws.String("SET #v = :val"), + ConditionExpression: aws.String("attribute_exists(#guard)"), + ExpressionAttributeNames: map[string]string{ + "#v": "value", + "#guard": "missing) OR attribute_exists(key", + }, + ExpressionAttributeValues: map[string]types.AttributeValue{ + ":val": &types.AttributeValueMemberS{Value: "v2"}, + }, + }) + require.ErrorContains(t, err, "invalid expression attribute name") + + out, err := client.GetItem(context.Background(), &dynamodb.GetItemInput{ + TableName: aws.String("t"), + Key: map[string]types.AttributeValue{ + "key": &types.AttributeValueMemberS{Value: "test"}, + }, + }) + require.NoError(t, err) + valueAttr, ok := out.Item["value"].(*types.AttributeValueMemberS) + require.True(t, ok) + require.Equal(t, "v1", valueAttr.Value) +} + func TestDynamoDB_TransactWriteItems_Concurrent(t *testing.T) { t.Parallel() nodes, _, _ := createNode(t, 1) @@ -418,6 +472,70 @@ func TestEvalConditionExpression_LogicalKeywordWithoutSpaces(t *testing.T) { require.True(t, ok) } +func TestReplaceNames_ValidatesExpressionAttributeNames(t *testing.T) { + t.Parallel() + + t.Run("invalid placeholder", func(t *testing.T) { + t.Parallel() + _, err := replaceNames("attribute_exists(#name)", map[string]string{ + "name": "value", + }) + require.ErrorContains(t, err, `invalid expression attribute placeholder "name"`) + }) + + t.Run("invalid placeholder character", func(t *testing.T) { + t.Parallel() + _, err := replaceNames("attribute_exists(#na-me)", map[string]string{ + "#na-me": "value", + }) + require.ErrorContains(t, err, `invalid expression attribute placeholder "#na-me"`) + }) + + t.Run("invalid attribute name", func(t *testing.T) { + t.Parallel() + _, err := replaceNames("attribute_exists(#name)", map[string]string{ + "#name": "value OR attribute_exists(key)", + }) + require.ErrorContains(t, err, `invalid expression attribute name "value OR attribute_exists(key)"`) + }) + + t.Run("valid replacement", func(t *testing.T) { + t.Parallel() + expr, err := replaceNames("attribute_exists(#name)", map[string]string{ + "#name": "value_1", + }) + require.NoError(t, err) + require.Equal(t, "attribute_exists(value_1)", expr) + }) + + t.Run("valid replacement with dot and hyphen", func(t *testing.T) { + t.Parallel() + expr, err := replaceNames("#left = :l AND #right = :r", map[string]string{ + "#left": "data.field", + "#right": "my-attribute", + }) + require.NoError(t, err) + require.Equal(t, "data.field = :l AND my-attribute = :r", expr) + }) +} + +func TestValidateConditionOnItem_AttributeNameContainsLogicalKeywordSubstring(t *testing.T) { + t.Parallel() + + item := map[string]attributeValue{ + "a-OR-b": newStringAttributeValue("ok"), + } + values := map[string]attributeValue{ + ":v": newStringAttributeValue("ok"), + } + names := map[string]string{ + "#k": "a-OR-b", + } + + err := validateConditionOnItem("#k = :v", names, values, item) + require.NoError(t, err) +} + func TestQueryExclusiveStartKey_AppliesAfterOrdering(t *testing.T) { schema := &dynamoTableSchema{ PrimaryKey: dynamoKeySchema{ diff --git a/cmd/server/demo.go b/cmd/server/demo.go index de1ad516..439a076c 100644 --- a/cmd/server/demo.go +++ b/cmd/server/demo.go @@ -47,6 +47,7 @@ const ( joinRetries = 20 joinWait = 3 * time.Second joinRetryInterval = 1 * time.Second + joinRPCTimeout = 3 * time.Second ) func init() { @@ -188,31 +189,67 @@ func joinCluster(ctx context.Context, nodes []config) error { client := raftadminpb.NewRaftAdminClient(conn) for _, n := range nodes[1:] { - var joined bool - for i := 0; i < joinRetries; i++ { - slog.Info("Attempting to join node", "id", n.raftID, "address", n.address) - _, err := client.AddVoter(ctx, &raftadminpb.AddVoterRequest{ - Id: n.raftID, - Address: n.address, - PreviousIndex: 0, - }) - if err == nil { - slog.Info("Successfully joined node", "id", n.raftID) - joined = true - break - } - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return joinClusterWaitError(errors.WithStack(err)) + if err := joinNodeWithRetry(ctx, client, n); err != nil { + return err + } + } + return nil +} + +func joinNodeWithRetry(ctx context.Context, client raftadminpb.RaftAdminClient, n config) error { + for i := 0; i < joinRetries; i++ { + if err := tryJoinNode(ctx, client, n); err == nil { + return nil + } else { + if ctx.Err() != nil { + // Retry loop should stop immediately once the parent context is canceled. + return joinRetryCancelResult(ctx) } slog.Warn("Failed to join node, retrying...", "id", n.raftID, "err", err) - if err := waitForJoinRetry(ctx, joinRetryInterval); err != nil { - return joinClusterWaitError(err) - } } - if !joined { - return fmt.Errorf("failed to join node %s after retries", n.raftID) + if i == joinRetries-1 { + break + } + if err := waitForJoinRetry(ctx, joinRetryInterval); err != nil { + return joinRetryCancelResult(ctx) } } + if ctx.Err() != nil { + return joinRetryCancelResult(ctx) + } + return fmt.Errorf("failed to join node %s after retries", n.raftID) +} + +func joinRetryCancelResult(ctx context.Context) error { + if ctx == nil || ctx.Err() == nil { + return nil + } + return joinClusterWaitError(errors.WithStack(ctx.Err())) +} + +func tryJoinNode(ctx context.Context, client raftadminpb.RaftAdminClient, n config) error { + slog.Info("Attempting to join node", "id", n.raftID, "address", n.address) + addCtx, cancelAdd := context.WithTimeout(ctx, joinRPCTimeout) + defer cancelAdd() + future, err := client.AddVoter(addCtx, &raftadminpb.AddVoterRequest{ + Id: n.raftID, + Address: n.address, + PreviousIndex: 0, + }) + if err != nil { + return errors.WithStack(err) + } + + awaitCtx, cancelAwait := context.WithTimeout(ctx, joinRPCTimeout) + defer cancelAwait() + await, err := client.Await(awaitCtx, future) + if err != nil { + return errors.WithStack(err) + } + if await.GetError() != "" { + return errors.New(await.GetError()) + } + slog.Info("Successfully joined node", "id", n.raftID) return nil } diff --git a/docs/docker_multinode_manual_run.md b/docs/docker_multinode_manual_run.md new file mode 100644 index 00000000..2cc50660 --- /dev/null +++ b/docs/docker_multinode_manual_run.md @@ -0,0 +1,235 @@ +# Run a 4-5 Node Elastickv Cluster on Multiple VMs (Docker `run`, No Docker Compose) + +This guide explains how to run Elastickv as a Raft cluster across multiple VMs using only `docker run`. + +- No Docker Compose +- Not a single-VM deployment +- One Elastickv node per VM + +## English Summary + +Use `docker run` on 4 or 5 separate VMs and bootstrap the cluster from a fixed voter list via `--raftBootstrapMembers`. +Start all initial nodes (`n1` to `n4`/`n5`) with both `--raftBootstrapMembers` and `--raftBootstrap`, then wait for quorum before sending write traffic. +Use private VM IPs for all bind addresses and lock down network access because gRPC, Redis, and DynamoDB-compatible endpoints are unauthenticated by default. + +## Target Topology + +- 1 node = 1 VM +- Total nodes: 4 or 5 +- VMs must be able to reach each other over TCP (at minimum `50051/tcp`) +- Docker Engine installed on every VM + +Example (5 nodes): + +| Node ID | VM | IP | +| --- | --- | --- | +| n1 | vm1 | 10.0.0.11 | +| n2 | vm2 | 10.0.0.12 | +| n3 | vm3 | 10.0.0.13 | +| n4 | vm4 | 10.0.0.14 | +| n5 | vm5 | 10.0.0.15 | + +For a 4-node cluster, remove `n5`. + +## Fault Tolerance + +| Nodes | Quorum | Tolerated Simultaneous Failures | +| --- | --- | --- | +| 4 | 3 | 1 | +| 5 | 3 | 2 (recommended) | + +If fault tolerance is a priority, use 5 nodes. + +## Security Requirements (Stronger Defaults) + +These examples prioritize operational clarity, not hardening. Treat them as baseline only. + +- Mandatory isolation: do not expose any Elastickv cluster port to the public Internet or to shared/multi-tenant networks. +- Restrict inbound sources with firewall/security groups/NACLs so only Elastickv nodes and tightly controlled admin hosts can connect. +- Endpoints are unauthenticated and plaintext by default: gRPC (including `RaftAdmin`), Redis API, and DynamoDB-compatible API. +- Any principal with network access to these ports can read/modify data and reconfigure cluster membership. +- Run Elastickv on dedicated private subnets/VPCs for the cluster, without direct Internet routing and without cross-tenant sharing. +- Enforce network segmentation and, where possible, add TLS/mTLS and authentication via Elastickv features or a trusted terminating proxy/service mesh. +- Do not bind advertised service addresses to `0.0.0.0` or `localhost` in cluster flags; use each VM's routable private IP (for example, `10.0.0.11`). + +## 1) Pull the Image on All VMs + +```bash +docker pull ghcr.io/bootjp/elastickv:latest +``` + +## 2) Prepare Data Directory on All VMs + +```bash +sudo mkdir -p /var/lib/elastickv +``` + +## 3) Define Shared Cluster Variables + +`RAFT_TO_REDIS_MAP` is only for Redis leader routing. +It is not used for Raft transport membership. + +Raft node-to-node communication uses `--address` (gRPC transport). + +Shared `RAFT_TO_REDIS_MAP` (5-node example): + +```bash +RAFT_TO_REDIS_MAP="10.0.0.11:50051=10.0.0.11:6379,10.0.0.12:50051=10.0.0.12:6379,10.0.0.13:50051=10.0.0.13:6379,10.0.0.14:50051=10.0.0.14:6379,10.0.0.15:50051=10.0.0.15:6379" +``` + +Shared fixed bootstrap voters (5-node example): + +```bash +RAFT_BOOTSTRAP_MEMBERS="n1=10.0.0.11:50051,n2=10.0.0.12:50051,n3=10.0.0.13:50051,n4=10.0.0.14:50051,n5=10.0.0.15:50051" +``` + +For a 4-node cluster, remove the `n5` entry from both variables. + +## 4) Start Nodes with `docker run` + +This guide uses `--network host` and explicit VM private IPs. + +Binding guidance: + +- Set `--address`, `--redisAddress`, and `--dynamoAddress` to the VM private IP. +- Do not use `0.0.0.0` as the advertised address. +- Do not use `localhost` for cluster communication. + +`n1` (initial voter): + +```bash +docker rm -f elastickv 2>/dev/null || true + +docker run -d \ + --name elastickv \ + --restart unless-stopped \ + --network host \ + -v /var/lib/elastickv:/var/lib/elastickv \ + ghcr.io/bootjp/elastickv:latest /app \ + --address "10.0.0.11:50051" \ + --redisAddress "10.0.0.11:6379" \ + --dynamoAddress "10.0.0.11:8000" \ + --raftId "n1" \ + --raftDataDir "/var/lib/elastickv" \ + --raftRedisMap "${RAFT_TO_REDIS_MAP}" \ + --raftBootstrapMembers "${RAFT_BOOTSTRAP_MEMBERS}" \ + --raftBootstrap +``` + +`n2` (initial voter): + +```bash +docker rm -f elastickv 2>/dev/null || true + +docker run -d \ + --name elastickv \ + --restart unless-stopped \ + --network host \ + -v /var/lib/elastickv:/var/lib/elastickv \ + ghcr.io/bootjp/elastickv:latest /app \ + --address "10.0.0.12:50051" \ + --redisAddress "10.0.0.12:6379" \ + --dynamoAddress "10.0.0.12:8000" \ + --raftId "n2" \ + --raftDataDir "/var/lib/elastickv" \ + --raftRedisMap "${RAFT_TO_REDIS_MAP}" \ + --raftBootstrapMembers "${RAFT_BOOTSTRAP_MEMBERS}" \ + --raftBootstrap +``` + +Start `n3` to `n5` the same way by replacing: + +- `--address` +- `--redisAddress` +- `--dynamoAddress` +- `--raftId` + +## Startup Order and Quorum Caution + +Recommended startup sequence: + +1. Start `n1`. +2. Immediately start enough peers to form quorum (`n2` and `n3` at minimum). +3. Start remaining nodes (`n4`, `n5`). +4. Send write traffic only after quorum is confirmed and leader election is stable. + +Important behavior: + +- In a 5-node cluster, at least 3 voters must be up for writes to commit. +- In a 4-node cluster, at least 3 voters must be up for writes to commit. +- If fewer than quorum nodes are running, leader election/commit may stall and writes may fail or hang. + +## 5) Verify Cluster Convergence + +Run on `n1`: + +```bash +# Pin to an immutable digest. +GRPCURL_IMG="fullstorydev/grpcurl@sha256:085e183ca334eb4e81ca81ee12cbb2b2737505d1d77f5e33dabc5d066593d998" + +# Optional: re-check the current digest for v1.9.3 before use. +# TOKEN=$(curl -fsSL 'https://auth.docker.io/token?service=registry.docker.io&scope=repository:fullstorydev/grpcurl:pull' | jq -r .token) +# curl -fsSI -H "Authorization: Bearer ${TOKEN}" -H 'Accept: application/vnd.docker.distribution.manifest.list.v2+json' \ +# https://registry-1.docker.io/v2/fullstorydev/grpcurl/manifests/v1.9.3 | tr -d '\r' | grep -i docker-content-digest + +# Wait for every node gRPC endpoint +for ip in 10.0.0.11 10.0.0.12 10.0.0.13 10.0.0.14 10.0.0.15; do + until docker run --rm --network host "${GRPCURL_IMG}" \ + -plaintext "${ip}:50051" list >/dev/null 2>&1; do + sleep 1 + done +done + +# Check Raft members +docker run --rm --network host "${GRPCURL_IMG}" \ + -plaintext -d '{}' 10.0.0.11:50051 RaftAdmin/GetConfiguration +``` + +For a 4-node cluster, remove `10.0.0.15` from the loop. + +## 6) Validate Read/Write + +Check leader: + +```bash +docker run --rm --network host "${GRPCURL_IMG}" \ + -plaintext -d '{}' 10.0.0.11:50051 RaftAdmin/Leader +``` + +Write/read through Redis endpoints: + +```bash +redis-cli -h 10.0.0.11 -p 6379 SET health ok +redis-cli -h 10.0.0.12 -p 6379 GET health +``` + +## 7) Fault Tolerance Drill + +With 5 nodes, the cluster should continue serving with up to 2 node failures. + +Example: stop `n4` and `n5`: + +```bash +docker stop elastickv +``` + +Then verify writes/reads still succeed from remaining nodes: + +```bash +redis-cli -h 10.0.0.11 -p 6379 SET survive yes +redis-cli -h 10.0.0.12 -p 6379 GET survive +``` + +## Stop and Cleanup + +Stop/remove on each VM: + +```bash +docker rm -f elastickv 2>/dev/null || true +``` + +Remove persisted data (if required): + +```bash +sudo rm -rf /var/lib/elastickv/* +``` diff --git a/main.go b/main.go index fd77c5fe..d7366e42 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "flag" "log" "net" + "strings" "sync" "time" @@ -29,15 +30,16 @@ const ( ) var ( - myAddr = flag.String("address", "localhost:50051", "TCP host+port for this node") - redisAddr = flag.String("redisAddress", "localhost:6379", "TCP host+port for redis") - dynamoAddr = flag.String("dynamoAddress", "localhost:8000", "TCP host+port for DynamoDB-compatible API") - raftId = flag.String("raftId", "", "Node id used by Raft") - raftDir = flag.String("raftDataDir", "data/", "Raft data dir") - raftBootstrap = flag.Bool("raftBootstrap", false, "Whether to bootstrap the Raft cluster") - raftGroups = flag.String("raftGroups", "", "Comma-separated raft groups (groupID=host:port,...)") - shardRanges = flag.String("shardRanges", "", "Comma-separated shard ranges (start:end=groupID,...)") - raftRedisMap = flag.String("raftRedisMap", "", "Map of Raft address to Redis address (raftAddr=redisAddr,...)") + myAddr = flag.String("address", "localhost:50051", "TCP host+port for this node") + redisAddr = flag.String("redisAddress", "localhost:6379", "TCP host+port for redis") + dynamoAddr = flag.String("dynamoAddress", "localhost:8000", "TCP host+port for DynamoDB-compatible API") + raftId = flag.String("raftId", "", "Node id used by Raft") + raftDir = flag.String("raftDataDir", "data/", "Raft data dir") + raftBootstrap = flag.Bool("raftBootstrap", false, "Whether to bootstrap the Raft cluster") + raftBootstrapMembers = flag.String("raftBootstrapMembers", "", "Comma-separated bootstrap raft members (raftID=host:port,...)") + raftGroups = flag.String("raftGroups", "", "Comma-separated raft groups (groupID=host:port,...)") + shardRanges = flag.String("shardRanges", "", "Comma-separated shard ranges (start:end=groupID,...)") + raftRedisMap = flag.String("raftRedisMap", "", "Map of Raft address to Redis address (raftAddr=redisAddr,...)") ) func main() { @@ -59,8 +61,13 @@ func run() error { if err != nil { return err } + bootstrapServers, err := resolveBootstrapServers(*raftId, cfg.groups, *raftBootstrapMembers) + if err != nil { + return err + } + bootstrap := *raftBootstrap || len(bootstrapServers) > 0 - runtimes, shardGroups, err := buildShardGroups(*raftId, *raftDir, cfg.groups, cfg.multi, *raftBootstrap) + runtimes, shardGroups, err := buildShardGroups(*raftId, *raftDir, cfg.groups, cfg.multi, bootstrap, bootstrapServers) if err != nil { return err } @@ -90,14 +97,21 @@ func run() error { adapter.WithDistributionCoordinator(coordinate), ) - if err := startRaftServers(runCtx, &lc, eg, runtimes, shardStore, coordinate, distServer); err != nil { - return waitErrgroupAfterStartupFailure(cancel, eg, err) - } - if err := startRedisServer(runCtx, &lc, eg, *redisAddr, shardStore, coordinate, cfg.leaderRedis); err != nil { - return waitErrgroupAfterStartupFailure(cancel, eg, err) + runner := runtimeServerRunner{ + ctx: runCtx, + lc: &lc, + eg: eg, + cancel: cancel, + runtimes: runtimes, + shardStore: shardStore, + coordinate: coordinate, + distServer: distServer, + redisAddress: *redisAddr, + leaderRedis: cfg.leaderRedis, + dynamoAddress: *dynamoAddr, } - if err := startDynamoDBServer(runCtx, &lc, eg, *dynamoAddr, shardStore, coordinate); err != nil { - return waitErrgroupAfterStartupFailure(cancel, eg, err) + if err := runner.start(); err != nil { + return err } if err := eg.Wait(); err != nil { @@ -167,13 +181,49 @@ func buildLeaderRedis(groups []groupSpec, redisAddr string, raftRedisMap string) return leaderRedis, nil } -func buildShardGroups(raftID string, raftDir string, groups []groupSpec, multi bool, bootstrap bool) ([]*raftGroupRuntime, map[uint64]*kv.ShardGroup, error) { +var ( + ErrBootstrapMembersRequireSingleGroup = errors.New("flag --raftBootstrapMembers requires exactly one raft group") + ErrBootstrapMembersMissingLocalNode = errors.New("flag --raftBootstrapMembers must include local --raftId") + ErrBootstrapMembersLocalAddrMismatch = errors.New("flag --raftBootstrapMembers local address must match local raft group address") + ErrNoBootstrapMembersConfigured = errors.New("no bootstrap members configured") +) + +func resolveBootstrapServers(raftID string, groups []groupSpec, bootstrapMembers string) ([]raft.Server, error) { + if strings.TrimSpace(bootstrapMembers) == "" { + return nil, nil + } + if len(groups) != 1 { + return nil, errors.WithStack(ErrBootstrapMembersRequireSingleGroup) + } + + servers, err := parseRaftBootstrapMembers(bootstrapMembers) + if err != nil { + return nil, errors.Wrap(err, "failed to parse raft bootstrap members") + } + if len(servers) == 0 { + return nil, errors.WithStack(ErrNoBootstrapMembersConfigured) + } + + localAddr := groups[0].address + for _, s := range servers { + if string(s.ID) != raftID { + continue + } + if string(s.Address) != localAddr { + return nil, errors.Wrapf(ErrBootstrapMembersLocalAddrMismatch, "expected %q got %q", localAddr, s.Address) + } + return servers, nil + } + return nil, errors.Wrapf(ErrBootstrapMembersMissingLocalNode, "raftId=%q", raftID) +} + +func buildShardGroups(raftID string, raftDir string, groups []groupSpec, multi bool, bootstrap bool, bootstrapServers []raft.Server) ([]*raftGroupRuntime, map[uint64]*kv.ShardGroup, error) { runtimes := make([]*raftGroupRuntime, 0, len(groups)) shardGroups := make(map[uint64]*kv.ShardGroup, len(groups)) for _, g := range groups { st := store.NewMVCCStore() fsm := kv.NewKvFSM(st) - r, tm, closeStores, err := newRaftGroup(raftID, g, raftDir, multi, bootstrap, fsm) + r, tm, closeStores, err := newRaftGroup(raftID, g, raftDir, multi, bootstrap, bootstrapServers, fsm) if err != nil { for _, rt := range runtimes { rt.Close() @@ -369,3 +419,30 @@ func waitErrgroupAfterStartupFailure(cancel context.CancelFunc, eg *errgroup.Gro } return startupErr } + +type runtimeServerRunner struct { + ctx context.Context + lc *net.ListenConfig + eg *errgroup.Group + cancel context.CancelFunc + runtimes []*raftGroupRuntime + shardStore *kv.ShardStore + coordinate kv.Coordinator + distServer *adapter.DistributionServer + redisAddress string + leaderRedis map[raft.ServerAddress]string + dynamoAddress string +} + +func (r runtimeServerRunner) start() error { + if err := startRaftServers(r.ctx, r.lc, r.eg, r.runtimes, r.shardStore, r.coordinate, r.distServer); err != nil { + return waitErrgroupAfterStartupFailure(r.cancel, r.eg, err) + } + if err := startRedisServer(r.ctx, r.lc, r.eg, r.redisAddress, r.shardStore, r.coordinate, r.leaderRedis); err != nil { + return waitErrgroupAfterStartupFailure(r.cancel, r.eg, err) + } + if err := startDynamoDBServer(r.ctx, r.lc, r.eg, r.dynamoAddress, r.shardStore, r.coordinate); err != nil { + return waitErrgroupAfterStartupFailure(r.cancel, r.eg, err) + } + return nil +} diff --git a/main_bootstrap_e2e_test.go b/main_bootstrap_e2e_test.go new file mode 100644 index 00000000..4c86861f --- /dev/null +++ b/main_bootstrap_e2e_test.go @@ -0,0 +1,584 @@ +package main + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "strings" + "sync" + "syscall" + "testing" + "time" + + "github.com/Jille/raft-grpc-leader-rpc/leaderhealth" + "github.com/Jille/raftadmin" + "github.com/bootjp/elastickv/adapter" + "github.com/bootjp/elastickv/kv" + pb "github.com/bootjp/elastickv/proto" + "github.com/hashicorp/raft" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/reflection" +) + +type bootstrapE2EEndpoint struct { + id string + raftAddr string + redisAddr string + dynamoAddr string +} + +type bootstrapE2EListeners struct { + grpc net.Listener + redis net.Listener + dynamo net.Listener +} + +type bootstrapE2ENode struct { + id string + runtimes []*raftGroupRuntime + + shardStore *kv.ShardStore + cancel context.CancelFunc + eg *errgroup.Group +} + +func (n *bootstrapE2ENode) raft() *raft.Raft { + if n == nil || len(n.runtimes) == 0 { + return nil + } + return n.runtimes[0].raft +} + +func (n *bootstrapE2ENode) close() error { + if n == nil { + return nil + } + if n.cancel != nil { + n.cancel() + } + + var waitErr error + if n.eg != nil { + waitErr = n.eg.Wait() + } + + if n.shardStore != nil { + _ = n.shardStore.Close() + n.shardStore = nil + } + for _, rt := range n.runtimes { + if rt != nil { + rt.Close() + } + } + n.runtimes = nil + return waitErr +} + +func TestRaftBootstrapMembers_E2E_FixedClusterWithoutAddVoter(t *testing.T) { + const ( + startupAttempts = 5 + nodeCount = 4 + waitTimeout = 20 * time.Second + waitInterval = 100 * time.Millisecond + rpcTimeout = 2 * time.Second + ) + + baseDir := t.TempDir() + endpoints, nodes := startBootstrapE2ECluster(t, baseDir, nodeCount, startupAttempts) + t.Cleanup(func() { closeBootstrapE2ENodes(t, nodes) }) + + expected := bootstrapExpectedServers(endpoints) + waitForBootstrapClusterConfig(t, nodes, expected, waitTimeout, waitInterval) + leaderIdx := waitForSingleLeader(t, nodes, waitTimeout, waitInterval) + + clients, conns := rawKVClients(t, endpoints) + t.Cleanup(func() { + for _, conn := range conns { + _ = conn.Close() + } + }) + + writerIdx := (leaderIdx + 1) % len(clients) + key := []byte("bootstrap-members-e2e-key") + value := []byte("bootstrap-members-e2e-value") + + require.NoError(t, rawPutWithTimeout(clients[writerIdx], key, value, rpcTimeout)) + + for i := range clients { + client := clients[i] + require.Eventually(t, func() bool { + resp, getErr := rawGetWithTimeout(client, key, rpcTimeout) + if getErr != nil { + return false + } + return resp.Exists && bytes.Equal(resp.Value, value) + }, waitTimeout, waitInterval) + } +} + +func startBootstrapE2ECluster( + t *testing.T, + baseDir string, + nodeCount int, + startupAttempts int, +) ([]bootstrapE2EEndpoint, []*bootstrapE2ENode) { + t.Helper() + + var ( + lastErr error + nodes []*bootstrapE2ENode + ) + + for attempt := 0; attempt < startupAttempts; attempt++ { + endpoints, listeners := allocateBootstrapE2EEndpoints(t, nodeCount) + attemptDir := filepath.Join(baseDir, fmt.Sprintf("attempt-%d", attempt)) + started, err := tryStartBootstrapE2ECluster(attemptDir, endpoints, listeners) + if err == nil { + return endpoints, started + } + closeBootstrapE2ENodesIgnoreError(started) + closeBootstrapE2EListeners(listeners) + lastErr = err + if !isAddressInUseError(err) { + break + } + nodes = nil + } + + require.NoError(t, lastErr) + return nil, nodes +} + +func allocateBootstrapE2EEndpoints(t *testing.T, nodeCount int) ([]bootstrapE2EEndpoint, []bootstrapE2EListeners) { + t.Helper() + + var lc net.ListenConfig + endpoints := make([]bootstrapE2EEndpoint, 0, nodeCount) + listeners := make([]bootstrapE2EListeners, 0, nodeCount) + for i := 0; i < nodeCount; i++ { + grpcL, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + redisL, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + dynamoL, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + + endpoints = append(endpoints, bootstrapE2EEndpoint{ + id: fmt.Sprintf("n%d", i+1), + raftAddr: grpcL.Addr().String(), + redisAddr: redisL.Addr().String(), + dynamoAddr: dynamoL.Addr().String(), + }) + listeners = append(listeners, bootstrapE2EListeners{ + grpc: grpcL, + redis: redisL, + dynamo: dynamoL, + }) + } + return endpoints, listeners +} + +func tryStartBootstrapE2ECluster(baseDir string, endpoints []bootstrapE2EEndpoint, listeners []bootstrapE2EListeners) ([]*bootstrapE2ENode, error) { + bootstrapMembers := bootstrapMembersArg(endpoints) + nodes := make([]*bootstrapE2ENode, 0, len(endpoints)) + for i := range endpoints { + node, err := startBootstrapE2ENode(baseDir, endpoints[i], listeners[i], true, bootstrapMembers) + if err != nil { + return nodes, err + } + nodes = append(nodes, node) + } + return nodes, nil +} + +func closeBootstrapE2EListeners(listeners []bootstrapE2EListeners) { + for _, lis := range listeners { + if lis.grpc != nil { + _ = lis.grpc.Close() + } + if lis.redis != nil { + _ = lis.redis.Close() + } + if lis.dynamo != nil { + _ = lis.dynamo.Close() + } + } +} + +func closeBootstrapE2ENodesIgnoreError(nodes []*bootstrapE2ENode) { + for _, n := range nodes { + if n == nil { + continue + } + _ = n.close() + } +} + +func isAddressInUseError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, syscall.EADDRINUSE) { + return true + } + var opErr *net.OpError + if errors.As(err, &opErr) && errors.Is(opErr.Err, syscall.EADDRINUSE) { + return true + } + var sysErr *os.SyscallError + if errors.As(err, &sysErr) && errors.Is(sysErr.Err, syscall.EADDRINUSE) { + return true + } + return false +} + +func bootstrapMembersArg(endpoints []bootstrapE2EEndpoint) string { + parts := make([]string, 0, len(endpoints)) + for _, ep := range endpoints { + parts = append(parts, fmt.Sprintf("%s=%s", ep.id, ep.raftAddr)) + } + return strings.Join(parts, ",") +} + +func bootstrapExpectedServers(endpoints []bootstrapE2EEndpoint) []raft.Server { + servers := make([]raft.Server, 0, len(endpoints)) + for _, ep := range endpoints { + servers = append(servers, raft.Server{ + Suffrage: raft.Voter, + ID: raft.ServerID(ep.id), + Address: raft.ServerAddress(ep.raftAddr), + }) + } + return servers +} + +func startBootstrapE2ENode( + baseDir string, + ep bootstrapE2EEndpoint, + listeners bootstrapE2EListeners, + bootstrap bool, + bootstrapMembers string, +) (*bootstrapE2ENode, error) { + cfg, err := parseRuntimeConfig(ep.raftAddr, ep.redisAddr, "", "", "") + if err != nil { + return nil, err + } + + bootstrapServers, err := resolveBootstrapServers(ep.id, cfg.groups, bootstrapMembers) + if err != nil { + return nil, err + } + bootstrap = bootstrap || len(bootstrapServers) > 0 + + runtimes, shardGroups, err := buildShardGroups(ep.id, baseDir, cfg.groups, cfg.multi, bootstrap, bootstrapServers) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithCancel(context.Background()) + eg, runCtx := errgroup.WithContext(ctx) + + clock := kv.NewHLC() + shardStore := kv.NewShardStore(cfg.engine, shardGroups) + coordinate := kv.NewShardedCoordinator(cfg.engine, shardGroups, cfg.defaultGroup, clock, shardStore) + distCatalog, err := setupDistributionCatalog(runCtx, runtimes, cfg.engine) + if err != nil { + cancel() + _ = shardStore.Close() + for _, rt := range runtimes { + rt.Close() + } + return nil, err + } + + eg.Go(func() error { + return runDistributionCatalogWatcher(runCtx, distCatalog, cfg.engine) + }) + + distServer := adapter.NewDistributionServer( + cfg.engine, + distCatalog, + adapter.WithDistributionCoordinator(coordinate), + ) + + err = startRuntimeServersWithBoundListeners( + runCtx, + eg, + cancel, + runtimes, + shardStore, + coordinate, + distServer, + cfg.leaderRedis, + listeners, + ) + if err != nil { + _ = shardStore.Close() + for _, rt := range runtimes { + rt.Close() + } + return nil, err + } + + return &bootstrapE2ENode{ + id: ep.id, + runtimes: runtimes, + shardStore: shardStore, + cancel: cancel, + eg: eg, + }, nil +} + +func startRuntimeServersWithBoundListeners( + ctx context.Context, + eg *errgroup.Group, + cancel context.CancelFunc, + runtimes []*raftGroupRuntime, + shardStore *kv.ShardStore, + coordinate kv.Coordinator, + distServer *adapter.DistributionServer, + leaderRedis map[raft.ServerAddress]string, + listeners bootstrapE2EListeners, +) error { + if len(runtimes) != 1 { + return waitErrgroupAfterStartupFailure(cancel, eg, fmt.Errorf("expected exactly one runtime, got %d", len(runtimes))) + } + rt := runtimes[0] + + if err := startBoundGRPCServer(ctx, eg, rt, shardStore, coordinate, distServer, listeners.grpc); err != nil { + return waitErrgroupAfterStartupFailure(cancel, eg, err) + } + if err := startBoundRedisServer(ctx, eg, listeners.redis, shardStore, coordinate, leaderRedis); err != nil { + return waitErrgroupAfterStartupFailure(cancel, eg, err) + } + if err := startBoundDynamoDBServer(ctx, eg, listeners.dynamo, shardStore, coordinate); err != nil { + return waitErrgroupAfterStartupFailure(cancel, eg, err) + } + return nil +} + +func startBoundGRPCServer( + ctx context.Context, + eg *errgroup.Group, + rt *raftGroupRuntime, + shardStore *kv.ShardStore, + coordinate kv.Coordinator, + distServer *adapter.DistributionServer, + listener net.Listener, +) error { + if rt == nil || rt.raft == nil || rt.tm == nil { + return fmt.Errorf("raft runtime is not ready") + } + if listener == nil { + return fmt.Errorf("grpc listener is required") + } + + gs := grpc.NewServer() + trx := kv.NewTransaction(rt.raft) + grpcSvc := adapter.NewGRPCServer(shardStore, coordinate) + pb.RegisterRawKVServer(gs, grpcSvc) + pb.RegisterTransactionalKVServer(gs, grpcSvc) + pb.RegisterInternalServer(gs, adapter.NewInternal(trx, rt.raft, coordinate.Clock())) + pb.RegisterDistributionServer(gs, distServer) + rt.tm.Register(gs) + leaderhealth.Setup(rt.raft, gs, []string{"RawKV"}) + raftadmin.Register(gs, rt.raft) + reflection.Register(gs) + + srv := gs + lis := listener + grpcService := grpcSvc + eg.Go(func() error { + var closeOnce sync.Once + closeService := func() { + closeOnce.Do(func() { _ = grpcService.Close() }) + } + stop := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + srv.GracefulStop() + _ = lis.Close() + closeService() + case <-stop: + } + }() + err := srv.Serve(lis) + close(stop) + closeService() + if errors.Is(err, grpc.ErrServerStopped) || errors.Is(err, net.ErrClosed) { + return nil + } + return err + }) + return nil +} + +func startBoundRedisServer( + ctx context.Context, + eg *errgroup.Group, + listener net.Listener, + shardStore *kv.ShardStore, + coordinate kv.Coordinator, + leaderRedis map[raft.ServerAddress]string, +) error { + if listener == nil { + return fmt.Errorf("redis listener is required") + } + redisServer := adapter.NewRedisServer(listener, shardStore, coordinate, leaderRedis) + eg.Go(func() error { + defer redisServer.Stop() + stop := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + redisServer.Stop() + case <-stop: + } + }() + err := redisServer.Run() + close(stop) + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + }) + return nil +} + +func startBoundDynamoDBServer( + ctx context.Context, + eg *errgroup.Group, + listener net.Listener, + shardStore *kv.ShardStore, + coordinate kv.Coordinator, +) error { + if listener == nil { + return fmt.Errorf("dynamodb listener is required") + } + dynamoServer := adapter.NewDynamoDBServer(listener, shardStore, coordinate) + eg.Go(func() error { + defer dynamoServer.Stop() + stop := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + dynamoServer.Stop() + case <-stop: + } + }() + err := dynamoServer.Run() + close(stop) + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + }) + return nil +} + +func closeBootstrapE2ENodes(t *testing.T, nodes []*bootstrapE2ENode) { + t.Helper() + for _, n := range nodes { + require.NoError(t, n.close()) + } +} + +func waitForBootstrapClusterConfig(t *testing.T, nodes []*bootstrapE2ENode, expected []raft.Server, waitTimeout, waitInterval time.Duration) { + t.Helper() + + require.Eventually(t, func() bool { + for _, n := range nodes { + r := n.raft() + if r == nil { + return false + } + future := r.GetConfiguration() + if err := future.Error(); err != nil { + return false + } + current := future.Configuration().Servers + if len(current) != len(expected) { + return false + } + for _, server := range expected { + if !containsRaftServer(current, server) { + return false + } + } + } + return true + }, waitTimeout, waitInterval) +} + +func waitForSingleLeader(t *testing.T, nodes []*bootstrapE2ENode, waitTimeout, waitInterval time.Duration) int { + t.Helper() + + leaderIdx := -1 + require.Eventually(t, func() bool { + idx := -1 + leaders := 0 + for i, n := range nodes { + r := n.raft() + if r == nil { + return false + } + if r.State() == raft.Leader { + idx = i + leaders++ + } + } + if leaders != 1 { + return false + } + leaderIdx = idx + return true + }, waitTimeout, waitInterval) + return leaderIdx +} + +func rawKVClients(t *testing.T, endpoints []bootstrapE2EEndpoint) ([]pb.RawKVClient, []*grpc.ClientConn) { + t.Helper() + + clients := make([]pb.RawKVClient, 0, len(endpoints)) + conns := make([]*grpc.ClientConn, 0, len(endpoints)) + for _, ep := range endpoints { + conn, err := grpc.NewClient(ep.raftAddr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + conns = append(conns, conn) + clients = append(clients, pb.NewRawKVClient(conn)) + } + return clients, conns +} + +func rawPutWithTimeout(client pb.RawKVClient, key []byte, value []byte, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + _, err := client.RawPut(ctx, &pb.RawPutRequest{Key: key, Value: value}) + return err +} + +func rawGetWithTimeout(client pb.RawKVClient, key []byte, timeout time.Duration) (*pb.RawGetResponse, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return client.RawGet(ctx, &pb.RawGetRequest{Key: key}) +} + +func containsRaftServer(servers []raft.Server, expected raft.Server) bool { + for _, s := range servers { + if s.ID == expected.ID && s.Address == expected.Address && s.Suffrage == expected.Suffrage { + return true + } + } + return false +} diff --git a/main_bootstrap_test.go b/main_bootstrap_test.go new file mode 100644 index 00000000..4ca617bb --- /dev/null +++ b/main_bootstrap_test.go @@ -0,0 +1,69 @@ +package main + +import ( + "testing" + + "github.com/hashicorp/raft" + "github.com/stretchr/testify/require" +) + +func TestResolveBootstrapServers(t *testing.T) { + t.Run("members imply fixed bootstrap servers", func(t *testing.T) { + servers, err := resolveBootstrapServers("n1", []groupSpec{{id: 1, address: "10.0.0.11:50051"}}, "n1=10.0.0.11:50051") + require.NoError(t, err) + require.Equal(t, []raft.Server{ + {Suffrage: raft.Voter, ID: "n1", Address: "10.0.0.11:50051"}, + }, servers) + }) + + t.Run("empty members returns nil", func(t *testing.T) { + servers, err := resolveBootstrapServers("n1", []groupSpec{{id: 1, address: "10.0.0.11:50051"}}, "") + require.NoError(t, err) + require.Nil(t, servers) + }) + + t.Run("single group fixed members", func(t *testing.T) { + servers, err := resolveBootstrapServers( + "n1", + []groupSpec{{id: 1, address: "10.0.0.11:50051"}}, + "n1=10.0.0.11:50051,n2=10.0.0.12:50051", + ) + require.NoError(t, err) + require.Equal(t, []raft.Server{ + {Suffrage: raft.Voter, ID: "n1", Address: "10.0.0.11:50051"}, + {Suffrage: raft.Voter, ID: "n2", Address: "10.0.0.12:50051"}, + }, servers) + }) + + t.Run("multiple groups are rejected", func(t *testing.T) { + _, err := resolveBootstrapServers( + "n1", + []groupSpec{{id: 1, address: "10.0.0.11:50051"}, {id: 2, address: "10.0.0.11:50052"}}, + "n1=10.0.0.11:50051,n2=10.0.0.12:50051", + ) + require.ErrorIs(t, err, ErrBootstrapMembersRequireSingleGroup) + }) + + t.Run("missing local member is rejected", func(t *testing.T) { + _, err := resolveBootstrapServers( + "n1", + []groupSpec{{id: 1, address: "10.0.0.11:50051"}}, + "n2=10.0.0.12:50051", + ) + require.ErrorIs(t, err, ErrBootstrapMembersMissingLocalNode) + }) + + t.Run("local address mismatch is rejected", func(t *testing.T) { + _, err := resolveBootstrapServers( + "n1", + []groupSpec{{id: 1, address: "10.0.0.11:50051"}}, + "n1=10.0.0.99:50051,n2=10.0.0.12:50051", + ) + require.ErrorIs(t, err, ErrBootstrapMembersLocalAddrMismatch) + }) + + t.Run("only separators are rejected", func(t *testing.T) { + _, err := resolveBootstrapServers("n1", []groupSpec{{id: 1, address: "10.0.0.11:50051"}}, " , , ") + require.ErrorIs(t, err, ErrNoBootstrapMembersConfigured) + }) +} diff --git a/multiraft_runtime.go b/multiraft_runtime.go index d39b0b7a..325b863f 100644 --- a/multiraft_runtime.go +++ b/multiraft_runtime.go @@ -76,7 +76,7 @@ func groupDataDir(baseDir, raftID string, groupID uint64, multi bool) string { return filepath.Join(baseDir, raftID, fmt.Sprintf("group-%d", groupID)) } -func newRaftGroup(raftID string, group groupSpec, baseDir string, multi bool, bootstrap bool, fsm raft.FSM) (*raft.Raft, *transport.Manager, func(), error) { +func newRaftGroup(raftID string, group groupSpec, baseDir string, multi bool, bootstrap bool, bootstrapServers []raft.Server, fsm raft.FSM) (*raft.Raft, *transport.Manager, func(), error) { c := raft.DefaultConfig() c.LocalID = raft.ServerID(raftID) c.HeartbeatTimeout = heartbeatTimeout @@ -127,15 +127,17 @@ func newRaftGroup(raftID string, group groupSpec, baseDir string, multi bool, bo } if bootstrap { - cfg := raft.Configuration{ - Servers: []raft.Server{ + servers := bootstrapServers + if len(servers) == 0 { + servers = []raft.Server{ { Suffrage: raft.Voter, ID: raft.ServerID(raftID), Address: raft.ServerAddress(group.address), }, - }, + } } + cfg := raft.Configuration{Servers: servers} f := r.BootstrapCluster(cfg) if err := f.Error(); err != nil { _ = r.Shutdown().Error() diff --git a/multiraft_runtime_test.go b/multiraft_runtime_test.go index e555cecb..282db3c5 100644 --- a/multiraft_runtime_test.go +++ b/multiraft_runtime_test.go @@ -37,6 +37,7 @@ func TestNewRaftGroupBootstrap(t *testing.T) { baseDir, true, // multi true, // bootstrap + nil, fsm, ) require.NoError(t, err) diff --git a/shard_config.go b/shard_config.go index 5716a64f..d0ecc03d 100644 --- a/shard_config.go +++ b/shard_config.go @@ -27,9 +27,10 @@ var ( ErrNoRaftGroupsConfigured = errors.New("no raft groups configured") ErrNoShardRangesConfigured = errors.New("no shard ranges configured") - ErrInvalidRaftGroupsEntry = errors.New("invalid raftGroups entry") - ErrInvalidShardRangesEntry = errors.New("invalid shardRanges entry") - ErrInvalidRaftRedisMapEntry = errors.New("invalid raftRedisMap entry") + ErrInvalidRaftGroupsEntry = errors.New("invalid raftGroups entry") + ErrInvalidShardRangesEntry = errors.New("invalid shardRanges entry") + ErrInvalidRaftRedisMapEntry = errors.New("invalid raftRedisMap entry") + ErrInvalidRaftBootstrapMembersEntry = errors.New("invalid raftBootstrapMembers entry") ) func parseRaftGroups(raw, defaultAddr string) ([]groupSpec, error) { @@ -138,6 +139,41 @@ func parseRaftRedisMap(raw string) (map[raft.ServerAddress]string, error) { return out, nil } +func parseRaftBootstrapMembers(raw string) ([]raft.Server, error) { + servers := make([]raft.Server, 0) + if raw == "" { + return servers, nil + } + seen := make(map[raft.ServerID]struct{}) + parts := strings.Split(raw, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + kv := strings.SplitN(part, "=", splitParts) + if len(kv) != splitParts { + return nil, errors.Wrapf(ErrInvalidRaftBootstrapMembersEntry, "%q", part) + } + id := strings.TrimSpace(kv[0]) + addr := strings.TrimSpace(kv[1]) + if id == "" || addr == "" { + return nil, errors.Wrapf(ErrInvalidRaftBootstrapMembersEntry, "%q", part) + } + sid := raft.ServerID(id) + if _, exists := seen[sid]; exists { + return nil, errors.Wrapf(ErrInvalidRaftBootstrapMembersEntry, "duplicate id %q", id) + } + seen[sid] = struct{}{} + servers = append(servers, raft.Server{ + Suffrage: raft.Voter, + ID: sid, + Address: raft.ServerAddress(addr), + }) + } + return servers, nil +} + func defaultGroupID(groups []groupSpec) uint64 { min := uint64(0) for _, g := range groups { diff --git a/shard_config_test.go b/shard_config_test.go index 2f22ffa6..0865cd09 100644 --- a/shard_config_test.go +++ b/shard_config_test.go @@ -115,6 +115,33 @@ func TestParseRaftRedisMap(t *testing.T) { }) } +func TestParseRaftBootstrapMembers(t *testing.T) { + t.Run("parses members", func(t *testing.T) { + members, err := parseRaftBootstrapMembers("n1=10.0.0.11:50051, n2=10.0.0.12:50051") + require.NoError(t, err) + require.Equal(t, []raft.Server{ + {Suffrage: raft.Voter, ID: raft.ServerID("n1"), Address: raft.ServerAddress("10.0.0.11:50051")}, + {Suffrage: raft.Voter, ID: raft.ServerID("n2"), Address: raft.ServerAddress("10.0.0.12:50051")}, + }, members) + }) + + t.Run("trims whitespace", func(t *testing.T) { + members, err := parseRaftBootstrapMembers(" n1 = 10.0.0.11:50051 , n2=10.0.0.12:50051 ") + require.NoError(t, err) + require.Len(t, members, 2) + }) + + t.Run("duplicate id errors", func(t *testing.T) { + _, err := parseRaftBootstrapMembers("n1=a,n1=b") + require.ErrorIs(t, err, ErrInvalidRaftBootstrapMembersEntry) + }) + + t.Run("invalid entry errors", func(t *testing.T) { + _, err := parseRaftBootstrapMembers("n1=a,nope") + require.ErrorIs(t, err, ErrInvalidRaftBootstrapMembersEntry) + }) +} + func TestDefaultGroupID(t *testing.T) { require.Equal(t, uint64(1), defaultGroupID(nil)) require.Equal(t, uint64(2), defaultGroupID([]groupSpec{{id: 3}, {id: 2}}))