diff --git a/wsus/wsus_test.go b/wsus/wsus_test.go index 1e40ab4..21128eb 100644 --- a/wsus/wsus_test.go +++ b/wsus/wsus_test.go @@ -15,20 +15,10 @@ package wsus import ( - "fmt" - "os" - "os/exec" "reflect" "testing" ) -var ( - stoppedService = "Wecsvc" // This service should not be running by default. - runningService = "w32time" // This service should always be running by default. - testStoppedService = "TestService" // A test service that will get created. - badService = "veryfakeservice" // This service should not exist. -) - func TestSortedKeys(t *testing.T) { for _, tt := range []struct { in map[int]string @@ -45,7 +35,6 @@ func TestSortedKeys(t *testing.T) { } func TestSet(t *testing.T) { - clnUpdateFolder = func(name string) error { return nil } w := WSUS{ CurrentServer: "", ServerSelection: 0, @@ -70,90 +59,3 @@ func TestSet(t *testing.T) { } } } - -func TestCleanUpdateFolder(t *testing.T) { - // This is a directory that shouldn't exist. - badDir := os.Getenv("windir") + `\rollinginthedeep` - // The actual expected directory that will be cleared. - updateDir := os.Getenv("windir") + `\SoftwareDistribution` - for _, tt := range []struct { - dir string - wantErr bool - }{ - {badDir, true}, - {updateDir, false}, - } { - if err := cleanUpdateFolder(tt.dir); (err != nil) != tt.wantErr { - t.Errorf("cleanUpdateFolder(%s) = %v, want error presence = %v", tt.dir, err, tt.wantErr) - } - } -} - -func TestCleanUpdateFolderWithOpenFile(t *testing.T) { - strtService = func(name string) error { return nil } - stpService = func(name string) error { return nil } - dir, err := os.MkdirTemp("", "setfiretotherain") - if err != nil { - t.Fatalf("TestCleanUpdateFolderWithOpenFile setup failed: could not make temp dir: %v", err) - } - defer os.RemoveAll(dir) - - f, err := os.CreateTemp(dir, "lyrics") - if err != nil { - t.Fatalf("TestCleanUpdateFolderWithOpenFile setup failed: could not create temp file in %s: %v", dir, err) - } - - // Hold the file open so it can't be deleted. - d, err := os.Open(f.Name()) - if err != nil { - t.Fatalf("TestCleanUpdateFolderWithOpenFile setup failed: could not open temp file(%s): %v", f.Name(), err) - } - if err := cleanUpdateFolder(dir); err == nil { - t.Errorf("cleanUpdateFolder(%v) returned nil, want error: %v", dir, err) - } - d.Close() -} - -func TestStartService(t *testing.T) { - for _, tt := range []struct { - service string - wantErr bool - }{ - {stoppedService, false}, - {runningService, false}, - {badService, true}, - } { - if err := startService(tt.service); (err != nil) != tt.wantErr { - t.Errorf("StartService(%s) = %v, want error presence = %v", tt.service, err, tt.wantErr) - } - } -} - -func createStoppedServiceHelper(t *testing.T) { - t.Helper() - cmdPath := `C:\Windows\System32\cmd.exe` - pwshPath := `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe` - serviceParams := fmt.Sprintf("New-Service -Name %s -BinaryPathName '%s'", testStoppedService, cmdPath) - cmd := exec.Command(pwshPath, serviceParams) - if _, err := cmd.CombinedOutput(); err != nil { - t.Fatalf("createStoppedServiceHelper setup failed: could not create test service: %v", err) - } -} - -func TestStopService(t *testing.T) { - createStoppedServiceHelper(t) - for _, tt := range []struct { - service string - wantErr bool - }{ - {runningService, false}, - {testStoppedService, false}, - {badService, true}, - } { - err := stopService(tt.service) - gotErr := err != nil - if gotErr != tt.wantErr { - t.Errorf("StopService(%s) = %v, want error presence = %v", tt.service, err, tt.wantErr) - } - } -} diff --git a/wsus/wsus_windows.go b/wsus/wsus_windows.go index b4a69e9..37222ea 100644 --- a/wsus/wsus_windows.go +++ b/wsus/wsus_windows.go @@ -20,24 +20,15 @@ package wsus import ( "fmt" "net/http" - "os" - "path/filepath" "time" "github.com/google/cabbie/cablib" "golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/svc/eventlog" - "golang.org/x/sys/windows/svc/mgr" - "golang.org/x/sys/windows/svc" ) var ( wlog *eventlog.Log - - // Test Stubs - clnUpdateFolder = cleanUpdateFolder - stpService = stopService - strtService = startService ) func responseTime(name string) time.Duration { @@ -62,93 +53,6 @@ func responseTime(name string) time.Duration { return time.Since(start) } -func cleanUpdateFolder(dir string) error { - if err := stpService("wuauserv"); err != nil { - return fmt.Errorf("StopService failure: %w", err) - } - d, err := os.Open(dir) - if err != nil { - return fmt.Errorf("os.Open(%s): %w", dir, err) - } - defer d.Close() - // Read all object names in the directory. - objects, err := d.Readdirnames(-1) - if err != nil { - return fmt.Errorf("Readdirnames: %w", err) - } - // Loop through the slice and delete each object. - for _, object := range objects { - if err := os.RemoveAll(filepath.Join(dir, object)); err != nil { - return fmt.Errorf("os.RemoveAll(%s): %w", filepath.Join(dir, object), err) - } - } - if err := strtService("wuauserv"); err != nil { - return fmt.Errorf("StartService failure: %w", err) - } - return nil -} - -// stopService attempts to stop local system services. -func stopService(name string) error { - m, err := mgr.Connect() - if err != nil { - return fmt.Errorf("failed to connect to service manager: %w", err) - } - defer m.Disconnect() - s, err := m.OpenService(name) - if err != nil { - return fmt.Errorf("failed to open service (%s): %w", name, err) - } - defer s.Close() - // Although s.Control returns stat, if the service is already stopped it returns an error. - stat, err := s.Query() - if err != nil { - return fmt.Errorf("failed to query service (%s): %w", s.Name, err) - } - if stat.State == svc.Stopped { - return nil - } - stat, err = s.Control(svc.Stop) - if err != nil { - return fmt.Errorf("failed to send control message (%s): %w", s.Name, err) - } - retry := 0 - for stat.State != svc.Stopped { - time.Sleep(5 * time.Second) - retry++ - if retry > 12 { - return fmt.Errorf("timed out waiting for service %s to stop", s.Name) - } - stat, err = s.Query() - if err != nil { - return fmt.Errorf("failed to query service (%s): %w", s.Name, err) - } - } - return nil -} - -// startService attempts to start local system services. -func startService(name string) error { - m, err := mgr.Connect() - if err != nil { - return fmt.Errorf("failed to connect to service manager: %w", err) - } - defer m.Disconnect() - s, err := m.OpenService(name) - if err != nil { - return fmt.Errorf("failed to open service (%s): %w", name, err) - } - defer s.Close() - stat, err := s.Query() - if err != nil { - return fmt.Errorf("failed to query service (%s): %w", s.Name, err) - } - if stat.State == svc.Running { - return nil - } - return s.Start() -} - // Init will initialize the local update client with the desired WSUS config. func Init(servers []string) (*WSUS, error) { var w WSUS @@ -231,13 +135,8 @@ func (w *WSUS) Set(index int) error { return err } defer sk.Close() - if err := sk.SetDWordValue("UseWUServer", 1); err != nil { - return err - } - // The cleanUpdateFolder runs to fix error 0x80244011 from being thrown during update - // runs after WSUS servers are set. - updateDir := os.Getenv("windir") + `\SoftwareDistribution` - return clnUpdateFolder(updateDir) + + return sk.SetDWordValue("UseWUServer", 1) } // Clear sets WSUS client configurations back to Windows defaults.