Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.DS_Store
.credential
.*.credential
.env
data
config.toml
config.toml
16 changes: 10 additions & 6 deletions common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import (

type Config struct {
/* Constants */
AppVersion string
UserAgent string
AppVersion string
DriveSDKVersion string
UserAgent string

/* Login */
FirstLoginCredential *FirstLoginCredentialData
Expand Down Expand Up @@ -44,8 +45,9 @@ type ReusableCredentialData struct {

func NewConfigWithDefaultValues() *Config {
return &Config{
AppVersion: "",
UserAgent: "",
AppVersion: "",
DriveSDKVersion: "",
UserAgent: "",

FirstLoginCredential: &FirstLoginCredentialData{
Username: "",
Expand Down Expand Up @@ -75,6 +77,7 @@ func NewConfigWithDefaultValues() *Config {

func NewConfigForIntegrationTests() *Config {
appVersion := os.Getenv("PROTON_API_BRIDGE_APP_VERSION")
driveSDKVersion := os.Getenv("PROTON_API_BRIDGE_DRIVE_SDK_VERSION")
userAgent := os.Getenv("PROTON_API_BRIDGE_USER_AGENT")

username := os.Getenv("PROTON_API_BRIDGE_TEST_USERNAME")
Expand All @@ -93,8 +96,9 @@ func NewConfigForIntegrationTests() *Config {
saltedKeyPass := os.Getenv("PROTON_API_BRIDGE_TEST_SALTEDKEYPASS")

return &Config{
AppVersion: appVersion,
UserAgent: userAgent,
AppVersion: appVersion,
DriveSDKVersion: driveSDKVersion,
UserAgent: userAgent,

FirstLoginCredential: &FirstLoginCredentialData{
Username: username,
Expand Down
21 changes: 21 additions & 0 deletions common/proton_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,33 @@ package common

import (
"github.com/rclone/go-proton-api"

"github.com/go-resty/resty/v2"
)

// Applies to all API calls made by the shared proton.Manager.
const defaultAPIRequestRetryCount = 3

type preRequestHookClient interface {
AddPreRequestHook(resty.RequestMiddleware)
}

func attachDriveSDKHeaderHook(client preRequestHookClient, driveSDKVersion string) {
if driveSDKVersion == "" {
return
}

client.AddPreRequestHook(func(_ *resty.Client, req *resty.Request) error {
req.SetHeader("x-pm-drive-sdk-version", driveSDKVersion)
return nil
})
}

func getProtonManager(appVersion string, userAgent string) *proton.Manager {
/* Notes on API calls: if the app version is not specified, the api calls will be rejected. */
options := []proton.Option{
proton.WithAppVersion(appVersion),
proton.WithRetryCount(defaultAPIRequestRetryCount),
proton.WithUserAgent(userAgent),
}
m := proton.New(options...)
Expand Down
124 changes: 124 additions & 0 deletions common/proton_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package common

import (
"context"
"net/http"
"net/http/httptest"
"reflect"
"sync/atomic"
"testing"
"time"
"unsafe"

"github.com/go-resty/resty/v2"
"github.com/rclone/go-proton-api"
)

type fakePreRequestHookClient struct {
hooks []resty.RequestMiddleware
}

func (f *fakePreRequestHookClient) AddPreRequestHook(hook resty.RequestMiddleware) {
f.hooks = append(f.hooks, hook)
}

func TestAttachDriveSDKHeaderHookSkipsEmptyVersion(t *testing.T) {
fakeClient := &fakePreRequestHookClient{}
attachDriveSDKHeaderHook(fakeClient, "")

if len(fakeClient.hooks) != 0 {
t.Fatalf("expected no hooks, got %d", len(fakeClient.hooks))
}
}

func TestAttachDriveSDKHeaderHookSetsHeader(t *testing.T) {
fakeClient := &fakePreRequestHookClient{}
attachDriveSDKHeaderHook(fakeClient, "js@0.10.0")

if len(fakeClient.hooks) != 1 {
t.Fatalf("expected one hook, got %d", len(fakeClient.hooks))
}

r := resty.New().R().SetContext(context.Background())
if err := fakeClient.hooks[0](resty.New(), r); err != nil {
t.Fatalf("expected nil error, got %v", err)
}

if got := r.Header.Get("x-pm-drive-sdk-version"); got != "js@0.10.0" {
t.Fatalf("expected drive sdk header to be set, got %q", got)
}
}

func TestGetProtonManagerRetriesRequests(t *testing.T) {
t.Run("succeeds after retry budget is consumed", func(t *testing.T) {
// Objective (positive): prove getProtonManager enables manager-level retries by succeeding
// only after the initial call plus default retry count have been attempted.
var callCount atomic.Int32

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/core/v4/addresses" {
t.Fatalf("expected request path %q, got %q", "/core/v4/addresses", r.URL.Path)
}

if callCount.Add(1) <= int32(defaultAPIRequestRetryCount) {
w.Header().Set("Retry-After", "-10")
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte(`{"Error":"temporary outage"}`))
return
}

w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"Addresses":[]}`))
}))
defer server.Close()

// Arrange
manager := getProtonManager("bridge-test", "bridge-test-agent")
configureManagerForRetryTest(t, manager, server.URL)

client := manager.NewClient("", "", "")
defer client.Close()
defer manager.Close()

// Act
addresses, err := client.GetAddresses(context.Background())

// Assert
if err != nil {
t.Fatalf("expected request to succeed after retries, got error: %v", err)
}
if len(addresses) != 0 {
t.Fatalf("expected no addresses from fake response, got %d", len(addresses))
}

expectedCalls := int32(defaultAPIRequestRetryCount + 1)
if got := callCount.Load(); got != expectedCalls {
t.Fatalf("expected %d calls (initial + retries), got %d", expectedCalls, got)
}
})

}

func configureManagerForRetryTest(t *testing.T, manager *proton.Manager, baseURL string) {
t.Helper()

managerValue := reflect.ValueOf(manager)
if managerValue.Kind() != reflect.Pointer || managerValue.IsNil() {
t.Fatal("expected non-nil proton manager pointer")
}

rcField := managerValue.Elem().FieldByName("rc")
if !rcField.IsValid() {
t.Fatal("expected manager to contain resty client field")
}

rcValue := reflect.NewAt(rcField.Type(), unsafe.Pointer(rcField.UnsafeAddr())).Elem()
rc, ok := rcValue.Interface().(*resty.Client)
if !ok || rc == nil {
t.Fatal("expected manager resty client to be accessible")
}

rc.SetBaseURL(baseURL)
rc.SetRetryWaitTime(time.Millisecond)
rc.SetRetryMaxWaitTime(time.Millisecond)
}
2 changes: 2 additions & 0 deletions common/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func Login(ctx context.Context, config *Config, authHandler proton.AuthHandler,

if config.UseReusableLogin {
c = m.NewClient(config.ReusableCredential.UID, config.ReusableCredential.AccessToken, config.ReusableCredential.RefreshToken)
attachDriveSDKHeaderHook(c, config.DriveSDKVersion)
c.AddAuthHandler(authHandler)
c.AddDeauthHandler(deAuthHandler)

Expand Down Expand Up @@ -90,6 +91,7 @@ func Login(ctx context.Context, config *Config, authHandler proton.AuthHandler,
if err != nil {
return nil, nil, nil, nil, nil, nil, err
}
attachDriveSDKHeaderHook(c, config.DriveSDKVersion)
c.AddAuthHandler(authHandler)
c.AddDeauthHandler(deAuthHandler)

Expand Down
116 changes: 116 additions & 0 deletions compat_reflect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package proton_api_bridge

import (
"errors"
"fmt"
"reflect"
)

type methodCompatibilityError struct {
err error
}

func (e *methodCompatibilityError) Error() string {
return e.err.Error()
}

func (e *methodCompatibilityError) Unwrap() error {
return e.err
}

func newMethodCompatibilityError(format string, args ...any) error {
return &methodCompatibilityError{err: fmt.Errorf(format, args...)}
}

func isMethodCompatibilityError(err error) bool {
var compatErr *methodCompatibilityError
return errors.As(err, &compatErr)
}

func findAndCallMethod(target any, methodName string, args ...any) (_ []reflect.Value, called bool, err error) {
defer func() {
if recovered := recover(); recovered != nil {
called = true
err = fmt.Errorf("%s panic: %v", methodName, recovered)
}
}()

if target == nil {
return nil, false, nil
}

targetValue := reflect.ValueOf(target)
if !targetValue.IsValid() {
return nil, false, nil
}

method := targetValue.MethodByName(methodName)
if !method.IsValid() {
return nil, false, nil
}

methodType := method.Type()
if methodType.NumIn() != len(args) {
return nil, true, newMethodCompatibilityError("%s has incompatible argument count", methodName)
}

callArgs := make([]reflect.Value, len(args))
for i := range args {
paramType := methodType.In(i)
argValue, err := getCallableValue(paramType, args[i])
if err != nil {
return nil, true, newMethodCompatibilityError("%s argument %d: %w", methodName, i, err)
}
callArgs[i] = argValue
}

return method.Call(callArgs), true, nil
}

func getCallableValue(paramType reflect.Type, arg any) (reflect.Value, error) {
if arg == nil {
switch paramType.Kind() {
case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice, reflect.Func, reflect.Chan:
return reflect.Zero(paramType), nil
default:
return reflect.Value{}, fmt.Errorf("nil not assignable to %s", paramType)
}
}

v := reflect.ValueOf(arg)
if v.Type().AssignableTo(paramType) {
return v, nil
}
if v.Type().ConvertibleTo(paramType) {
return v.Convert(paramType), nil
}

return reflect.Value{}, fmt.Errorf("%s not assignable to %s", v.Type(), paramType)
}

func extractErrorResult(value reflect.Value) (error, error) {
errorType := reflect.TypeOf((*error)(nil)).Elem()
if !value.Type().Implements(errorType) {
return nil, fmt.Errorf("result does not implement error")
}

if isNilableKind(value.Kind()) && value.IsNil() {
return nil, nil
}

err, ok := value.Interface().(error)
if !ok {
return nil, fmt.Errorf("result cannot be converted to error")
}

return err, nil
}

func isNilableKind(kind reflect.Kind) bool {
switch kind {
case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice, reflect.Func, reflect.Chan:
return true
default:
return false
}
}
Loading