Skip to content

Commit

Permalink
Modify WSUS library to include cleanup of the software distribution f…
Browse files Browse the repository at this point in the history
…older to resolve windows update errors if WSUS servers are set.

PiperOrigin-RevId: 623629239
  • Loading branch information
mjoliver authored and copybara-github committed Apr 10, 2024
1 parent 2b3ad5c commit b78ed00
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 201 deletions.
98 changes: 0 additions & 98 deletions wsus/wsus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,10 @@
package wsus

import (
"fmt"
"os"
"os/exec"
"reflect"
"testing"
)

var (
stoppedService = "Wecsvc" // This service should not be running by default.
runningService = "w32time" // This service should always be running by default.
testStoppedService = "TestService" // A test service that will get created.
badService = "veryfakeservice" // This service should not exist.
)

func TestSortedKeys(t *testing.T) {
for _, tt := range []struct {
in map[int]string
Expand All @@ -45,7 +35,6 @@ func TestSortedKeys(t *testing.T) {
}

func TestSet(t *testing.T) {
clnUpdateFolder = func(name string) error { return nil }
w := WSUS{
CurrentServer: "",
ServerSelection: 0,
Expand All @@ -70,90 +59,3 @@ func TestSet(t *testing.T) {
}
}
}

func TestCleanUpdateFolder(t *testing.T) {
// This is a directory that shouldn't exist.
badDir := os.Getenv("windir") + `\rollinginthedeep`
// The actual expected directory that will be cleared.
updateDir := os.Getenv("windir") + `\SoftwareDistribution`
for _, tt := range []struct {
dir string
wantErr bool
}{
{badDir, true},
{updateDir, false},
} {
if err := cleanUpdateFolder(tt.dir); (err != nil) != tt.wantErr {
t.Errorf("cleanUpdateFolder(%s) = %v, want error presence = %v", tt.dir, err, tt.wantErr)
}
}
}

func TestCleanUpdateFolderWithOpenFile(t *testing.T) {
strtService = func(name string) error { return nil }
stpService = func(name string) error { return nil }
dir, err := os.MkdirTemp("", "setfiretotherain")
if err != nil {
t.Fatalf("TestCleanUpdateFolderWithOpenFile setup failed: could not make temp dir: %v", err)
}
defer os.RemoveAll(dir)

f, err := os.CreateTemp(dir, "lyrics")
if err != nil {
t.Fatalf("TestCleanUpdateFolderWithOpenFile setup failed: could not create temp file in %s: %v", dir, err)
}

// Hold the file open so it can't be deleted.
d, err := os.Open(f.Name())
if err != nil {
t.Fatalf("TestCleanUpdateFolderWithOpenFile setup failed: could not open temp file(%s): %v", f.Name(), err)
}
if err := cleanUpdateFolder(dir); err == nil {
t.Errorf("cleanUpdateFolder(%v) returned nil, want error: %v", dir, err)
}
d.Close()
}

func TestStartService(t *testing.T) {
for _, tt := range []struct {
service string
wantErr bool
}{
{stoppedService, false},
{runningService, false},
{badService, true},
} {
if err := startService(tt.service); (err != nil) != tt.wantErr {
t.Errorf("StartService(%s) = %v, want error presence = %v", tt.service, err, tt.wantErr)
}
}
}

func createStoppedServiceHelper(t *testing.T) {
t.Helper()
cmdPath := `C:\Windows\System32\cmd.exe`
pwshPath := `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`
serviceParams := fmt.Sprintf("New-Service -Name %s -BinaryPathName '%s'", testStoppedService, cmdPath)
cmd := exec.Command(pwshPath, serviceParams)
if _, err := cmd.CombinedOutput(); err != nil {
t.Fatalf("createStoppedServiceHelper setup failed: could not create test service: %v", err)
}
}

func TestStopService(t *testing.T) {
createStoppedServiceHelper(t)
for _, tt := range []struct {
service string
wantErr bool
}{
{runningService, false},
{testStoppedService, false},
{badService, true},
} {
err := stopService(tt.service)
gotErr := err != nil
if gotErr != tt.wantErr {
t.Errorf("StopService(%s) = %v, want error presence = %v", tt.service, err, tt.wantErr)
}
}
}
105 changes: 2 additions & 103 deletions wsus/wsus_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,15 @@ package wsus
import (
"fmt"
"net/http"
"os"
"path/filepath"
"time"

"github.com/google/cabbie/cablib"
"golang.org/x/sys/windows/registry"
"golang.org/x/sys/windows/svc/eventlog"
"golang.org/x/sys/windows/svc/mgr"
"golang.org/x/sys/windows/svc"
)

var (
wlog *eventlog.Log

// Test Stubs
clnUpdateFolder = cleanUpdateFolder
stpService = stopService
strtService = startService
)

func responseTime(name string) time.Duration {
Expand All @@ -62,93 +53,6 @@ func responseTime(name string) time.Duration {
return time.Since(start)
}

func cleanUpdateFolder(dir string) error {
if err := stpService("wuauserv"); err != nil {
return fmt.Errorf("StopService failure: %w", err)
}
d, err := os.Open(dir)
if err != nil {
return fmt.Errorf("os.Open(%s): %w", dir, err)
}
defer d.Close()
// Read all object names in the directory.
objects, err := d.Readdirnames(-1)
if err != nil {
return fmt.Errorf("Readdirnames: %w", err)
}
// Loop through the slice and delete each object.
for _, object := range objects {
if err := os.RemoveAll(filepath.Join(dir, object)); err != nil {
return fmt.Errorf("os.RemoveAll(%s): %w", filepath.Join(dir, object), err)
}
}
if err := strtService("wuauserv"); err != nil {
return fmt.Errorf("StartService failure: %w", err)
}
return nil
}

// stopService attempts to stop local system services.
func stopService(name string) error {
m, err := mgr.Connect()
if err != nil {
return fmt.Errorf("failed to connect to service manager: %w", err)
}
defer m.Disconnect()
s, err := m.OpenService(name)
if err != nil {
return fmt.Errorf("failed to open service (%s): %w", name, err)
}
defer s.Close()
// Although s.Control returns stat, if the service is already stopped it returns an error.
stat, err := s.Query()
if err != nil {
return fmt.Errorf("failed to query service (%s): %w", s.Name, err)
}
if stat.State == svc.Stopped {
return nil
}
stat, err = s.Control(svc.Stop)
if err != nil {
return fmt.Errorf("failed to send control message (%s): %w", s.Name, err)
}
retry := 0
for stat.State != svc.Stopped {
time.Sleep(5 * time.Second)
retry++
if retry > 12 {
return fmt.Errorf("timed out waiting for service %s to stop", s.Name)
}
stat, err = s.Query()
if err != nil {
return fmt.Errorf("failed to query service (%s): %w", s.Name, err)
}
}
return nil
}

// startService attempts to start local system services.
func startService(name string) error {
m, err := mgr.Connect()
if err != nil {
return fmt.Errorf("failed to connect to service manager: %w", err)
}
defer m.Disconnect()
s, err := m.OpenService(name)
if err != nil {
return fmt.Errorf("failed to open service (%s): %w", name, err)
}
defer s.Close()
stat, err := s.Query()
if err != nil {
return fmt.Errorf("failed to query service (%s): %w", s.Name, err)
}
if stat.State == svc.Running {
return nil
}
return s.Start()
}

// Init will initialize the local update client with the desired WSUS config.
func Init(servers []string) (*WSUS, error) {
var w WSUS
Expand Down Expand Up @@ -231,13 +135,8 @@ func (w *WSUS) Set(index int) error {
return err
}
defer sk.Close()
if err := sk.SetDWordValue("UseWUServer", 1); err != nil {
return err
}
// The cleanUpdateFolder runs to fix error 0x80244011 from being thrown during update
// runs after WSUS servers are set.
updateDir := os.Getenv("windir") + `\SoftwareDistribution`
return clnUpdateFolder(updateDir)

return sk.SetDWordValue("UseWUServer", 1)
}

// Clear sets WSUS client configurations back to Windows defaults.
Expand Down

0 comments on commit b78ed00

Please sign in to comment.