Skip to content

Commit

Permalink
refactor (shell) : Refactor unix/linux shell detection mechanism (crc…
Browse files Browse the repository at this point in the history
…-org#4562)

+ Move parsing ps output and picking recent pid approach to shell_windows. It would only
  be used in case of WSL.
+ Instead of parsing ps output and picking up recent pid. Inspect parent process
  of current process and keep going up until we find a shell process. Pick that
  shell process as currently active shell.

Signed-off-by: Rohan Kumar <[email protected]>
  • Loading branch information
rohanKanojia committed Jan 17, 2025
1 parent 8f40e84 commit f31fa0d
Show file tree
Hide file tree
Showing 6 changed files with 349 additions and 171 deletions.
83 changes: 0 additions & 83 deletions pkg/os/shell/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@ package shell
import (
"fmt"
"os"
"sort"
"strconv"
"strings"

"github.com/crc-org/crc/v2/pkg/crc/logging"
crcos "github.com/crc-org/crc/v2/pkg/os"
)

Expand Down Expand Up @@ -154,83 +151,3 @@ func IsWindowsSubsystemLinux() bool {
}
return false
}

// detectShellByInvokingCommand is a utility method that tries to detect current shell in use by invoking `ps` command.
// This method is extracted so that it could be used by unix systems as well as Windows (in case of WSL). It executes
// the command provided in the method arguments and then passes the output to inspectProcessOutputForRecentlyUsedShell
// for evaluation.
//
// It receives two arguments:
// - defaultShell : default shell to revert back to in case it's unable to detect.
// - command: command to be executed
// - args: a string array containing command arguments
//
// It returns a string value representing current shell.
func detectShellByInvokingCommand(defaultShell string, command string, args []string) string {
stdOut, _, err := CommandRunner.Run(command, args...)
if err != nil {
return defaultShell
}

detectedShell := inspectProcessOutputForRecentlyUsedShell(stdOut)
if detectedShell == "" {
return defaultShell
}
logging.Debugf("Detected shell: %s", detectedShell)
return detectedShell
}

// inspectProcessOutputForRecentlyUsedShell inspects output of ps command to detect currently active shell session.
//
// It parses the output into a struct, filters process types by name then reverse sort it with pid and returns the first element.
//
// It takes one argument:
//
// - psCommandOutput: output of ps command executed on a particular shell session
//
// It returns:
//
// - a string value (one of `zsh`, `bash` or `fish`) for current shell environment in use. If it's not able to determine
// underlying shell type, it returns and empty string.
//
// This method tries to check all processes open and filters out shell sessions (one of `zsh`, `bash` or `fish)
// It then returns first shell process.
//
// For example, if ps command gives this output:
//
// 2908 ps
// 2889 fish
// 823 bash
//
// Then this method would return `fish` as it's the first shell process.
func inspectProcessOutputForRecentlyUsedShell(psCommandOutput string) string {
type ProcessOutput struct {
processID int
output string
}
var processOutputs []ProcessOutput
lines := strings.Split(psCommandOutput, "\n")
for _, line := range lines {
lineParts := strings.Split(strings.TrimSpace(line), " ")
if len(lineParts) == 2 && (strings.Contains(lineParts[1], "zsh") ||
strings.Contains(lineParts[1], "bash") ||
strings.Contains(lineParts[1], "fish")) {
parsedProcessID, err := strconv.Atoi(lineParts[0])
if err == nil {
processOutputs = append(processOutputs, ProcessOutput{
processID: parsedProcessID,
output: lineParts[1],
})
}
}
}
// Reverse sort the processes by PID (higher to lower)
sort.Slice(processOutputs, func(i, j int) bool {
return processOutputs[i].processID > processOutputs[j].processID
})

if len(processOutputs) > 0 {
return processOutputs[0].output
}
return ""
}
47 changes: 0 additions & 47 deletions pkg/os/shell/shell_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,50 +179,3 @@ func TestConvertToWindowsSubsystemLinuxPath(t *testing.T) {
assert.Equal(t, "wsl", mockCommandExecutor.commandName)
assert.Equal(t, []string{"-e", "bash", "-c", "wslpath -a 'C:\\Users\\foo\\.crc\\bin\\oc'"}, mockCommandExecutor.commandArgs)
}

func TestInspectProcessForRecentlyUsedShell(t *testing.T) {
tests := []struct {
name string
psCommandOutput string
expectedShellType string
}{
{
"nothing provided, then return empty string",
"",
"",
},
{
"bash shell, then detect bash shell",
" 31162 ps\n 435 bash",
"bash",
},
{
"zsh shell, then detect zsh shell",
" 31259 ps\n31253 zsh",
"zsh",
},
{
"fish shell, then detect fish shell",
" 31372 ps\n 31352 fish",
"fish",
},
{"bash and zsh shell, then detect zsh with more recent process id",
" 31259 ps\n 31253 zsh\n 435 bash",
"zsh",
},
{"bash and fish shell, then detect fish shell with more recent process id",
" 31372 ps\n 31352 fish\n 435 bash",
"fish",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Given + When
result := inspectProcessOutputForRecentlyUsedShell(tt.psCommandOutput)
// Then
if result != tt.expectedShellType {
t.Errorf("%s inspectProcessOutputForRecentlyUsedShell() = %s; want %s", tt.name, result, tt.expectedShellType)
}
})
}
}
69 changes: 67 additions & 2 deletions pkg/os/shell/shell_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,85 @@ package shell
import (
"errors"
"fmt"
"os"
"path/filepath"

"github.com/shirou/gopsutil/v3/process"
"github.com/spf13/cast"
)

var (
ErrUnknownShell = errors.New("Error: Unknown shell")
ErrUnknownShell = errors.New("Error: Unknown shell")
getCurrentProcessID = os.Getpid
processSupplier AbstractProcessSupplier = &ProcessSupplier{}
ProcessDepthLimit = 10
)

type AbstractProcess interface {
Name() (string, error)
Ppid() (int32, error)
}

type AbstractProcessSupplier interface {
NewProcess(pid int32) (AbstractProcess, error)
}

type ProcessSupplier struct{}

func (p *ProcessSupplier) NewProcess(pid int32) (AbstractProcess, error) {
return process.NewProcess(pid)
}

// detect detects user's current shell.
func detect() (string, error) {
detectedShell := detectShellByInvokingCommand("", "ps", []string{"-o", "pid=,comm="})
detectedShell := detectShellByCheckingProcessTree(getCurrentProcessID())
if detectedShell == "" {
fmt.Printf("The default lines below are for a sh/bash shell, you can specify the shell you're using, with the --shell flag.\n\n")
return "", ErrUnknownShell
}

return filepath.Base(detectedShell), nil
}

func detectShellByCheckingProcessTree(pid int) string {
level := 0
for level < ProcessDepthLimit {
processName, err := getProcessName(pid)
if err != nil {
return ""
}
if processName == "zsh" || processName == "bash" || processName == "fish" {
return processName
}
pid, err = getParentProcessID(pid)
if err != nil {
return ""
}
level++
}
return ""
}

func getProcessName(pid int) (string, error) {
p, err := processSupplier.NewProcess(cast.ToInt32(pid))
if err != nil {
return "", err
}
processName, err := p.Name()
if err != nil {
return "", err
}
return processName, nil
}

func getParentProcessID(pid int) (int, error) {
p, err := processSupplier.NewProcess(cast.ToInt32(pid))
if err != nil {
return 0, err
}
ppid, err := p.Ppid()
if err != nil {
return 0, err
}
return cast.ToInt(ppid), nil
}
Loading

0 comments on commit f31fa0d

Please sign in to comment.