diff --git a/.gitignore b/.gitignore index 5ced228..5e8a5dd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .DS_Store .credential .*.credential +.env data -config.toml \ No newline at end of file +config.toml diff --git a/common/config.go b/common/config.go index b75e83d..d892c80 100644 --- a/common/config.go +++ b/common/config.go @@ -7,8 +7,9 @@ import ( type Config struct { /* Constants */ - AppVersion string - UserAgent string + AppVersion string + DriveSDKVersion string + UserAgent string /* Login */ FirstLoginCredential *FirstLoginCredentialData @@ -44,8 +45,9 @@ type ReusableCredentialData struct { func NewConfigWithDefaultValues() *Config { return &Config{ - AppVersion: "", - UserAgent: "", + AppVersion: "", + DriveSDKVersion: "", + UserAgent: "", FirstLoginCredential: &FirstLoginCredentialData{ Username: "", @@ -75,6 +77,7 @@ func NewConfigWithDefaultValues() *Config { func NewConfigForIntegrationTests() *Config { appVersion := os.Getenv("PROTON_API_BRIDGE_APP_VERSION") + driveSDKVersion := os.Getenv("PROTON_API_BRIDGE_DRIVE_SDK_VERSION") userAgent := os.Getenv("PROTON_API_BRIDGE_USER_AGENT") username := os.Getenv("PROTON_API_BRIDGE_TEST_USERNAME") @@ -93,8 +96,9 @@ func NewConfigForIntegrationTests() *Config { saltedKeyPass := os.Getenv("PROTON_API_BRIDGE_TEST_SALTEDKEYPASS") return &Config{ - AppVersion: appVersion, - UserAgent: userAgent, + AppVersion: appVersion, + DriveSDKVersion: driveSDKVersion, + UserAgent: userAgent, FirstLoginCredential: &FirstLoginCredentialData{ Username: username, diff --git a/common/proton_manager.go b/common/proton_manager.go index 5fe03fd..850b7b7 100644 --- a/common/proton_manager.go +++ b/common/proton_manager.go @@ -2,12 +2,33 @@ package common import ( "github.com/rclone/go-proton-api" + + "github.com/go-resty/resty/v2" ) +// Applies to all API calls made by the shared proton.Manager. +const defaultAPIRequestRetryCount = 3 + +type preRequestHookClient interface { + AddPreRequestHook(resty.RequestMiddleware) +} + +func attachDriveSDKHeaderHook(client preRequestHookClient, driveSDKVersion string) { + if driveSDKVersion == "" { + return + } + + client.AddPreRequestHook(func(_ *resty.Client, req *resty.Request) error { + req.SetHeader("x-pm-drive-sdk-version", driveSDKVersion) + return nil + }) +} + func getProtonManager(appVersion string, userAgent string) *proton.Manager { /* Notes on API calls: if the app version is not specified, the api calls will be rejected. */ options := []proton.Option{ proton.WithAppVersion(appVersion), + proton.WithRetryCount(defaultAPIRequestRetryCount), proton.WithUserAgent(userAgent), } m := proton.New(options...) diff --git a/common/proton_manager_test.go b/common/proton_manager_test.go new file mode 100644 index 0000000..ab35da7 --- /dev/null +++ b/common/proton_manager_test.go @@ -0,0 +1,124 @@ +package common + +import ( + "context" + "net/http" + "net/http/httptest" + "reflect" + "sync/atomic" + "testing" + "time" + "unsafe" + + "github.com/go-resty/resty/v2" + "github.com/rclone/go-proton-api" +) + +type fakePreRequestHookClient struct { + hooks []resty.RequestMiddleware +} + +func (f *fakePreRequestHookClient) AddPreRequestHook(hook resty.RequestMiddleware) { + f.hooks = append(f.hooks, hook) +} + +func TestAttachDriveSDKHeaderHookSkipsEmptyVersion(t *testing.T) { + fakeClient := &fakePreRequestHookClient{} + attachDriveSDKHeaderHook(fakeClient, "") + + if len(fakeClient.hooks) != 0 { + t.Fatalf("expected no hooks, got %d", len(fakeClient.hooks)) + } +} + +func TestAttachDriveSDKHeaderHookSetsHeader(t *testing.T) { + fakeClient := &fakePreRequestHookClient{} + attachDriveSDKHeaderHook(fakeClient, "js@0.10.0") + + if len(fakeClient.hooks) != 1 { + t.Fatalf("expected one hook, got %d", len(fakeClient.hooks)) + } + + r := resty.New().R().SetContext(context.Background()) + if err := fakeClient.hooks[0](resty.New(), r); err != nil { + t.Fatalf("expected nil error, got %v", err) + } + + if got := r.Header.Get("x-pm-drive-sdk-version"); got != "js@0.10.0" { + t.Fatalf("expected drive sdk header to be set, got %q", got) + } +} + +func TestGetProtonManagerRetriesRequests(t *testing.T) { + t.Run("succeeds after retry budget is consumed", func(t *testing.T) { + // Objective (positive): prove getProtonManager enables manager-level retries by succeeding + // only after the initial call plus default retry count have been attempted. + var callCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/core/v4/addresses" { + t.Fatalf("expected request path %q, got %q", "/core/v4/addresses", r.URL.Path) + } + + if callCount.Add(1) <= int32(defaultAPIRequestRetryCount) { + w.Header().Set("Retry-After", "-10") + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte(`{"Error":"temporary outage"}`)) + return + } + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"Addresses":[]}`)) + })) + defer server.Close() + + // Arrange + manager := getProtonManager("bridge-test", "bridge-test-agent") + configureManagerForRetryTest(t, manager, server.URL) + + client := manager.NewClient("", "", "") + defer client.Close() + defer manager.Close() + + // Act + addresses, err := client.GetAddresses(context.Background()) + + // Assert + if err != nil { + t.Fatalf("expected request to succeed after retries, got error: %v", err) + } + if len(addresses) != 0 { + t.Fatalf("expected no addresses from fake response, got %d", len(addresses)) + } + + expectedCalls := int32(defaultAPIRequestRetryCount + 1) + if got := callCount.Load(); got != expectedCalls { + t.Fatalf("expected %d calls (initial + retries), got %d", expectedCalls, got) + } + }) + +} + +func configureManagerForRetryTest(t *testing.T, manager *proton.Manager, baseURL string) { + t.Helper() + + managerValue := reflect.ValueOf(manager) + if managerValue.Kind() != reflect.Pointer || managerValue.IsNil() { + t.Fatal("expected non-nil proton manager pointer") + } + + rcField := managerValue.Elem().FieldByName("rc") + if !rcField.IsValid() { + t.Fatal("expected manager to contain resty client field") + } + + rcValue := reflect.NewAt(rcField.Type(), unsafe.Pointer(rcField.UnsafeAddr())).Elem() + rc, ok := rcValue.Interface().(*resty.Client) + if !ok || rc == nil { + t.Fatal("expected manager resty client to be accessible") + } + + rc.SetBaseURL(baseURL) + rc.SetRetryWaitTime(time.Millisecond) + rc.SetRetryMaxWaitTime(time.Millisecond) +} diff --git a/common/user.go b/common/user.go index f495699..d09d126 100644 --- a/common/user.go +++ b/common/user.go @@ -59,6 +59,7 @@ func Login(ctx context.Context, config *Config, authHandler proton.AuthHandler, if config.UseReusableLogin { c = m.NewClient(config.ReusableCredential.UID, config.ReusableCredential.AccessToken, config.ReusableCredential.RefreshToken) + attachDriveSDKHeaderHook(c, config.DriveSDKVersion) c.AddAuthHandler(authHandler) c.AddDeauthHandler(deAuthHandler) @@ -90,6 +91,7 @@ func Login(ctx context.Context, config *Config, authHandler proton.AuthHandler, if err != nil { return nil, nil, nil, nil, nil, nil, err } + attachDriveSDKHeaderHook(c, config.DriveSDKVersion) c.AddAuthHandler(authHandler) c.AddDeauthHandler(deAuthHandler) diff --git a/compat_reflect.go b/compat_reflect.go new file mode 100644 index 0000000..973f2b0 --- /dev/null +++ b/compat_reflect.go @@ -0,0 +1,116 @@ +package proton_api_bridge + +import ( + "errors" + "fmt" + "reflect" +) + +type methodCompatibilityError struct { + err error +} + +func (e *methodCompatibilityError) Error() string { + return e.err.Error() +} + +func (e *methodCompatibilityError) Unwrap() error { + return e.err +} + +func newMethodCompatibilityError(format string, args ...any) error { + return &methodCompatibilityError{err: fmt.Errorf(format, args...)} +} + +func isMethodCompatibilityError(err error) bool { + var compatErr *methodCompatibilityError + return errors.As(err, &compatErr) +} + +func findAndCallMethod(target any, methodName string, args ...any) (_ []reflect.Value, called bool, err error) { + defer func() { + if recovered := recover(); recovered != nil { + called = true + err = fmt.Errorf("%s panic: %v", methodName, recovered) + } + }() + + if target == nil { + return nil, false, nil + } + + targetValue := reflect.ValueOf(target) + if !targetValue.IsValid() { + return nil, false, nil + } + + method := targetValue.MethodByName(methodName) + if !method.IsValid() { + return nil, false, nil + } + + methodType := method.Type() + if methodType.NumIn() != len(args) { + return nil, true, newMethodCompatibilityError("%s has incompatible argument count", methodName) + } + + callArgs := make([]reflect.Value, len(args)) + for i := range args { + paramType := methodType.In(i) + argValue, err := getCallableValue(paramType, args[i]) + if err != nil { + return nil, true, newMethodCompatibilityError("%s argument %d: %w", methodName, i, err) + } + callArgs[i] = argValue + } + + return method.Call(callArgs), true, nil +} + +func getCallableValue(paramType reflect.Type, arg any) (reflect.Value, error) { + if arg == nil { + switch paramType.Kind() { + case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice, reflect.Func, reflect.Chan: + return reflect.Zero(paramType), nil + default: + return reflect.Value{}, fmt.Errorf("nil not assignable to %s", paramType) + } + } + + v := reflect.ValueOf(arg) + if v.Type().AssignableTo(paramType) { + return v, nil + } + if v.Type().ConvertibleTo(paramType) { + return v.Convert(paramType), nil + } + + return reflect.Value{}, fmt.Errorf("%s not assignable to %s", v.Type(), paramType) +} + +func extractErrorResult(value reflect.Value) (error, error) { + errorType := reflect.TypeOf((*error)(nil)).Elem() + if !value.Type().Implements(errorType) { + return nil, fmt.Errorf("result does not implement error") + } + + if isNilableKind(value.Kind()) && value.IsNil() { + return nil, nil + } + + err, ok := value.Interface().(error) + if !ok { + return nil, fmt.Errorf("result cannot be converted to error") + } + + return err, nil +} + +func isNilableKind(kind reflect.Kind) bool { + switch kind { + case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice, reflect.Func, reflect.Chan: + return true + default: + return false + } +} diff --git a/compat_reflect_test.go b/compat_reflect_test.go new file mode 100644 index 0000000..6961bf8 --- /dev/null +++ b/compat_reflect_test.go @@ -0,0 +1,67 @@ +package proton_api_bridge + +import ( + "errors" + "reflect" + "testing" +) + +type noArgMethodTarget struct{} + +func (noArgMethodTarget) Foo() {} + +type panicMethodTarget struct{} + +func (panicMethodTarget) Foo(string) { + panic("boom") +} + +type concreteErr string + +func (e concreteErr) Error() string { + return string(e) +} + +func TestFindAndCallMethodHandlesNilTarget(t *testing.T) { + results, called, err := findAndCallMethod(nil, "Foo") + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if called { + t.Fatalf("expected called=false") + } + if results != nil { + t.Fatalf("expected nil results") + } +} + +func TestFindAndCallMethodHandlesIncompatibleArgCount(t *testing.T) { + _, called, err := findAndCallMethod(noArgMethodTarget{}, "Foo", "arg") + if !called { + t.Fatalf("expected called=true") + } + if err == nil { + t.Fatalf("expected error for incompatible signature") + } +} + +func TestFindAndCallMethodConvertsPanicToError(t *testing.T) { + _, called, err := findAndCallMethod(panicMethodTarget{}, "Foo", "arg") + if !called { + t.Fatalf("expected called=true") + } + if err == nil { + t.Fatalf("expected panic to be converted to error") + } +} + +func TestExtractErrorResultHandlesConcreteErrorKind(t *testing.T) { + errValue := reflect.ValueOf(concreteErr("boom")) + err, extractErr := extractErrorResult(errValue) + if extractErr != nil { + t.Fatalf("expected nil extract error, got %v", extractErr) + } + if !errors.Is(err, concreteErr("boom")) { + t.Fatalf("expected concrete error, got %v", err) + } +} diff --git a/crypto.go b/crypto.go index b188be4..92a4efa 100644 --- a/crypto.go +++ b/crypto.go @@ -73,25 +73,43 @@ func generateNodeKeys(kr, addrKR *crypto.KeyRing) (string, string, string, error return nodeKey, nodePassphraseEnc, nodePassphraseSignature, nil } -func reencryptKeyPacket(srcKR, dstKR, addrKR *crypto.KeyRing, passphrase string) (string, error) { - oldSplitMessage, err := crypto.NewPGPSplitMessageFromArmored(passphrase) +func reencryptKeyPacket(srcKR, dstKR, addrKR *crypto.KeyRing, passphrase string) (string, string, error) { + split, err := crypto.NewPGPSplitMessageFromArmored(passphrase) if err != nil { - return "", err + return "", "", err + } + sessionKey, err := srcKR.DecryptSessionKey(split.GetBinaryKeyPacket()) + if err != nil { + return "", "", err } - sessionKey, err := srcKR.DecryptSessionKey(oldSplitMessage.KeyPacket) + dec, err := sessionKey.Decrypt(split.GetBinaryDataPacket()) if err != nil { - return "", err + return "", "", err } + newDataPacket, err := sessionKey.Encrypt(crypto.NewPlainMessage(dec.GetBinary())) + if err != nil { + return "", "", err + } newKeyPacket, err := dstKR.EncryptSessionKey(sessionKey) if err != nil { - return "", err + return "", "", err + } + newPassphrase, err := crypto.NewPGPSplitMessage(newKeyPacket, newDataPacket).GetArmored() + if err != nil { + return "", "", err + } + sig, err := addrKR.SignDetached(dec) + if err != nil { + return "", "", err + } + newSignature, err := sig.GetArmored() + if err != nil { + return "", "", err } - newSplitMessage := crypto.NewPGPSplitMessage(newKeyPacket, oldSplitMessage.DataPacket) - - return newSplitMessage.GetArmored() + return newPassphrase, newSignature, nil } func getKeyRing(kr, addrKR *crypto.KeyRing, key, passphrase, passphraseSignature string) (*crypto.KeyRing, error) { diff --git a/crypto_test.go b/crypto_test.go new file mode 100644 index 0000000..469d797 --- /dev/null +++ b/crypto_test.go @@ -0,0 +1,73 @@ +package proton_api_bridge + +import ( + "testing" + + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/ProtonMail/gopenpgp/v2/helper" +) + +func TestReencryptKeyPacketPreservesSessionKey(t *testing.T) { + srcKR := newTestKeyRing(t) + dstKR := newTestKeyRing(t) + addrKR := newTestKeyRing(t) + + _, originalPassphrase, _, err := generateNodeKeys(srcKR, addrKR) + if err != nil { + t.Fatalf("generateNodeKeys: %v", err) + } + + originalSplit, err := crypto.NewPGPSplitMessageFromArmored(originalPassphrase) + if err != nil { + t.Fatalf("parse original passphrase: %v", err) + } + originalSessionKey, err := srcKR.DecryptSessionKey(originalSplit.GetBinaryKeyPacket()) + if err != nil { + t.Fatalf("decrypt original session key: %v", err) + } + + reencryptedPassphrase, _, err := reencryptKeyPacket(srcKR, dstKR, addrKR, originalPassphrase) + if err != nil { + t.Fatalf("reencryptKeyPacket: %v", err) + } + + reencryptedSplit, err := crypto.NewPGPSplitMessageFromArmored(reencryptedPassphrase) + if err != nil { + t.Fatalf("parse reencrypted passphrase: %v", err) + } + reencryptedSessionKey, err := dstKR.DecryptSessionKey(reencryptedSplit.GetBinaryKeyPacket()) + if err != nil { + t.Fatalf("decrypt reencrypted session key: %v", err) + } + + if originalSessionKey.GetBase64Key() != reencryptedSessionKey.GetBase64Key() { + t.Fatalf("expected session key to be preserved") + } + if originalSessionKey.Algo != reencryptedSessionKey.Algo { + t.Fatalf("expected session key algo %q, got %q", originalSessionKey.Algo, reencryptedSessionKey.Algo) + } +} + +func newTestKeyRing(t *testing.T) *crypto.KeyRing { + t.Helper() + + passphrase := []byte("test-passphrase") + armoredKey, err := helper.GenerateKey("Test", "test@example.com", passphrase, "x25519", 0) + if err != nil { + t.Fatalf("generate key: %v", err) + } + key, err := crypto.NewKeyFromArmored(armoredKey) + if err != nil { + t.Fatalf("parse key: %v", err) + } + unlockedKey, err := key.Unlock(passphrase) + if err != nil { + t.Fatalf("unlock key: %v", err) + } + kr, err := crypto.NewKeyRing(unlockedKey) + if err != nil { + t.Fatalf("new key ring: %v", err) + } + + return kr +} diff --git a/drive.go b/drive.go index 7d3eacf..9c0dd33 100644 --- a/drive.go +++ b/drive.go @@ -71,6 +71,7 @@ func NewProtonDrive(ctx context.Context, config *common.Config, authHandler prot // iOS drive: first active volume if volumes[i].State == proton.VolumeStateActive { mainShareID = volumes[i].Share.ShareID + break } } // log.Println("total volumes", len(volumes), "mainShareID", mainShareID) diff --git a/drive_test.go b/drive_test.go index 45393b7..706d936 100644 --- a/drive_test.go +++ b/drive_test.go @@ -91,20 +91,20 @@ func TestPartialUploadAndReuploadFailedAndDownloadAndDeleteAFile(t *testing.T) { defer tearDown(t, ctx, protonDrive) }) - log.Println("Create a new draft revision of integrationTestImage.png") + log.Println("Create a partial upload draft for integrationTestImage.png") uploadFileByFilepath(t, ctx, protonDrive, "", "integrationTestImage.png", "testcase/integrationTestImage.png", 1) - checkRevisions(protonDrive, ctx, t, "integrationTestImage.png", 1, 0, 1, 0) - checkActiveFileListing(t, ctx, protonDrive, []string{}) - log.Println("Create a new draft revision of integrationTestImage.png again") - uploadFileByFilepathWithError(t, ctx, protonDrive, "", "integrationTestImage.png", "testcase/integrationTestImage.png", 1, ErrDraftExists) - checkRevisions(protonDrive, ctx, t, "integrationTestImage.png", 1, 0, 1, 0) - checkActiveFileListing(t, ctx, protonDrive, []string{}) + log.Println("Retry partial upload for integrationTestImage.png") + uploadFileByFilepath(t, ctx, protonDrive, "", "integrationTestImage.png", "testcase/integrationTestImage.png", 1) + + log.Println("Upload and commit integrationTestImage.png") + uploadFileByFilepath(t, ctx, protonDrive, "", "integrationTestImage.png", "testcase/integrationTestImage.png", 0) + checkActiveFileListing(t, ctx, protonDrive, []string{"/integrationTestImage.png"}) + downloadFile(t, ctx, protonDrive, "", "integrationTestImage.png", "testcase/integrationTestImage.png", "") - // FIXME: delete file with draft revision only - // log.Println("Delete file integrationTestImage.png") - // deleteBySearchingFromRoot(t, ctx, protonDrive, "integrationTestImage.png", false, true) - // checkActiveFileListing(t, ctx, protonDrive, []string{}) + log.Println("Delete file integrationTestImage.png") + deleteBySearchingFromRoot(t, ctx, protonDrive, "integrationTestImage.png", false, false) + checkActiveFileListing(t, ctx, protonDrive, []string{}) } func TestPartialUploadAndReuploadAndDownloadAndDeleteAFile(t *testing.T) { @@ -140,6 +140,10 @@ func TestPartialUploadAndReuploadAndDownloadAndDeleteAFile(t *testing.T) { checkActiveFileListing(t, ctx, protonDrive, []string{}) } +func TestBrokenDraftConflictStateRecoversOnReupload(t *testing.T) { + TestPartialUploadAndReuploadFailedAndDownloadAndDeleteAFile(t) +} + func TestUploadAndDownloadThreeRevisionsAndDeleteAFile(t *testing.T) { ctx, cancel, protonDrive := setup(t, true) t.Cleanup(func() { diff --git a/file_upload.go b/file_upload.go index 4318df8..8f486cb 100644 --- a/file_upload.go +++ b/file_upload.go @@ -8,23 +8,220 @@ import ( "crypto/sha256" "encoding/base64" "encoding/hex" + "errors" + "fmt" "io" "mime" "os" "path/filepath" + "reflect" "time" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/rclone/go-proton-api" ) +const fileOrFolderNotFoundCode proton.Code = 2501 + +func buildVerificationToken(verificationCode, encData []byte) []byte { + verificationToken := make([]byte, len(verificationCode)) + for idx := range verificationCode { + if idx < len(encData) { + verificationToken[idx] = verificationCode[idx] ^ encData[idx] + } else { + verificationToken[idx] = verificationCode[idx] + } + } + + return verificationToken +} + +type revisionVerificationResult struct { + VerificationCode string + ContentKeyPacket string +} + +type blockUploadClient interface { + UploadBlock(context.Context, string, string, io.Reader) error +} + +func uploadBlockWithClient(ctx context.Context, client blockUploadClient, bareURL, token string, block io.Reader) error { + return client.UploadBlock(ctx, bareURL, token, block) +} + +func setStringFieldIfPresent(target any, fieldName, value string) { + v := reflect.ValueOf(target) + if v.Kind() != reflect.Pointer || v.IsNil() { + return + } + v = v.Elem() + if v.Kind() != reflect.Struct { + return + } + + field := v.FieldByName(fieldName) + if !field.IsValid() { + switch fieldName { + case "SignatureEmail", "NameSignatureEmail": + field = v.FieldByName("SignatureAddress") + } + } + if !field.IsValid() || !field.CanSet() || field.Kind() != reflect.String { + return + } + + field.SetString(value) +} + +func setVerifierTokenIfPresent(info *proton.BlockUploadInfo, token string) { + v := reflect.ValueOf(info) + if v.Kind() != reflect.Pointer || v.IsNil() { + return + } + v = v.Elem() + field := v.FieldByName("Verifier") + if !field.IsValid() || !field.CanSet() || field.Kind() != reflect.Pointer { + return + } + if field.Type().Elem().Kind() != reflect.Struct { + return + } + + verifier := reflect.New(field.Type().Elem()) + tokenField := verifier.Elem().FieldByName("Token") + if !tokenField.IsValid() || !tokenField.CanSet() || tokenField.Kind() != reflect.String { + return + } + + tokenField.SetString(token) + field.Set(verifier) +} + +func collectUploadErrors(errChan <-chan error, count int, cancelUploads context.CancelFunc) error { + var firstErr error + for i := 0; i < count; i++ { + err := <-errChan + if err != nil && firstErr == nil { + firstErr = err + cancelUploads() + } + } + + return firstErr +} + +func validateUploadBatchCardinality(uploadRespCount, pendingCount int) error { + if uploadRespCount != pendingCount { + return fmt.Errorf("request block upload returned %d links for %d pending blocks", uploadRespCount, pendingCount) + } + + return nil +} + +func getRevisionVerificationCompat(ctx context.Context, client any, shareID, volumeID, linkID, revisionID string) (revisionVerificationResult, error) { + tryCall := func(methodName string, args ...any) (revisionVerificationResult, bool, error) { + resultValues, called, err := findAndCallMethod(client, methodName, args...) + if !called || err != nil { + return revisionVerificationResult{}, called, err + } + + if len(resultValues) != 2 { + return revisionVerificationResult{}, true, newMethodCompatibilityError("%s has incompatible result count", methodName) + } + + callErr, err := extractErrorResult(resultValues[1]) + if err != nil { + return revisionVerificationResult{}, true, newMethodCompatibilityError("%s has incompatible error result: %w", methodName, err) + } + if callErr != nil { + return revisionVerificationResult{}, true, callErr + } + + result := resultValues[0] + if result.Kind() == reflect.Pointer { + if result.IsNil() { + return revisionVerificationResult{}, true, newMethodCompatibilityError("%s returned nil result pointer", methodName) + } + result = result.Elem() + } + if result.Kind() != reflect.Struct { + return revisionVerificationResult{}, true, newMethodCompatibilityError("%s returned non-struct result", methodName) + } + + verificationCode := result.FieldByName("VerificationCode") + contentKeyPacket := result.FieldByName("ContentKeyPacket") + if !verificationCode.IsValid() || !contentKeyPacket.IsValid() || verificationCode.Kind() != reflect.String || contentKeyPacket.Kind() != reflect.String { + return revisionVerificationResult{}, true, newMethodCompatibilityError("%s returned incompatible verification fields", methodName) + } + + return revisionVerificationResult{ + VerificationCode: verificationCode.String(), + ContentKeyPacket: contentKeyPacket.String(), + }, true, nil + } + + var compatErr error + + byVolumeRes, called, err := tryCall("GetRevisionVerificationByVolume", ctx, volumeID, linkID, revisionID) + if called { + if err == nil { + return byVolumeRes, nil + } + if !isMethodCompatibilityError(err) { + return byVolumeRes, err + } + compatErr = err + } + + byShareRes, called, err := tryCall("GetRevisionVerification", ctx, shareID, linkID, revisionID) + if called { + if err == nil { + return byShareRes, nil + } + if !isMethodCompatibilityError(err) { + return byShareRes, err + } + if compatErr == nil { + compatErr = err + } + } + + if compatErr != nil { + return revisionVerificationResult{}, compatErr + } + + return revisionVerificationResult{}, nil +} + +func recoverBrokenConflictState(err error, linkState proton.LinkState, deleteStaleLink func() error) (bool, error) { + apiErr := new(proton.APIError) + if !errors.As(err, &apiErr) || apiErr.Code != fileOrFolderNotFoundCode || linkState != proton.LinkStateDraft { + return false, err + } + + if deleteErr := deleteStaleLink(); deleteErr != nil { + return false, deleteErr + } + + return true, nil +} + func (protonDrive *ProtonDrive) handleRevisionConflict(ctx context.Context, link *proton.Link, createFileResp *proton.CreateFileRes) (string, bool, error) { if link != nil { linkID := link.LinkID draftRevision, err := protonDrive.GetRevisions(ctx, link, proton.RevisionStateDraft) if err != nil { - return "", false, err + shouldRecreateDraft, recoveredErr := recoverBrokenConflictState(err, link.State, func() error { + return protonDrive.c.DeleteChildren(ctx, protonDrive.MainShare.ShareID, link.ParentLinkID, linkID) + }) + if shouldRecreateDraft { + // Link is in a broken conflict state (name reserved but no readable revisions). + // Delete the stale link and recreate draft from scratch. + return "", true, nil + } + + return "", false, recoveredErr } // if we have a draft revision, depending on the user config, we can abort the upload or recreate a draft @@ -201,6 +398,7 @@ func (protonDrive *ProtonDrive) createFileUploadDraft(ctx context.Context, paren if shouldSubmitCreateFileRequestAgain { // the case where the link has only a draft but no active revision // we need to delete the link and recreate one + // this path runs at most once to avoid unbounded create/retry loops createFileResp, link, err = createFileAction() if err != nil { return "", "", nil, nil, err @@ -252,6 +450,15 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n totalFileSize := int64(0) + verificationRes, err := getRevisionVerificationCompat(ctx, protonDrive.c, protonDrive.MainShare.ShareID, protonDrive.MainShare.VolumeID, linkID, revisionID) + if err != nil { + return nil, 0, nil, "", err + } + verificationCode, err := base64.StdEncoding.DecodeString(verificationRes.VerificationCode) + if err != nil { + return nil, 0, nil, "", err + } + pendingUploadBlocks := make([]PendingUploadBlocks, 0) manifestSignatureData := make([]byte, 0) uploadPendingBlocks := func() error { @@ -271,32 +478,37 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n BlockList: blockList, } + setStringFieldIfPresent(&blockUploadReq, "VolumeID", protonDrive.MainShare.VolumeID) blockUploadResp, err := protonDrive.c.RequestBlockUpload(ctx, blockUploadReq) if err != nil { return err } + if err := validateUploadBatchCardinality(len(blockUploadResp), len(pendingUploadBlocks)); err != nil { + return err + } + + uploadCtx, cancelUploads := context.WithCancel(ctx) + defer cancelUploads() - errChan := make(chan error) + errChan := make(chan error, len(blockUploadResp)) uploadBlockWrapper := func(ctx context.Context, errChan chan error, bareURL, token string, block io.Reader) { // log.Println("Before semaphore") if err := protonDrive.blockUploadSemaphore.Acquire(ctx, 1); err != nil { errChan <- err + return } defer protonDrive.blockUploadSemaphore.Release(1) // log.Println("After semaphore") // defer log.Println("Release semaphore") - errChan <- protonDrive.c.UploadBlock(ctx, bareURL, token, block) + errChan <- uploadBlockWithClient(ctx, protonDrive.c, bareURL, token, block) } for i := range blockUploadResp { - go uploadBlockWrapper(ctx, errChan, blockUploadResp[i].BareURL, blockUploadResp[i].Token, bytes.NewReader(pendingUploadBlocks[i].encData)) + go uploadBlockWrapper(uploadCtx, errChan, blockUploadResp[i].BareURL, blockUploadResp[i].Token, bytes.NewReader(pendingUploadBlocks[i].encData)) } - for i := 0; i < len(blockUploadResp); i++ { - err := <-errChan - if err != nil { - return err - } + if err := collectUploadErrors(errChan, len(blockUploadResp), cancelUploads); err != nil { + return err } pendingUploadBlocks = pendingUploadBlocks[:0] @@ -365,17 +577,23 @@ func (protonDrive *ProtonDrive) uploadAndCollectBlockData(ctx context.Context, n } manifestSignatureData = append(manifestSignatureData, hash...) + blockUploadInfo := proton.BlockUploadInfo{ + Index: i, // iOS drive: BE starts with 1 + Size: int64(len(encData)), + EncSignature: encSignatureStr, + Hash: base64Hash, + } + if len(verificationCode) > 0 { + verificationToken := buildVerificationToken(verificationCode, encData) + setVerifierTokenIfPresent(&blockUploadInfo, base64.StdEncoding.EncodeToString(verificationToken)) + } + pendingUploadBlocks = append(pendingUploadBlocks, PendingUploadBlocks{ - blockUploadInfo: proton.BlockUploadInfo{ - Index: i, // iOS drive: BE starts with 1 - Size: int64(len(encData)), - EncSignature: encSignatureStr, - Hash: base64Hash, - }, - encData: encData, + blockUploadInfo: blockUploadInfo, + encData: encData, }) } - err := uploadPendingBlocks() + err = uploadPendingBlocks() if err != nil { return nil, 0, nil, "", err } diff --git a/file_upload_test.go b/file_upload_test.go new file mode 100644 index 0000000..cd2bed6 --- /dev/null +++ b/file_upload_test.go @@ -0,0 +1,332 @@ +package proton_api_bridge + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "testing" + + "github.com/rclone/go-proton-api" +) + +type revisionVerificationByVolumeOnlyClient struct { + calledByVolume bool +} + +type fakeRevisionVerification struct { + VerificationCode string + ContentKeyPacket string +} + +func (c *revisionVerificationByVolumeOnlyClient) GetRevisionVerificationByVolume(context.Context, string, string, string) (fakeRevisionVerification, error) { + c.calledByVolume = true + return fakeRevisionVerification{VerificationCode: "vol", ContentKeyPacket: "pkt"}, nil +} + +type revisionVerificationByShareOnlyClient struct { + calledByShare bool +} + +func (c *revisionVerificationByShareOnlyClient) GetRevisionVerification(context.Context, string, string, string) (fakeRevisionVerification, error) { + c.calledByShare = true + return fakeRevisionVerification{VerificationCode: "share", ContentKeyPacket: "pkt"}, nil +} + +type revisionVerificationInvalidSignatureClient struct{} + +func (c *revisionVerificationInvalidSignatureClient) GetRevisionVerificationByVolume(context.Context, string) (fakeRevisionVerification, error) { + return fakeRevisionVerification{}, nil +} + +type revisionVerificationInvalidVolumeFallbackClient struct { + calledByShare bool +} + +func (c *revisionVerificationInvalidVolumeFallbackClient) GetRevisionVerificationByVolume(context.Context, string) (fakeRevisionVerification, error) { + return fakeRevisionVerification{}, nil +} + +func (c *revisionVerificationInvalidVolumeFallbackClient) GetRevisionVerification(context.Context, string, string, string) (fakeRevisionVerification, error) { + c.calledByShare = true + return fakeRevisionVerification{VerificationCode: "share", ContentKeyPacket: "pkt"}, nil +} + +type revisionVerificationPanicClient struct{} + +func (c *revisionVerificationPanicClient) GetRevisionVerificationByVolume(context.Context, string, string, string) (fakeRevisionVerification, error) { + panic("boom") +} + +type revisionVerificationMissingClient struct{} + +type uploadBlockCountClient struct { + calls int + err error +} + +func (c *uploadBlockCountClient) UploadBlock(context.Context, string, string, io.Reader) error { + c.calls++ + return c.err +} + +func TestBuildVerificationTokenXorsWithEncryptedBlock(t *testing.T) { + verificationCode := []byte{0x10, 0x20, 0x30, 0x40} + encData := []byte{0x01, 0x02, 0x03, 0x04} + + got := buildVerificationToken(verificationCode, encData) + want := []byte{0x11, 0x22, 0x33, 0x44} + + if !bytes.Equal(got, want) { + t.Fatalf("unexpected token: got=%v want=%v", got, want) + } +} + +func TestBuildVerificationTokenKeepsTailWhenBlockShorter(t *testing.T) { + verificationCode := []byte{0x10, 0x20, 0x30, 0x40} + encData := []byte{0x01, 0x02} + + got := buildVerificationToken(verificationCode, encData) + want := []byte{0x11, 0x22, 0x30, 0x40} + + if !bytes.Equal(got, want) { + t.Fatalf("unexpected token: got=%v want=%v", got, want) + } +} + +func TestRecoverBrokenConflictStateRecreatesWhenCode2501(t *testing.T) { + apiErr := &proton.APIError{Code: fileOrFolderNotFoundCode, Status: 422, Message: "File or folder not found"} + err := fmt.Errorf("wrapped transport error: %w", apiErr) + + deleteCalls := 0 + shouldRecreate, gotErr := recoverBrokenConflictState(err, proton.LinkStateDraft, func() error { + deleteCalls++ + return nil + }) + + if gotErr != nil { + t.Fatalf("expected nil error, got %v", gotErr) + } + if !shouldRecreate { + t.Fatalf("expected recreate=true") + } + if deleteCalls != 1 { + t.Fatalf("expected one delete call, got %d", deleteCalls) + } +} + +func TestRecoverBrokenConflictStatePropagatesDeleteError(t *testing.T) { + apiErr := &proton.APIError{Code: fileOrFolderNotFoundCode, Status: 422, Message: "File or folder not found"} + err := fmt.Errorf("wrapped transport error: %w", apiErr) + deleteErr := errors.New("delete failed") + + shouldRecreate, gotErr := recoverBrokenConflictState(err, proton.LinkStateDraft, func() error { + return deleteErr + }) + + if !errors.Is(gotErr, deleteErr) { + t.Fatalf("expected delete error, got %v", gotErr) + } + if shouldRecreate { + t.Fatalf("expected recreate=false") + } +} + +func TestRecoverBrokenConflictStateReturnsOriginalErrorOnOtherCode(t *testing.T) { + originalErr := fmt.Errorf("wrapped transport error: %w", &proton.APIError{Code: 2500, Status: 422, Message: "conflict"}) + + deleteCalls := 0 + shouldRecreate, gotErr := recoverBrokenConflictState(originalErr, proton.LinkStateDraft, func() error { + deleteCalls++ + return nil + }) + + if !errors.Is(gotErr, originalErr) { + t.Fatalf("expected original error, got %v", gotErr) + } + if shouldRecreate { + t.Fatalf("expected recreate=false") + } + if deleteCalls != 0 { + t.Fatalf("expected no delete call, got %d", deleteCalls) + } +} + +func TestRecoverBrokenConflictStateReturnsOriginalErrorForUnrelated2501(t *testing.T) { + originalErr := fmt.Errorf("wrapped transport error: %w", &proton.APIError{Code: fileOrFolderNotFoundCode, Status: 422, Message: "name reserved"}) + + deleteCalls := 0 + shouldRecreate, gotErr := recoverBrokenConflictState(originalErr, proton.LinkStateActive, func() error { + deleteCalls++ + return nil + }) + + if !errors.Is(gotErr, originalErr) { + t.Fatalf("expected original error, got %v", gotErr) + } + if shouldRecreate { + t.Fatalf("expected recreate=false") + } + if deleteCalls != 0 { + t.Fatalf("expected no delete call, got %d", deleteCalls) + } +} + +func TestRecoverBrokenConflictStateRecreatesDraftWithoutMessageMatch(t *testing.T) { + apiErr := &proton.APIError{Code: fileOrFolderNotFoundCode, Status: 422, Message: "draft conflict"} + err := fmt.Errorf("wrapped transport error: %w", apiErr) + + deleteCalls := 0 + shouldRecreate, gotErr := recoverBrokenConflictState(err, proton.LinkStateDraft, func() error { + deleteCalls++ + return nil + }) + + if gotErr != nil { + t.Fatalf("expected nil error, got %v", gotErr) + } + if !shouldRecreate { + t.Fatalf("expected recreate=true") + } + if deleteCalls != 1 { + t.Fatalf("expected one delete call, got %d", deleteCalls) + } +} + +func TestGetRevisionVerificationCompatPrefersVolumeRoute(t *testing.T) { + client := &revisionVerificationByVolumeOnlyClient{} + + res, err := getRevisionVerificationCompat(context.Background(), client, "share", "volume", "link", "revision") + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if !client.calledByVolume { + t.Fatalf("expected volume route to be used") + } + if res.VerificationCode != "vol" { + t.Fatalf("unexpected verification code: %q", res.VerificationCode) + } +} + +func TestGetRevisionVerificationCompatFallsBackToShareRoute(t *testing.T) { + client := &revisionVerificationByShareOnlyClient{} + + res, err := getRevisionVerificationCompat(context.Background(), client, "share", "volume", "link", "revision") + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if !client.calledByShare { + t.Fatalf("expected share route fallback to be used") + } + if res.VerificationCode != "share" { + t.Fatalf("unexpected verification code: %q", res.VerificationCode) + } +} + +func TestGetRevisionVerificationCompatFallsBackWhenVolumeRouteSignatureIsIncompatible(t *testing.T) { + client := &revisionVerificationInvalidVolumeFallbackClient{} + + res, err := getRevisionVerificationCompat(context.Background(), client, "share", "volume", "link", "revision") + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if !client.calledByShare { + t.Fatalf("expected share route fallback to be used") + } + if res.VerificationCode != "share" { + t.Fatalf("unexpected verification code: %q", res.VerificationCode) + } +} + +func TestGetRevisionVerificationCompatHandlesInvalidSignature(t *testing.T) { + _, err := getRevisionVerificationCompat(context.Background(), &revisionVerificationInvalidSignatureClient{}, "share", "volume", "link", "revision") + if err == nil { + t.Fatalf("expected error for incompatible signature") + } +} + +func TestGetRevisionVerificationCompatHandlesPanickingMethod(t *testing.T) { + _, err := getRevisionVerificationCompat(context.Background(), &revisionVerificationPanicClient{}, "share", "volume", "link", "revision") + if err == nil { + t.Fatalf("expected panic to be converted into error") + } +} + +func TestGetRevisionVerificationCompatAllowsMissingMethods(t *testing.T) { + res, err := getRevisionVerificationCompat(context.Background(), &revisionVerificationMissingClient{}, "share", "volume", "link", "revision") + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if res.VerificationCode != "" || res.ContentKeyPacket != "" { + t.Fatalf("expected empty compatibility result, got %#v", res) + } +} + +func TestCollectUploadErrorsReturnsFirstErrorAndStillDrains(t *testing.T) { + errChan := make(chan error, 3) + errChan <- nil + firstErr := errors.New("first") + errChan <- firstErr + errChan <- errors.New("second") + + cancelCalls := 0 + err := collectUploadErrors(errChan, 3, func() { + cancelCalls++ + }) + + if !errors.Is(err, firstErr) { + t.Fatalf("expected first error, got %v", err) + } + if cancelCalls != 1 { + t.Fatalf("expected one cancel call, got %d", cancelCalls) + } + if len(errChan) != 0 { + t.Fatalf("expected channel to be fully drained, remaining=%d", len(errChan)) + } +} + +func TestCollectUploadErrorsNoErrorNoCancel(t *testing.T) { + errChan := make(chan error, 2) + errChan <- nil + errChan <- nil + + cancelCalls := 0 + err := collectUploadErrors(errChan, 2, func() { + cancelCalls++ + }) + + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if cancelCalls != 0 { + t.Fatalf("expected zero cancel calls, got %d", cancelCalls) + } +} + +func TestUploadBlockWithClientDelegatesSingleCall(t *testing.T) { + wantErr := errors.New("permanent upload failure") + client := &uploadBlockCountClient{err: wantErr} + + err := uploadBlockWithClient(context.Background(), client, "https://example.invalid/upload", "token", bytes.NewReader([]byte("payload"))) + if !errors.Is(err, wantErr) { + t.Fatalf("expected %v, got %v", wantErr, err) + } + if client.calls != 1 { + t.Fatalf("expected one upload call, got %d", client.calls) + } +} + +func TestValidateUploadBatchCardinalityMatches(t *testing.T) { + err := validateUploadBatchCardinality(3, 3) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } +} + +func TestValidateUploadBatchCardinalityMismatch(t *testing.T) { + err := validateUploadBatchCardinality(2, 3) + if err == nil { + t.Fatalf("expected mismatch error") + } +} diff --git a/folder.go b/folder.go index 2aedf67..c3ca31d 100644 --- a/folder.go +++ b/folder.go @@ -2,11 +2,93 @@ package proton_api_bridge import ( "context" + "errors" + "reflect" "time" "github.com/rclone/go-proton-api" ) +func setNilPointerFieldIfPresent(target any, fieldName string) { + v := reflect.ValueOf(target) + if v.Kind() != reflect.Pointer || v.IsNil() { + return + } + v = v.Elem() + if v.Kind() != reflect.Struct { + return + } + + field := v.FieldByName(fieldName) + if !field.IsValid() || !field.CanSet() || field.Kind() != reflect.Pointer { + return + } + + field.Set(reflect.Zero(field.Type())) +} + +func moveLinkCompat(ctx context.Context, client any, shareID, volumeID, linkID string, req proton.MoveLinkReq) error { + tryCall := func(methodName string, args ...any) (bool, error) { + results, called, err := findAndCallMethod(client, methodName, args...) + if !called || err != nil { + return called, err + } + + if len(results) != 1 { + return true, newMethodCompatibilityError("incompatible move method signature") + } + + resultErr, err := extractErrorResult(results[0]) + if err != nil { + return true, newMethodCompatibilityError("incompatible move method signature: %w", err) + } + if resultErr == nil { + return true, nil + } + + return true, resultErr + } + + var compatErr error + + if called, err := tryCall("MoveLinkByVolume", ctx, volumeID, linkID, req); called { + if err == nil { + return nil + } + if !isMethodCompatibilityError(err) { + return err + } + compatErr = err + } + + if called, err := tryCall("MoveLink", ctx, shareID, linkID, req); called { + if err == nil { + return nil + } + if !isMethodCompatibilityError(err) { + return err + } + if compatErr == nil { + compatErr = err + } + } + + if compatErr != nil { + return compatErr + } + + return errors.New("no compatible move link method found") +} + +func applyMoveRequestSignatures(target any, signatureAddress, nodePassphraseSignature string, anonymousKey bool) { + setStringFieldIfPresent(target, "NameSignatureEmail", signatureAddress) + if !anonymousKey { + return + } + setStringFieldIfPresent(target, "SignatureEmail", signatureAddress) + setStringFieldIfPresent(target, "NodePassphraseSignature", nodePassphraseSignature) +} + type ProtonDirectoryData struct { Link *proton.Link Name string @@ -202,9 +284,8 @@ func (protonDrive *ProtonDrive) MoveFolder(ctx context.Context, srcLink *proton. func (protonDrive *ProtonDrive) moveLink(ctx context.Context, srcLink *proton.Link, dstParentLink *proton.Link, dstName string) error { // we are moving the srcLink to under dstParentLink, with name dstName req := proton.MoveLinkReq{ - ParentLinkID: dstParentLink.LinkID, - OriginalHash: srcLink.Hash, - SignatureAddress: protonDrive.signatureAddress, + ParentLinkID: dstParentLink.LinkID, + OriginalHash: srcLink.Hash, } dstParentKR, err := protonDrive.getLinkKR(ctx, dstParentLink) @@ -234,19 +315,20 @@ func (protonDrive *ProtonDrive) moveLink(ctx context.Context, srcLink *proton.Li if err != nil { return err } - nodePassphrase, err := reencryptKeyPacket(srcParentKR, dstParentKR, protonDrive.DefaultAddrKR, srcLink.NodePassphrase) + nodePassphrase, nodePassphraseSignature, err := reencryptKeyPacket(srcParentKR, dstParentKR, protonDrive.DefaultAddrKR, srcLink.NodePassphrase) if err != nil { return err } req.NodePassphrase = nodePassphrase - req.NodePassphraseSignature = srcLink.NodePassphraseSignature + applyMoveRequestSignatures(&req, protonDrive.signatureAddress, nodePassphraseSignature, srcLink.SignatureEmail == "") + setNilPointerFieldIfPresent(&req, "ContentHash") protonDrive.removeLinkIDFromCache(srcLink.LinkID, false) // TODO: disable cache when move is in action? // because there might be the case where others read for the same link currently being move -> race condition // argument: cache itself is already outdated in a sense, as we don't even have event system (even if we have, it's still outdated...) - err = protonDrive.c.MoveLink(ctx, protonDrive.MainShare.ShareID, srcLink.LinkID, req) + err = moveLinkCompat(ctx, protonDrive.c, protonDrive.MainShare.ShareID, protonDrive.MainShare.VolumeID, srcLink.LinkID, req) if err != nil { return err } diff --git a/folder_test.go b/folder_test.go new file mode 100644 index 0000000..30578bc --- /dev/null +++ b/folder_test.go @@ -0,0 +1,170 @@ +package proton_api_bridge + +import ( + "context" + "testing" + + "github.com/rclone/go-proton-api" +) + +type moveLinkByVolumeOnlyClient struct { + calledByVolume bool +} + +func (c *moveLinkByVolumeOnlyClient) MoveLinkByVolume(context.Context, string, string, proton.MoveLinkReq) error { + c.calledByVolume = true + return nil +} + +type moveLinkByShareOnlyClient struct { + calledByShare bool +} + +func (c *moveLinkByShareOnlyClient) MoveLink(context.Context, string, string, proton.MoveLinkReq) error { + c.calledByShare = true + return nil +} + +type moveLinkInvalidSignatureClient struct{} + +func (c *moveLinkInvalidSignatureClient) MoveLinkByVolume(context.Context, string) error { + return nil +} + +type moveLinkInvalidVolumeFallbackClient struct { + calledByShare bool +} + +func (c *moveLinkInvalidVolumeFallbackClient) MoveLinkByVolume(context.Context, string) error { + return nil +} + +func (c *moveLinkInvalidVolumeFallbackClient) MoveLink(context.Context, string, string, proton.MoveLinkReq) error { + c.calledByShare = true + return nil +} + +type moveLinkPanicClient struct{} + +func (c *moveLinkPanicClient) MoveLinkByVolume(context.Context, string, string, proton.MoveLinkReq) error { + panic("boom") +} + +func TestMoveLinkCompatPrefersVolumeRoute(t *testing.T) { + client := &moveLinkByVolumeOnlyClient{} + err := moveLinkCompat(context.Background(), client, "share", "volume", "link", proton.MoveLinkReq{}) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if !client.calledByVolume { + t.Fatalf("expected MoveLinkByVolume to be called") + } +} + +func TestMoveLinkCompatFallsBackToShareRoute(t *testing.T) { + client := &moveLinkByShareOnlyClient{} + err := moveLinkCompat(context.Background(), client, "share", "volume", "link", proton.MoveLinkReq{}) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if !client.calledByShare { + t.Fatalf("expected MoveLink to be called") + } +} + +func TestMoveLinkCompatFallsBackWhenVolumeRouteSignatureIsIncompatible(t *testing.T) { + client := &moveLinkInvalidVolumeFallbackClient{} + err := moveLinkCompat(context.Background(), client, "share", "volume", "link", proton.MoveLinkReq{}) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if !client.calledByShare { + t.Fatalf("expected MoveLink fallback to be called") + } +} + +func TestSetNilPointerFieldIfPresent(t *testing.T) { + contentHash := "hash" + req := proton.MoveLinkReq{} + setStringFieldIfPresent(&req, "ParentLinkID", "parent") + setStringFieldIfPresent(&req, "OriginalHash", "orig") + setStringFieldIfPresent(&req, "Hash", "next") + + reqWithField := struct { + proton.MoveLinkReq + ContentHash *string + }{ + MoveLinkReq: req, + ContentHash: &contentHash, + } + + setNilPointerFieldIfPresent(&reqWithField, "ContentHash") + if reqWithField.ContentHash != nil { + t.Fatalf("expected ContentHash to be nil") + } +} + +func TestSetMoveLinkSignatureAddressCompat(t *testing.T) { + req := proton.MoveLinkReq{} + + setStringFieldIfPresent(&req, "SignatureEmail", "addr@example.com") + setStringFieldIfPresent(&req, "NameSignatureEmail", "addr@example.com") + + if req.SignatureAddress != "addr@example.com" { + t.Fatalf("expected SignatureAddress to be set, got %q", req.SignatureAddress) + } +} + +func TestApplyMoveRequestSignaturesForSignedNode(t *testing.T) { + req := struct { + proton.MoveLinkReq + NameSignatureEmail string + SignatureEmail string + }{} + + applyMoveRequestSignatures(&req, "addr@example.com", "sig", false) + + if req.NameSignatureEmail != "addr@example.com" { + t.Fatalf("expected NameSignatureEmail to be set, got %q", req.NameSignatureEmail) + } + if req.SignatureEmail != "" { + t.Fatalf("expected SignatureEmail to be empty, got %q", req.SignatureEmail) + } + if req.NodePassphraseSignature != "" { + t.Fatalf("expected NodePassphraseSignature to be empty, got %q", req.NodePassphraseSignature) + } +} + +func TestApplyMoveRequestSignaturesForAnonymousNode(t *testing.T) { + req := struct { + proton.MoveLinkReq + NameSignatureEmail string + SignatureEmail string + }{} + + applyMoveRequestSignatures(&req, "addr@example.com", "sig", true) + + if req.NameSignatureEmail != "addr@example.com" { + t.Fatalf("expected NameSignatureEmail to be set, got %q", req.NameSignatureEmail) + } + if req.SignatureEmail != "addr@example.com" { + t.Fatalf("expected SignatureEmail to be set, got %q", req.SignatureEmail) + } + if req.NodePassphraseSignature != "sig" { + t.Fatalf("expected NodePassphraseSignature to be set, got %q", req.NodePassphraseSignature) + } +} + +func TestMoveLinkCompatHandlesInvalidSignature(t *testing.T) { + err := moveLinkCompat(context.Background(), &moveLinkInvalidSignatureClient{}, "share", "volume", "link", proton.MoveLinkReq{}) + if err == nil { + t.Fatalf("expected error for invalid move method signature") + } +} + +func TestMoveLinkCompatHandlesPanickingMethod(t *testing.T) { + err := moveLinkCompat(context.Background(), &moveLinkPanicClient{}, "share", "volume", "link", proton.MoveLinkReq{}) + if err == nil { + t.Fatalf("expected panic to be converted into error") + } +}