diff --git a/cmd/app/attachment.go b/cmd/app/attachment.go index d67b1e89..41e4c467 100644 --- a/cmd/app/attachment.go +++ b/cmd/app/attachment.go @@ -36,15 +36,15 @@ func (ar attachmentReader) ReadFile(path string) (io.Reader, error) { } data, err := io.ReadAll(file) + closeErr := file.Close() if err != nil { return nil, err } + if closeErr != nil { + return nil, closeErr + } - defer file.Close() - - reader := bytes.NewReader(data) - - return reader, nil + return bytes.NewReader(data), nil } type FileUploader interface { diff --git a/cmd/app/emoji.go b/cmd/app/emoji.go index 27e29188..29996b3f 100644 --- a/cmd/app/emoji.go +++ b/cmd/app/emoji.go @@ -104,14 +104,13 @@ func (a emojiService) deleteEmojiFromNote(w http.ResponseWriter, r *http.Request /* postEmojiOnNote adds an emojis to a note based on the note's ID */ func (a emojiService) postEmojiOnNote(w http.ResponseWriter, r *http.Request) { + defer func() { _ = r.Body.Close() }() body, err := io.ReadAll(r.Body) if err != nil { handleError(w, err, "Could not read request body", http.StatusBadRequest) return } - defer r.Body.Close() - var emojiPost CreateNoteEmojiPost err = json.Unmarshal(body, &emojiPost) diff --git a/cmd/app/git/git.go b/cmd/app/git/git.go index 1fcc7087..ed9aacf5 100644 --- a/cmd/app/git/git.go +++ b/cmd/app/git/git.go @@ -12,6 +12,8 @@ type GitManager interface { GetProjectUrlFromNativeGitCmd(remote string) (url string, err error) GetCurrentBranchNameFromNativeGitCmd() (string, error) GetLatestCommitOnRemote(remote string, branchName string) (string, error) + GetMRHeadCommit(mrIID int64) (string, error) + FetchMRHead(remote string, mrIID int64) error } type GitData struct { @@ -68,7 +70,7 @@ func NewGitData(remote string, gitlabUrl string, g GitManager) (GitData, error) // remove part of the hostname from the parsed namespace url_re := regexp.MustCompile(`[^\/]\/([^\/].*)$`) url_matches := url_re.FindStringSubmatch(gitlabUrl) - var namespace string = matches[1] + namespace := matches[1] if len(url_matches) == 2 { namespace = strings.TrimLeft(strings.TrimPrefix(namespace, url_matches[1]), "/") } @@ -125,6 +127,39 @@ func (g Git) RefreshProjectInfo(remote string) error { return nil } +/* Fetches the head ref of a merge request so that fork MR commits are available locally */ +func (g Git) FetchMRHead(remote string, mrIID int64) error { + ref := fmt.Sprintf("refs/merge-requests/%d/head", mrIID) + cmd := exec.Command("git", "fetch", remote, fmt.Sprintf("%s:%s", ref, ref)) + _, err := cmd.Output() + if err != nil { + return fmt.Errorf("failed to fetch MR head ref: %v", err) + } + return nil +} + +/* +FetchMrHead fetches the MR head ref when a specific MR is chosen (mrIID != 0). +Returns an error only if the fetch was attempted and failed; returns nil if skipped or succeeded. +*/ +func FetchMrHead(g GitManager, remote string, mrIID int64) error { + if mrIID == 0 { + return nil + } + return g.FetchMRHead(remote, mrIID) +} + +/* Resolves the commit SHA from a locally stored MR head ref */ +func (g Git) GetMRHeadCommit(mrIID int64) (string, error) { + ref := fmt.Sprintf("refs/merge-requests/%d/head", mrIID) + cmd := exec.Command("git", "rev-parse", ref) + out, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("failed to resolve MR head ref: %v", err) + } + return strings.TrimSpace(string(out)), nil +} + func (g Git) GetLatestCommitOnRemote(remote string, branchName string) (string, error) { cmd := exec.Command("git", "log", "-1", "--format=%H", fmt.Sprintf("%s/%s", remote, branchName)) diff --git a/cmd/app/git/git_test.go b/cmd/app/git/git_test.go index 9158c901..f6ec2f55 100644 --- a/cmd/app/git/git_test.go +++ b/cmd/app/git/git_test.go @@ -2,6 +2,9 @@ package git import ( "errors" + "os" + "os/exec" + "strings" "testing" ) @@ -24,6 +27,14 @@ func (f FakeGitManager) GetLatestCommitOnRemote(remote string, branchName string return "", nil } +func (f FakeGitManager) GetMRHeadCommit(mrIID int64) (string, error) { + return "", nil +} + +func (f FakeGitManager) FetchMRHead(remote string, mrIID int64) error { + return nil +} + func (f FakeGitManager) GetProjectUrlFromNativeGitCmd(string) (url string, err error) { return f.RemoteUrl, nil } @@ -286,3 +297,113 @@ func TestExtractGitInfo_FailToGetCurrentBranchName(t *testing.T) { } }) } + +type fetchTrackingManager struct { + fetchCalled bool + fetchErr error + FakeGitManager +} + +func (f *fetchTrackingManager) FetchMRHead(remote string, mrIID int64) error { + f.fetchCalled = true + return f.fetchErr +} + +func TestFetchMrHead_SkipsWhenMrIIDIsZero(t *testing.T) { + g := &fetchTrackingManager{} + err := FetchMrHead(g, "origin", 0) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if g.fetchCalled { + t.Error("Expected FetchMRHead not to be called when mrIID is 0") + } +} + +func TestFetchMrHead_FetchesWhenMrIIDIsNonZero(t *testing.T) { + g := &fetchTrackingManager{} + err := FetchMrHead(g, "origin", 42) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if !g.fetchCalled { + t.Error("Expected FetchMRHead to be called when mrIID is non-zero") + } +} + +func TestFetchMrHead_ReturnsErrorOnFetchFailure(t *testing.T) { + g := &fetchTrackingManager{ + fetchErr: errors.New("fetch failed"), + } + err := FetchMrHead(g, "origin", 42) + if err == nil { + t.Error("Expected an error, got nil") + } + if err.Error() != "fetch failed" { + t.Errorf("Expected 'fetch failed', got '%s'", err.Error()) + } +} + +// setupTempGitRepo creates a temp git repo with a commit and an MR head ref, +// changes the working directory to it, and returns the expected commit SHA. +// The caller's working directory is restored via t.Cleanup. +func setupTempGitRepo(t *testing.T, mrIID string) string { + t.Helper() + + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + + tmpDir := t.TempDir() + if err := os.Chdir(tmpDir); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = os.Chdir(origDir) }) + + cmds := [][]string{ + {"git", "init"}, + {"git", "config", "user.email", "test@test.com"}, + {"git", "config", "user.name", "Test"}, + {"git", "commit", "--allow-empty", "-m", "init"}, + {"git", "update-ref", "refs/merge-requests/" + mrIID + "/head", "HEAD"}, + } + for _, args := range cmds { + cmd := exec.Command(args[0], args[1:]...) + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("command %v failed: %v\n%s", args, err, out) + } + } + + out, err := exec.Command("git", "rev-parse", "HEAD").Output() + if err != nil { + t.Fatal(err) + } + return strings.TrimSpace(string(out)) +} + +func TestGetMRHeadCommit_Success(t *testing.T) { + expectedSHA := setupTempGitRepo(t, "42") + + g := Git{} + sha, err := g.GetMRHeadCommit(42) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if sha != expectedSHA { + t.Errorf("Expected SHA %s, got %s", expectedSHA, sha) + } +} + +func TestGetMRHeadCommit_NonExistentRef(t *testing.T) { + setupTempGitRepo(t, "42") + + g := Git{} + _, err := g.GetMRHeadCommit(9999) + if err == nil { + t.Fatal("Expected an error for non-existent ref, got nil") + } + if !strings.Contains(err.Error(), "failed to resolve MR head ref") { + t.Errorf("Expected error about resolving MR head ref, got: %s", err.Error()) + } +} diff --git a/cmd/app/label.go b/cmd/app/label.go index 957fbcdb..55bbaf1f 100644 --- a/cmd/app/label.go +++ b/cmd/app/label.go @@ -86,13 +86,12 @@ func (a labelService) getLabels(w http.ResponseWriter, r *http.Request) { func (a labelService) updateLabels(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") + defer func() { _ = r.Body.Close() }() body, err := io.ReadAll(r.Body) if err != nil { handleError(w, err, "Could not read request body", http.StatusBadRequest) return } - - defer r.Body.Close() var labelUpdateRequest LabelUpdateRequest err = json.Unmarshal(body, &labelUpdateRequest) diff --git a/cmd/app/list_discussions_test.go b/cmd/app/list_discussions_test.go index ab142a4c..17f4f371 100644 --- a/cmd/app/list_discussions_test.go +++ b/cmd/app/list_discussions_test.go @@ -111,7 +111,7 @@ func TestListDiscussions(t *testing.T) { withMethodCheck(http.MethodPost), ) data := getDiscussionsList(t, svc, request) - assert(t, data.SuccessResponse.Message, "Discussions retrieved") + assert(t, data.Message, "Discussions retrieved") assert(t, len(data.Discussions), 2) assert(t, data.Discussions[0].Notes[0].Author.Username, "hcramer4") assert(t, data.Discussions[1].Notes[0].Author.Username, "hcramer2") diff --git a/cmd/app/logging.go b/cmd/app/logging.go index d85820db..679f3b91 100644 --- a/cmd/app/logging.go +++ b/cmd/app/logging.go @@ -53,7 +53,7 @@ func (l LoggingServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { func logRequest(prefix string, r *http.Request) { file := openLogFile() - defer file.Close() + defer func() { _ = file.Close() }() token := r.Header.Get("Private-Token") r.Header.Set("Private-Token", "REDACTED") res, err := httputil.DumpRequest(r, true) @@ -67,7 +67,7 @@ func logRequest(prefix string, r *http.Request) { func logResponse(prefix string, r *http.Response) { file := openLogFile() - defer file.Close() + defer func() { _ = file.Close() }() res, err := httputil.DumpResponse(r, true) if err != nil { diff --git a/cmd/app/merge_requests_by_username.go b/cmd/app/merge_requests_by_username.go index 8a9c85e2..ca39c805 100644 --- a/cmd/app/merge_requests_by_username.go +++ b/cmd/app/merge_requests_by_username.go @@ -125,7 +125,7 @@ func (a mergeRequestListerByUsernameService) getMrs(payload *gitlab.ListProjectM return []*gitlab.BasicMergeRequest{}, GenericError{endpoint: "/merge_requests_by_username"} } - defer res.Body.Close() + _ = res.Body.Close() return mrs, err } diff --git a/cmd/app/middleware.go b/cmd/app/middleware.go index 1ea4b0c8..70b288ff 100644 --- a/cmd/app/middleware.go +++ b/cmd/app/middleware.go @@ -103,12 +103,16 @@ func (m withMrMiddleware) handle(next http.Handler) http.Handler { // If the merge request is already attached, skip the middleware logic if m.data.projectInfo.MergeId == 0 { options := gitlab.ListProjectMergeRequestsOptions{ - Scope: gitlab.Ptr("all"), - SourceBranch: &m.data.gitInfo.BranchName, + Scope: gitlab.Ptr("all"), } if pluginOptions.ChosenMrIID != 0 { + // When an MR was explicitly chosen by IID (e.g. via choose_merge_request), + // only filter by IID. Skipping SourceBranch allows fork MRs to be found, + // since their source branch doesn't exist in the local repository. options.IIDs = gitlab.Ptr([]int64{pluginOptions.ChosenMrIID}) + } else { + options.SourceBranch = &m.data.gitInfo.BranchName } mergeRequests, _, err := m.client.ListProjectMergeRequests(m.data.projectInfo.ProjectId, &options) @@ -175,9 +179,9 @@ func formatValidationErrors(errs validator.ValidationErrors) error { } switch e.Tag() { case "required": - s.WriteString(fmt.Sprintf("%s is required", e.Field())) + fmt.Fprintf(&s, "%s is required", e.Field()) default: - s.WriteString(fmt.Sprintf("The field '%s' failed on validation on the '%s' tag", e.Field(), e.Tag())) + fmt.Fprintf(&s, "The field '%s' failed on validation on the '%s' tag", e.Field(), e.Tag()) } } diff --git a/cmd/app/middleware_test.go b/cmd/app/middleware_test.go index 598c3d59..16055e6d 100644 --- a/cmd/app/middleware_test.go +++ b/cmd/app/middleware_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/harrisoncramer/gitlab.nvim/cmd/app/git" + gitlab "gitlab.com/gitlab-org/api/client-go" ) type FakePayload struct { @@ -14,6 +15,17 @@ type FakePayload struct { type fakeHandler struct{} +// capturingMergeRequestLister records the options passed to ListProjectMergeRequests +// so tests can verify the filter logic in the withMr middleware. +type capturingMergeRequestLister struct { + capturedOpts *gitlab.ListProjectMergeRequestsOptions +} + +func (f *capturingMergeRequestLister) ListProjectMergeRequests(pid interface{}, opt *gitlab.ListProjectMergeRequestsOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.BasicMergeRequest, *gitlab.Response, error) { + f.capturedOpts = opt + return []*gitlab.BasicMergeRequest{{IID: 10}}, makeResponse(http.StatusOK), nil +} + func (f fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) data := SuccessResponse{Message: "Some message"} @@ -90,6 +102,39 @@ func TestWithMrMiddleware(t *testing.T) { assert(t, data.Message, "Multiple MRs found") assert(t, data.Details, "please call gitlab.choose_merge_request()") }) + t.Run("Filters by IIDs when ChosenMrIID is set", func(t *testing.T) { + pluginOptions.ChosenMrIID = 42 + defer func() { pluginOptions.ChosenMrIID = 0 }() + + request := makeRequest(t, http.MethodGet, "/foo", nil) + lister := &capturingMergeRequestLister{} + d := data{ + projectInfo: &ProjectInfo{}, + gitInfo: &git.GitData{BranchName: "foo"}, + } + mw := withMr(d, lister) + handler := middleware(fakeHandler{}, mw) + getSuccessData(t, handler, request) + + assert(t, (*lister.capturedOpts.IIDs)[0], int64(42)) + assert(t, lister.capturedOpts.SourceBranch == nil, true) + }) + t.Run("Filters by SourceBranch when ChosenMrIID is not set", func(t *testing.T) { + pluginOptions.ChosenMrIID = 0 + + request := makeRequest(t, http.MethodGet, "/foo", nil) + lister := &capturingMergeRequestLister{} + d := data{ + projectInfo: &ProjectInfo{}, + gitInfo: &git.GitData{BranchName: "foo"}, + } + mw := withMr(d, lister) + handler := middleware(fakeHandler{}, mw) + getSuccessData(t, handler, request) + + assert(t, *lister.capturedOpts.SourceBranch, "foo") + assert(t, lister.capturedOpts.IIDs == nil, true) + }) } func TestValidatorMiddleware(t *testing.T) { diff --git a/cmd/app/pipeline.go b/cmd/app/pipeline.go index 24fb6935..7ef08ae2 100644 --- a/cmd/app/pipeline.go +++ b/cmd/app/pipeline.go @@ -84,7 +84,10 @@ func (a pipelineService) GetPipelineAndJobs(w http.ResponseWriter, r *http.Reque w.Header().Set("Content-Type", "application/json") commit, err := a.gitService.GetLatestCommitOnRemote(pluginOptions.ConnectionSettings.Remote, a.gitInfo.BranchName) - + if err != nil && pluginOptions.ChosenMrIID != 0 { + // Fall back to the MR head ref for fork MRs where the branch doesn't exist on origin + commit, err = a.gitService.GetMRHeadCommit(pluginOptions.ChosenMrIID) + } if err != nil { handleError(w, err, "Error getting commit on remote branch", http.StatusInternalServerError) return diff --git a/cmd/app/test_helpers.go b/cmd/app/test_helpers.go index 2839f77b..aad66338 100644 --- a/cmd/app/test_helpers.go +++ b/cmd/app/test_helpers.go @@ -133,6 +133,14 @@ func (f FakeGitManager) GetLatestCommitOnRemote(remote string, branchName string return "", nil } +func (f FakeGitManager) GetMRHeadCommit(mrIID int64) (string, error) { + return "", nil +} + +func (f FakeGitManager) FetchMRHead(remote string, mrIID int64) error { + return nil +} + func (f FakeGitManager) GetProjectUrlFromNativeGitCmd(string) (url string, err error) { return f.RemoteUrl, nil } diff --git a/cmd/main.go b/cmd/main.go index 3aae8dbe..bf7f1528 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -34,6 +34,11 @@ func main() { log.Fatalf("Failure initializing plugin: %v", err) } + // Fetch MR head ref when a specific MR is chosen, so fork MR commits are available locally for diffview + if fetchErr := git.FetchMrHead(gitManager, pluginOptions.ConnectionSettings.Remote, pluginOptions.ChosenMrIID); fetchErr != nil { + log.Printf("Warning: could not fetch MR head ref: %v", fetchErr) + } + client, err := app.NewClient() if err != nil { log.Fatalf("Failed to initialize Gitlab client: %v", err) diff --git a/lua/gitlab/actions/merge_requests.lua b/lua/gitlab/actions/merge_requests.lua index d9af5e29..760aefa4 100644 --- a/lua/gitlab/actions/merge_requests.lua +++ b/lua/gitlab/actions/merge_requests.lua @@ -4,6 +4,54 @@ local git = require("gitlab.git") local u = require("gitlab.utils") local M = {} +---Sets the chosen MR and restarts the server, optionally opening the reviewer. +---@param choice table +---@param opts ChooseMergeRequestOptions +local activate_mr = function(choice, opts) + vim.schedule(function() + state.chosen_mr_iid = choice.iid + require("gitlab.server").restart(function() + if opts.open_reviewer then + require("gitlab").review() + end + end) + end) +end + +---Selects a fork MR by restarting the server without switching branches. +---@param choice table +---@param opts ChooseMergeRequestOptions +local select_fork_mr = function(choice, opts) + activate_mr(choice, opts) +end + +---Selects a local MR by switching to its source branch and restarting the server. +---@param choice table +---@param opts ChooseMergeRequestOptions +local select_local_mr = function(choice, opts) + if choice.source_branch ~= git.get_current_branch() then + local has_clean_tree, clean_tree_err = git.has_clean_tree() + if clean_tree_err ~= nil then + return + elseif not has_clean_tree then + u.notify( + "Cannot switch branch when working tree has changes, please stash or commit and push", + vim.log.levels.ERROR + ) + return + end + end + + vim.schedule(function() + local _, branch_switch_err = git.switch_branch(choice.source_branch) + if branch_switch_err ~= nil then + return + end + + activate_mr(choice, opts) + end) +end + ---@class ChooseMergeRequestOptions ---@field open_reviewer? boolean ---@field label? string[] @@ -30,34 +78,13 @@ M.choose_merge_request = function(opts) reviewer.close() end - if choice.source_branch ~= git.get_current_branch() then - local has_clean_tree, clean_tree_err = git.has_clean_tree() - if clean_tree_err ~= nil then - return - elseif not has_clean_tree then - u.notify( - "Cannot switch branch when working tree has changes, please stash or commit and push", - vim.log.levels.ERROR - ) - return - end - end - - vim.schedule(function() - local _, branch_switch_err = git.switch_branch(choice.source_branch) - if branch_switch_err ~= nil then - return - end + local is_fork_mr = choice.source_project_id ~= choice.target_project_id - vim.schedule(function() - state.chosen_mr_iid = choice.iid - require("gitlab.server").restart(function() - if opts.open_reviewer then - require("gitlab").review() - end - end) - end) - end) + if is_fork_mr then + select_fork_mr(choice, opts) + else + select_local_mr(choice, opts) + end end) end diff --git a/lua/gitlab/hunks.lua b/lua/gitlab/hunks.lua index a8231a98..691d0efe 100644 --- a/lua/gitlab/hunks.lua +++ b/lua/gitlab/hunks.lua @@ -91,8 +91,9 @@ end ---Parse git diff hunks. ---@param base_sha string Git base SHA of merge request. +---@param head_sha string Git head SHA of merge request. ---@return HunksAndDiff -local parse_hunks_and_diff = function(base_sha) +local parse_hunks_and_diff = function(base_sha, head_sha) local hunks = {} local all_diff_output = {} @@ -104,6 +105,7 @@ local parse_hunks_and_diff = function(base_sha) "--no-color", "--no-ext-diff", base_sha, + head_sha, "--", reviewer.get_current_file_oldpath(), reviewer.get_current_file_path(), @@ -234,7 +236,7 @@ end ---@param new_sha_focused boolean ---@return string|nil function M.get_modification_type(old_line, new_line, new_sha_focused) - local hunk_and_diff_data = parse_hunks_and_diff(state.INFO.diff_refs.base_sha) + local hunk_and_diff_data = parse_hunks_and_diff(state.INFO.diff_refs.base_sha, state.INFO.diff_refs.head_sha) if hunk_and_diff_data.hunks == nil then return end diff --git a/lua/gitlab/indicators/diagnostics.lua b/lua/gitlab/indicators/diagnostics.lua index ccdd9363..bed0021e 100644 --- a/lua/gitlab/indicators/diagnostics.lua +++ b/lua/gitlab/indicators/diagnostics.lua @@ -117,6 +117,9 @@ M.place_diagnostics = function(bufnr) if not state.settings.discussion_signs.enabled then return end + if bufnr == nil or bufnr == 0 then + return + end local view = diffview_lib.get_current_view() if view == nil then u.notify("Could not find Diffview view", vim.log.levels.ERROR)