diff options
| author | James O'Doherty <james@theodohertyfamily.com> | 2026-06-13 11:51:04 -0400 |
|---|---|---|
| committer | James O'Doherty <james@theodohertyfamily.com> | 2026-06-13 11:51:04 -0400 |
| commit | 29621ecbd1e77e6e1a70b6b3ea8fbe3a56e47df3 (patch) | |
| tree | fa54976bbb0c4e9db59c983e7cb4e60c5119d18b /internal/wireguard | |
| parent | f8afb7d5889f5c8b6ea256fd078fa8426d21c7be (diff) | |
refactor: implement dependency injection and enable parallel testing
This commit refactors the core system operations to use a manager-based
dependency injection pattern, eliminating global state and resolving
data races in the test suite.
Architecture:
- Introduced NetworkManager and NetworkOps interface in internal/network
to abstract netlink calls.
- Introduced MountOps and FileSystem interfaces in internal/namespace
to abstract mount and filesystem operations.
- Introduced TunnelManager in internal/wireguard to coordinate tunnel
lifecycle using the new abstractions.
- Updated internal/cli and internal/manager to use these managers.
Testing:
- Restored t.Parallel() to unit tests in internal/network and
internal/wireguard.
- Implemented setupParallelEnv and an enhanced mockFS in
wireguard_unit_test.go to ensure complete test isolation.
- Added bootstrap_test.go to verify launcher preparation logic in
internal/namespace without requiring syscall.Exec.
- Resolved data races in internal/network tests.
CLI:
- Added support for -h, --help, and -help flags for the main command.
Verification:
- Passed all tests (unit, integration, E2E).
- Verified zero data races with 'go test -race'.
- Passed golangci-lint and go vet.
Diffstat (limited to 'internal/wireguard')
| -rw-r--r-- | internal/wireguard/wireguard.go | 64 | ||||
| -rw-r--r-- | internal/wireguard/wireguard_unit_test.go | 195 |
2 files changed, 233 insertions, 26 deletions
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 8ffe794..67cc211 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -51,8 +51,24 @@ type Tunnel struct { dnsFile string } +// TunnelManager coordinates the creation and management of a WireGuard tunnel. +type TunnelManager struct { + MountOps namespace.MountOps + FS namespace.FileSystem + Net *network.NetworkManager +} + +// NewTunnelManager creates a new TunnelManager with production defaults. +func NewTunnelManager() *TunnelManager { + return &TunnelManager{ + MountOps: namespace.DefaultMountOps, + FS: namespace.DefaultFS, + Net: network.NewNetworkManager(), + } +} + // StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes. -func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) { +func (tm *TunnelManager) StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) { var cleanups []func() defer func() { if err != nil { @@ -66,11 +82,11 @@ func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsS tunName := "tun0" mtu := 1420 - if err := unix.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil { + if err := tm.MountOps.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil { fmt.Printf("warning: failed to make mount namespace private: %v\n", err) } - if err := BlockHostServices(pm, profile); err != nil { + if err := tm.BlockHostServices(pm, profile); err != nil { fmt.Printf("warning: failed to block host services: %v\n", err) } @@ -118,18 +134,18 @@ func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsS } // 4. Configure network interface - if err := network.ConfigureInterface(tunName, cfg.Address, mtu); err != nil { + if err := tm.Net.ConfigureInterface(tunName, cfg.Address, mtu); err != nil { return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err) } var dnsFile string profileDir := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile) - if path, err := ConfigureResolvConf(dnsServer, profileDir); err != nil { + if path, err := tm.ConfigureResolvConf(dnsServer, profileDir); err != nil { fmt.Printf("warning: failed to configure DNS resolver: %v\n", err) } else { dnsFile = path cleanups = append(cleanups, func() { - if err := UnmountResolvConf(dnsFile); err != nil { + if err := tm.UnmountResolvConf(dnsFile); err != nil { fmt.Printf("warning: failed to unmount resolv.conf during cleanup: %v\n", err) } }) @@ -143,12 +159,12 @@ func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsS } // Close shuts down the userspace WireGuard device and closes the TUN interface. -func (t *Tunnel) Close() { +func (t *Tunnel) Close(tm *TunnelManager) { if t.Device != nil { t.Device.Close() } if t.dnsFile != "" { - if err := UnmountResolvConf(t.dnsFile); err != nil { + if err := tm.UnmountResolvConf(t.dnsFile); err != nil { fmt.Printf("warning: failed to unmount resolv.conf: %v\n", err) } } @@ -216,14 +232,12 @@ func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) { } // ConfigureResolvConf creates a temporary resolv.conf file and bind-mounts it to /etc/resolv.conf. -func ConfigureResolvConf(dns string, profileDir string) (string, error) { +func (tm *TunnelManager) ConfigureResolvConf(dns string, profileDir string) (string, error) { if dns == "" { return "", nil } - // Create the temporary resolv.conf file within the profile directory - // so it can be cleaned up during namespace unpinning. - tmpFile, err := os.CreateTemp(profileDir, "resolvconf") + tmpFile, err := tm.FS.CreateTemp(profileDir, "resolvconf") if err != nil { return "", fmt.Errorf("failed to create temp resolv.conf in %s: %w", profileDir, err) } @@ -235,11 +249,11 @@ func ConfigureResolvConf(dns string, profileDir string) (string, error) { } _ = tmpFile.Close() - if err := unix.Mount(launcherPath, "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil { + if err := tm.MountOps.Mount(launcherPath, "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil { return "", fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", launcherPath, err) } - if err := unix.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil { + if err := tm.MountOps.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil { fmt.Printf("warning: failed to make /etc/resolv.conf mount private: %v\n", err) } @@ -247,16 +261,14 @@ func ConfigureResolvConf(dns string, profileDir string) (string, error) { } // UnmountResolvConf unmounts the bind-mounted resolv.conf and removes the temporary file. -func UnmountResolvConf(path string) error { +func (tm *TunnelManager) UnmountResolvConf(path string) error { if path == "" { return nil } - // Attempt to unmount. If it fails, it might already be unmounted - // or the namespace might be gone. - _ = unix.Unmount("/etc/resolv.conf", unix.MNT_DETACH) + _ = tm.MountOps.Unmount("/etc/resolv.conf", unix.MNT_DETACH) - if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + if err := tm.FS.Remove(path); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove temp resolv.conf file %s: %w", path, err) } return nil @@ -264,18 +276,18 @@ func UnmountResolvConf(path string) error { // BlockHostServices bind-mounts empty files/directories over sensitive host services // to prevent access from within the isolated namespace. -func BlockHostServices(pm *paths.PathManager, profile string) error { +func (tm *TunnelManager) BlockHostServices(pm *paths.PathManager, profile string) error { blockDirBase := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "block") - if err := os.MkdirAll(blockDirBase, 0755); err != nil { + if err := tm.FS.MkdirAll(blockDirBase, 0755); err != nil { return fmt.Errorf("failed to create block base directory: %w", err) } - tmpDir, err := os.MkdirTemp(blockDirBase, "dir-") + tmpDir, err := tm.FS.MkdirTemp(blockDirBase, "dir-") if err != nil { return fmt.Errorf("failed to create temp block dir: %w", err) } - tmpFile, err := os.CreateTemp(blockDirBase, "file-") + tmpFile, err := tm.FS.CreateTemp(blockDirBase, "file-") if err != nil { return fmt.Errorf("failed to create temp block file: %w", err) } @@ -283,16 +295,16 @@ func BlockHostServices(pm *paths.PathManager, profile string) error { _ = tmpFile.Close() for _, p := range namespace.GetBlockPaths() { - stat, err := os.Stat(p) + stat, err := tm.FS.Stat(p) if err == nil { source := tmpFileName if stat.IsDir() { source = tmpDir } - if err := unix.Mount(source, p, "", unix.MS_BIND, ""); err != nil { + if err := tm.MountOps.Mount(source, p, "", unix.MS_BIND, ""); err != nil { fmt.Printf("warning: failed to bind-mount block over %s: %v\n", p, err) } else { - _ = unix.Mount("", p, "", unix.MS_PRIVATE, "") + _ = tm.MountOps.Mount("", p, "", unix.MS_PRIVATE, "") } } } diff --git a/internal/wireguard/wireguard_unit_test.go b/internal/wireguard/wireguard_unit_test.go new file mode 100644 index 0000000..1ad7f65 --- /dev/null +++ b/internal/wireguard/wireguard_unit_test.go @@ -0,0 +1,195 @@ +//go:build linux + +package wireguard + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "git.theodohertyfamily.com/wg-wrap/internal/namespace" + "git.theodohertyfamily.com/wg-wrap/internal/paths" + "golang.org/x/sys/unix" +) + +// mockMountOps records mount calls for verification. +type mockMountOps struct { + mounts []mountCall + unmounts []string +} + +type mountCall struct { + source string + target string + flags uintptr +} + +func (m *mockMountOps) Mount(source, target, fstype string, flags uintptr, data string) error { + m.mounts = append(m.mounts, mountCall{source, target, flags}) + return nil +} + +func (m *mockMountOps) Unmount(target string, flags int) error { + m.unmounts = append(m.unmounts, target) + return nil +} + +// mockFS mimics a filesystem using a temporary directory. +type mockFS struct { + root string +} + +func (m *mockFS) fullPath(path string) string { + if filepath.IsAbs(path) { + path = strings.TrimPrefix(path, "/") + } + return filepath.Join(m.root, path) +} + +func (m *mockFS) Stat(name string) (os.FileInfo, error) { + return os.Stat(m.fullPath(name)) +} + +func (m *mockFS) MkdirAll(path string, perm os.FileMode) error { + return os.MkdirAll(m.fullPath(path), perm) +} + +func (m *mockFS) CreateTemp(dir, pattern string) (*os.File, error) { + return os.CreateTemp(m.fullPath(dir), pattern) +} + +func (m *mockFS) MkdirTemp(dir, pattern string) (string, error) { + dirPath := m.fullPath(dir) + res, err := os.MkdirTemp(dirPath, pattern) + if err != nil { + return "", err + } + // Return path relative to root for consistency if needed, but OS calls return absolute. + // The TunnelManager expects a path it can pass to Remove() later. + // If we return absolute, mockFS.Remove must handle it. + return res, nil +} + +func (m *mockFS) Remove(name string) error { + // If the path is absolute and starts with our root, we can remove it directly. + // Otherwise, we use fullPath to ensure it's within root. + if filepath.IsAbs(name) && strings.HasPrefix(name, m.root) { + return os.Remove(name) + } + return os.Remove(m.fullPath(name)) +} + +// setupParallelEnv creates a clean, isolated test environment for a single test. +func setupParallelEnv(t *testing.T) (*paths.PathManager, *mockMountOps, *mockFS, *TunnelManager) { + t.Helper() + + tmpDir := t.TempDir() + hostRoot := t.TempDir() + + pm := paths.NewPathManager("", tmpDir) + mockMounts := &mockMountOps{} + mockFS := &mockFS{root: hostRoot} + + tm := &TunnelManager{ + MountOps: mockMounts, + FS: mockFS, + } + + return pm, mockMounts, mockFS, tm +} + +func TestBlockHostServices_Logic(t *testing.T) { + t.Parallel() + + pm, mockMounts, mockFS, tm := setupParallelEnv(t) + profile := "test-profile" + + blockPaths := namespace.GetBlockPaths() + for _, p := range blockPaths { + fullPath := mockFS.fullPath(p) + _ = os.MkdirAll(filepath.Dir(fullPath), 0755) + _ = os.WriteFile(fullPath, []byte("data"), 0644) + } + + err := tm.BlockHostServices(pm, profile) + if err != nil { + t.Fatalf("BlockHostServices failed: %v", err) + } + + for _, p := range blockPaths { + found := false + for _, m := range mockMounts.mounts { + if m.target == p { + found = true + break + } + } + if !found { + t.Errorf("path %s was not blocked", p) + } + } +} + +func TestConfigureResolvConf_Logic(t *testing.T) { + t.Parallel() + + _, mockMounts, mockFS, tm := setupParallelEnv(t) + + // Use a directory relative to mockFS.root + profileDir := "profile-dir" + _ = mockFS.MkdirAll(profileDir, 0755) + dnsServer := "1.1.1.1" + + path, err := tm.ConfigureResolvConf(dnsServer, profileDir) + if err != nil { + t.Fatalf("ConfigureResolvConf failed: %v", err) + } + + // path is returned from os.CreateTemp, which is absolute. + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("resolv.conf file was not created at %s", path) + } + + if len(mockMounts.mounts) < 2 { + t.Fatalf("expected 2 mounts, got %d", len(mockMounts.mounts)) + } + + if mockMounts.mounts[0].target != "/etc/resolv.conf" || mockMounts.mounts[0].flags != unix.MS_BIND { + t.Errorf("expected BIND mount to /etc/resolv.conf, got %+v", mockMounts.mounts[0]) + } + + if mockMounts.mounts[1].target != "/etc/resolv.conf" || mockMounts.mounts[1].flags != unix.MS_PRIVATE { + t.Errorf("expected PRIVATE mount to /etc/resolv.conf, got %+v", mockMounts.mounts[1]) + } +} + +func TestUnmountResolvConf_Logic(t *testing.T) { + t.Parallel() + + _, mockMounts, mockFS, tm := setupParallelEnv(t) + + // Create a file within the mock FS root + profileDir := "profile-dir" + _ = mockFS.MkdirAll(profileDir, 0755) + tmpFile := "resolv.conf.tmp" + fullPath := mockFS.fullPath(filepath.Join(profileDir, tmpFile)) + + err := os.WriteFile(fullPath, []byte("test"), 0644) + if err != nil { + t.Fatalf("failed to write temp resolv.conf: %v", err) + } + + err = tm.UnmountResolvConf(fullPath) + if err != nil { + t.Errorf("UnmountResolvConf failed: %v", err) + } + + if len(mockMounts.unmounts) == 0 || mockMounts.unmounts[0] != "/etc/resolv.conf" { + t.Errorf("expected unmount of /etc/resolv.conf, got %+v", mockMounts.unmounts) + } + + if _, err := os.Stat(fullPath); !os.IsNotExist(err) { + t.Error("temporary resolv.conf file was not removed") + } +} |
