diff --git a/pkg/os/shell/shell.go b/pkg/os/shell/shell.go index e469a35018..96af153041 100644 --- a/pkg/os/shell/shell.go +++ b/pkg/os/shell/shell.go @@ -3,11 +3,8 @@ package shell import ( "fmt" "os" - "sort" - "strconv" "strings" - "github.com/crc-org/crc/v2/pkg/crc/logging" crcos "github.com/crc-org/crc/v2/pkg/os" ) @@ -154,83 +151,3 @@ func IsWindowsSubsystemLinux() bool { } return false } - -// detectShellByInvokingCommand is a utility method that tries to detect current shell in use by invoking `ps` command. -// This method is extracted so that it could be used by unix systems as well as Windows (in case of WSL). It executes -// the command provided in the method arguments and then passes the output to inspectProcessOutputForRecentlyUsedShell -// for evaluation. -// -// It receives two arguments: -// - defaultShell : default shell to revert back to in case it's unable to detect. -// - command: command to be executed -// - args: a string array containing command arguments -// -// It returns a string value representing current shell. -func detectShellByInvokingCommand(defaultShell string, command string, args []string) string { - stdOut, _, err := CommandRunner.Run(command, args...) - if err != nil { - return defaultShell - } - - detectedShell := inspectProcessOutputForRecentlyUsedShell(stdOut) - if detectedShell == "" { - return defaultShell - } - logging.Debugf("Detected shell: %s", detectedShell) - return detectedShell -} - -// inspectProcessOutputForRecentlyUsedShell inspects output of ps command to detect currently active shell session. -// -// It parses the output into a struct, filters process types by name then reverse sort it with pid and returns the first element. -// -// It takes one argument: -// -// - psCommandOutput: output of ps command executed on a particular shell session -// -// It returns: -// -// - a string value (one of `zsh`, `bash` or `fish`) for current shell environment in use. If it's not able to determine -// underlying shell type, it returns and empty string. -// -// This method tries to check all processes open and filters out shell sessions (one of `zsh`, `bash` or `fish) -// It then returns first shell process. -// -// For example, if ps command gives this output: -// -// 2908 ps -// 2889 fish -// 823 bash -// -// Then this method would return `fish` as it's the first shell process. -func inspectProcessOutputForRecentlyUsedShell(psCommandOutput string) string { - type ProcessOutput struct { - processID int - output string - } - var processOutputs []ProcessOutput - lines := strings.Split(psCommandOutput, "\n") - for _, line := range lines { - lineParts := strings.Split(strings.TrimSpace(line), " ") - if len(lineParts) == 2 && (strings.Contains(lineParts[1], "zsh") || - strings.Contains(lineParts[1], "bash") || - strings.Contains(lineParts[1], "fish")) { - parsedProcessID, err := strconv.Atoi(lineParts[0]) - if err == nil { - processOutputs = append(processOutputs, ProcessOutput{ - processID: parsedProcessID, - output: lineParts[1], - }) - } - } - } - // Reverse sort the processes by PID (higher to lower) - sort.Slice(processOutputs, func(i, j int) bool { - return processOutputs[i].processID > processOutputs[j].processID - }) - - if len(processOutputs) > 0 { - return processOutputs[0].output - } - return "" -} diff --git a/pkg/os/shell/shell_unix.go b/pkg/os/shell/shell_unix.go index e75f0f1896..279d40cd88 100644 --- a/pkg/os/shell/shell_unix.go +++ b/pkg/os/shell/shell_unix.go @@ -6,16 +6,38 @@ package shell import ( "errors" "fmt" + "os" "path/filepath" + + "github.com/shirou/gopsutil/v3/process" + "github.com/spf13/cast" ) var ( - ErrUnknownShell = errors.New("Error: Unknown shell") + ErrUnknownShell = errors.New("Error: Unknown shell") + getCurrentProcessID = os.Getpid + processSupplier AbstractProcessSupplier = &ProcessSupplier{} + ProcessDepthLimit = 10 ) +type AbstractProcess interface { + Name() (string, error) + Ppid() (int32, error) +} + +type AbstractProcessSupplier interface { + NewProcess(pid int32) (AbstractProcess, error) +} + +type ProcessSupplier struct{} + +func (p *ProcessSupplier) NewProcess(pid int32) (AbstractProcess, error) { + return process.NewProcess(pid) +} + // detect detects user's current shell. func detect() (string, error) { - detectedShell := detectShellByInvokingCommand("", "ps", []string{"-o", "pid=,comm="}) + detectedShell := detectShellByCheckingProcessTree(getCurrentProcessID()) if detectedShell == "" { fmt.Printf("The default lines below are for a sh/bash shell, you can specify the shell you're using, with the --shell flag.\n\n") return "", ErrUnknownShell @@ -23,3 +45,47 @@ func detect() (string, error) { return filepath.Base(detectedShell), nil } + +func detectShellByCheckingProcessTree(pid int) string { + level := 0 + for level < ProcessDepthLimit { + processName, err := getProcessName(pid) + if err != nil { + return "" + } + if processName == "zsh" || processName == "bash" || processName == "fish" { + // fmt.Printf("Found : %s\n", processName) + return processName + } + pid, err = getParentProcessID(pid) + if err != nil { + return "" + } + level++ + } + return "" +} + +func getProcessName(pid int) (string, error) { + p, err := processSupplier.NewProcess(cast.ToInt32(pid)) + if err != nil { + return "", err + } + processName, err := p.Name() + if err != nil { + return "", err + } + return processName, nil +} + +func getParentProcessID(pid int) (int, error) { + p, err := processSupplier.NewProcess(cast.ToInt32(pid)) + if err != nil { + return 0, err + } + ppid, err := p.Ppid() + if err != nil { + return 0, err + } + return cast.ToInt(ppid), nil +} diff --git a/pkg/os/shell/shell_unix_test.go b/pkg/os/shell/shell_unix_test.go index a11b5a0db5..08762ee9a5 100644 --- a/pkg/os/shell/shell_unix_test.go +++ b/pkg/os/shell/shell_unix_test.go @@ -5,80 +5,192 @@ package shell import ( "bytes" + "errors" "os" "testing" "github.com/stretchr/testify/assert" ) +type MockedProcess struct { + name string + ppid int32 + nameGetFails bool + ppidGetFails bool +} + +func (m *MockedProcess) Ppid() (int32, error) { + if m.ppidGetFails { + return -1, errors.New("failed to get the pid") + } + return m.ppid, nil +} + +func (m *MockedProcess) Name() (string, error) { + if m.nameGetFails { + return "", errors.New("failed to get the name") + } + return m.name, nil +} + +type MockProcessSupplier struct { + processMap map[int32]AbstractProcess + errorToReturn error +} + +func (m *MockProcessSupplier) NewProcess(pid int32) (AbstractProcess, error) { + if m.errorToReturn != nil { + return nil, m.errorToReturn + } + return m.processMap[pid], nil +} + +func NewMockProcessSupplier(processMap map[int32]AbstractProcess, err error) *MockProcessSupplier { + return &MockProcessSupplier{processMap: processMap, errorToReturn: err} +} + func TestUnknownShell(t *testing.T) { - // Given - mockCommandExecutor := NewMockCommandRunnerWithOutputErr("", "", nil) - CommandRunner = mockCommandExecutor - originalStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w + tests := []struct { + name string + processSupplier AbstractProcessSupplier + expectedShellType string + }{ + { + "failure to get process details for given pid", + NewMockProcessSupplier(map[int32]AbstractProcess{}, errors.New("Unable to find Process for given pid")), + "", + }, + { + "failure while getting name of process", + NewMockProcessSupplier(map[int32]AbstractProcess{ + 1003: &MockedProcess{ + name: "crc", + ppid: 1002, + }, + 1002: &MockedProcess{ + nameGetFails: true, + }, + }, nil), + "", + }, + { + "failure while getting ppid of process", + NewMockProcessSupplier(map[int32]AbstractProcess{ + 1003: &MockedProcess{ + name: "crc", + ppid: 1002, + }, + 1002: &MockedProcess{ + ppidGetFails: true, + }, + }, nil), + "", + }, + { + "failure while getting ppid of process", + NewMockProcessSupplier(map[int32]AbstractProcess{ + 1003: &MockedProcess{ + name: "crc", + ppid: 1002, + }, + 1002: &MockedProcess{ + name: "unknown", + ppid: 1001, + }, + }, nil), + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + getCurrentProcessID = func() int { + return 1003 + } + ProcessDepthLimit = 2 + processSupplier = tt.processSupplier + originalStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w - // When - shell, err := detect() + // When + shell, err := detect() - // Then - assert.Equal(t, err, ErrUnknownShell) - err = w.Close() - assert.NoError(t, err) - os.Stdout = originalStdout - var buf bytes.Buffer - nBytesRead, err := buf.ReadFrom(r) - assert.NoError(t, err) - assert.Greater(t, nBytesRead, int64(0)) - assert.Equal(t, "The default lines below are for a sh/bash shell, you can specify the shell you're using, with the --shell flag.\n\n", buf.String()) - assert.Equal(t, "ps", mockCommandExecutor.commandName) - assert.Equal(t, []string{"-o", "pid=,comm="}, mockCommandExecutor.commandArgs) - assert.Empty(t, shell) + // Then + assert.Equal(t, err, ErrUnknownShell) + err = w.Close() + assert.NoError(t, err) + os.Stdout = originalStdout + var buf bytes.Buffer + nBytesRead, err := buf.ReadFrom(r) + assert.NoError(t, err) + assert.Greater(t, nBytesRead, int64(0)) + assert.Equal(t, "The default lines below are for a sh/bash shell, you can specify the shell you're using, with the --shell flag.\n\n", buf.String()) + assert.Empty(t, shell) + }) + } } -func TestDetect_GivenPsOutputContainsShell_ThenReturnShellProcessWithMostRecentPid(t *testing.T) { +func TestDetect_GivenProcessTree_ThenReturnShellProcessWithCorrespondingParentPID(t *testing.T) { tests := []struct { name string - psCommandOutput string + psMap map[int32]AbstractProcess expectedShellType string }{ { "bash shell, then detect bash shell", - " 31162 ps\n 435 bash", + map[int32]AbstractProcess{ + 1003: &MockedProcess{ + name: "crc", + ppid: 1002, + }, + 1002: &MockedProcess{ + name: "bash", + ppid: 1000, + }, + }, "bash", }, { "zsh shell, then detect zsh shell", - " 31259 ps\n31253 zsh", + map[int32]AbstractProcess{ + 1003: &MockedProcess{ + name: "crc", + ppid: 1002, + }, + 1002: &MockedProcess{ + name: "zsh", + ppid: 1000, + }, + }, "zsh", }, { "fish shell, then detect fish shell", - " 31372 ps\n 31352 fish", - "fish", - }, - {"bash and zsh shell, then detect zsh with more recent process id", - " 31259 ps\n 31253 zsh\n 435 bash", - "zsh", - }, - {"bash and fish shell, then detect fish shell with more recent process id", - " 31372 ps\n 31352 fish\n 435 bash", + map[int32]AbstractProcess{ + 1003: &MockedProcess{ + name: "crc", + ppid: 1002, + }, + 1002: &MockedProcess{ + name: "fish", + ppid: 1000, + }, + }, "fish", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Given - mockCommandExecutor := NewMockCommandRunnerWithOutputErr(tt.psCommandOutput, "", nil) - CommandRunner = mockCommandExecutor - + getCurrentProcessID = func() int { + return 1003 + } + processSupplier = NewMockProcessSupplier(tt.psMap, nil) // When shell, err := detect() // Then - assert.Equal(t, "ps", mockCommandExecutor.commandName) - assert.Equal(t, []string{"-o", "pid=,comm="}, mockCommandExecutor.commandArgs) assert.Equal(t, tt.expectedShellType, shell) assert.NoError(t, err) }) diff --git a/pkg/os/shell/shell_windows.go b/pkg/os/shell/shell_windows.go index 85f64f6228..a513f80435 100644 --- a/pkg/os/shell/shell_windows.go +++ b/pkg/os/shell/shell_windows.go @@ -2,9 +2,12 @@ package shell import ( "fmt" + "github.com/crc-org/crc/v2/pkg/crc/logging" "math" "os" "path/filepath" + "sort" + "strconv" "strings" "syscall" "unsafe" @@ -101,3 +104,83 @@ func detect() (string, error) { return shellType(shell, "cmd"), nil } + +// detectShellByInvokingCommand is a utility method that tries to detect current shell in use by invoking `ps` command. +// This method is extracted so that it could be used by unix systems as well as Windows (in case of WSL). It executes +// the command provided in the method arguments and then passes the output to inspectProcessOutputForRecentlyUsedShell +// for evaluation. +// +// It receives two arguments: +// - defaultShell : default shell to revert back to in case it's unable to detect. +// - command: command to be executed +// - args: a string array containing command arguments +// +// It returns a string value representing current shell. +func detectShellByInvokingCommand(defaultShell string, command string, args []string) string { + stdOut, _, err := CommandRunner.Run(command, args...) + if err != nil { + return defaultShell + } + + detectedShell := inspectProcessOutputForRecentlyUsedShell(stdOut) + if detectedShell == "" { + return defaultShell + } + logging.Debugf("Detected shell: %s", detectedShell) + return detectedShell +} + +// inspectProcessOutputForRecentlyUsedShell inspects output of ps command to detect currently active shell session. +// +// It parses the output into a struct, filters process types by name then reverse sort it with pid and returns the first element. +// +// It takes one argument: +// +// - psCommandOutput: output of ps command executed on a particular shell session +// +// It returns: +// +// - a string value (one of `zsh`, `bash` or `fish`) for current shell environment in use. If it's not able to determine +// underlying shell type, it returns and empty string. +// +// This method tries to check all processes open and filters out shell sessions (one of `zsh`, `bash` or `fish) +// It then returns first shell process. +// +// For example, if ps command gives this output: +// +// 2908 ps +// 2889 fish +// 823 bash +// +// Then this method would return `fish` as it's the first shell process. +func inspectProcessOutputForRecentlyUsedShell(psCommandOutput string) string { + type ProcessOutput struct { + processID int + output string + } + var processOutputs []ProcessOutput + lines := strings.Split(psCommandOutput, "\n") + for _, line := range lines { + lineParts := strings.Split(strings.TrimSpace(line), " ") + if len(lineParts) == 2 && (strings.Contains(lineParts[1], "zsh") || + strings.Contains(lineParts[1], "bash") || + strings.Contains(lineParts[1], "fish")) { + parsedProcessID, err := strconv.Atoi(lineParts[0]) + if err == nil { + processOutputs = append(processOutputs, ProcessOutput{ + processID: parsedProcessID, + output: lineParts[1], + }) + } + } + } + // Reverse sort the processes by PID (higher to lower) + sort.Slice(processOutputs, func(i, j int) bool { + return processOutputs[i].processID > processOutputs[j].processID + }) + + if len(processOutputs) > 0 { + return processOutputs[0].output + } + return "" +}