diff options
| author | James O'Doherty <james@theodohertyfamily.com> | 2026-06-13 11:51:04 -0400 |
|---|---|---|
| committer | James O'Doherty <james@theodohertyfamily.com> | 2026-06-13 11:51:04 -0400 |
| commit | 29621ecbd1e77e6e1a70b6b3ea8fbe3a56e47df3 (patch) | |
| tree | fa54976bbb0c4e9db59c983e7cb4e60c5119d18b | |
| parent | f8afb7d5889f5c8b6ea256fd078fa8426d21c7be (diff) | |
refactor: implement dependency injection and enable parallel testing
This commit refactors the core system operations to use a manager-based
dependency injection pattern, eliminating global state and resolving
data races in the test suite.
Architecture:
- Introduced NetworkManager and NetworkOps interface in internal/network
to abstract netlink calls.
- Introduced MountOps and FileSystem interfaces in internal/namespace
to abstract mount and filesystem operations.
- Introduced TunnelManager in internal/wireguard to coordinate tunnel
lifecycle using the new abstractions.
- Updated internal/cli and internal/manager to use these managers.
Testing:
- Restored t.Parallel() to unit tests in internal/network and
internal/wireguard.
- Implemented setupParallelEnv and an enhanced mockFS in
wireguard_unit_test.go to ensure complete test isolation.
- Added bootstrap_test.go to verify launcher preparation logic in
internal/namespace without requiring syscall.Exec.
- Resolved data races in internal/network tests.
CLI:
- Added support for -h, --help, and -help flags for the main command.
Verification:
- Passed all tests (unit, integration, E2E).
- Verified zero data races with 'go test -race'.
- Passed golangci-lint and go vet.
| -rw-r--r-- | internal/cli/cli.go | 10 | ||||
| -rw-r--r-- | internal/manager/manager.go | 4 | ||||
| -rw-r--r-- | internal/namespace/bootstrap_test.go | 111 | ||||
| -rw-r--r-- | internal/namespace/namespace.go | 152 | ||||
| -rw-r--r-- | internal/network/network.go | 74 | ||||
| -rw-r--r-- | internal/network/network_test.go | 163 | ||||
| -rw-r--r-- | internal/wireguard/wireguard.go | 64 | ||||
| -rw-r--r-- | internal/wireguard/wireguard_unit_test.go | 195 |
8 files changed, 698 insertions, 75 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 5beb989..b38d0d9 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -56,7 +56,13 @@ func (a *App) Route() error { return fmt.Errorf("no command provided") } - switch a.Args[1] { + firstArg := a.Args[1] + if firstArg == "-h" || firstArg == "--help" || firstArg == "-help" { + a.printUsage() + return nil + } + + switch firstArg { case "show-config": return a.showConfig() case "profile": @@ -71,7 +77,7 @@ func (a *App) Route() error { return a.testLifecycle() default: a.printUsage() - return fmt.Errorf("unknown command: %s", a.Args[1]) + return fmt.Errorf("unknown command: %s", firstArg) } } diff --git a/internal/manager/manager.go b/internal/manager/manager.go index ffb02c0..270b99e 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -185,11 +185,11 @@ func (m *Manager) Execute(cfg *config.Config, verbose bool) error { } } - tunnel, err := wireguard.StartTunnel(m.PM, cfg.Profile, wgCfg, dnsServer) + tunnel, err := wireguard.NewTunnelManager().StartTunnel(m.PM, cfg.Profile, wgCfg, dnsServer) if err != nil { return fmt.Errorf("failed to start WireGuard tunnel: %w", err) } - defer tunnel.Close() + defer tunnel.Close(wireguard.NewTunnelManager()) if err := m.NS.PinNamespace(m.PM, cfg.Profile); err != nil { fmt.Fprintf(os.Stderr, "warning: failed to pin namespace: %v\n", err) diff --git a/internal/namespace/bootstrap_test.go b/internal/namespace/bootstrap_test.go new file mode 100644 index 0000000..eff04e2 --- /dev/null +++ b/internal/namespace/bootstrap_test.go @@ -0,0 +1,111 @@ +//go:build linux + +package namespace + +import ( + "os" + "strings" + "testing" +) + +func TestPrepareBootstrap(t *testing.T) { + t.Parallel() + + // We are testing the environment and argument preparation. + // This will actually open some FDs (memfd, netns, socket), + // which is fine for an integration-style unit test. + config, err := PrepareBootstrap() + if err != nil { + t.Fatalf("PrepareBootstrap failed: %v", err) + } + + // 1. Verify ExecPath is correct (should be a fd in /proc/self/fd) + if !strings.HasPrefix(config.ExecPath, "/proc/self/fd/") { + t.Errorf("expected ExecPath to be in /proc/self/fd/, got %s", config.ExecPath) + } + + // 2. Verify Arguments are preserved and expanded + if len(config.Args) == 0 || config.Args[0] == "" { + t.Error("args should not be empty") + } + + // 3. Verify Host NetNS FD is present in Env + foundNetNs := false + for _, env := range config.Env { + if strings.HasPrefix(env, "WG_WRAP_HOST_NETNS_FD=") { + foundNetNs = true + break + } + } + if !foundNetNs { + t.Error("WG_WRAP_HOST_NETNS_FD missing from environment") + } + + // 4. Verify Host Socket FD is present in Env + foundSocket := false + for _, env := range config.Env { + if strings.HasPrefix(env, "WG_WRAP_HOST_SOCKET_FD=") { + foundSocket = true + break + } + } + if !foundSocket { + t.Error("WG_WRAP_HOST_SOCKET_FD missing from environment") + } + + // 5. Verify FDs are tracked for cleanup + if len(config.Fds) < 2 { + t.Errorf("expected at least 2 FDs (launcher, netns), got %d", len(config.Fds)) + } +} + +func TestPrepareBootstrapJoin(t *testing.T) { + t.Parallel() + + targetPid := 1234 + config, err := PrepareBootstrapJoin(targetPid) + if err != nil { + t.Fatalf("PrepareBootstrapJoin failed: %v", err) + } + + // 1. Verify Join PID is present in Env + foundJoinPid := false + expectedPidEnv := "WG_WRAP_JOIN_PID=1234" + for _, env := range config.Env { + if env == expectedPidEnv { + foundJoinPid = true + break + } + } + if !foundJoinPid { + t.Errorf("expected %s in environment", expectedPidEnv) + } + + // 2. Verify Joined flag is present + foundJoined := false + for _, env := range config.Env { + if env == "WG_WRAP_JOINED=1" { + foundJoined = true + break + } + } + if !foundJoined { + t.Error("WG_WRAP_JOINED=1 missing from environment") + } +} + +func TestPrepareBootstrap_NullByteValidation(t *testing.T) { + // Temporarily inject a null byte into os.Args to test validation + // Note: os.Args is a slice, so we can modify it. + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + os.Args = append(os.Args, "bad\x00arg") + + _, err := PrepareBootstrap() + if err == nil { + t.Error("expected error when argument contains null byte, got nil") + } else if !strings.Contains(err.Error(), "contains null byte") { + t.Errorf("expected null byte error, got: %v", err) + } +} diff --git a/internal/namespace/namespace.go b/internal/namespace/namespace.go index 45eba73..368775f 100644 --- a/internal/namespace/namespace.go +++ b/internal/namespace/namespace.go @@ -32,6 +32,66 @@ import ( "golang.org/x/sys/unix" ) +// MountOps abstracts the filesystem mount operations. +type MountOps interface { + Mount(source, target, fstype string, flags uintptr, data string) error + Unmount(target string, flags int) error +} + +// realMountOps is the production implementation using unix.Mount. +type realMountOps struct{} + +func (r *realMountOps) Mount(source, target, fstype string, flags uintptr, data string) error { + return unix.Mount(source, target, fstype, flags, data) +} +func (r *realMountOps) Unmount(target string, flags int) error { + return unix.Unmount(target, flags) +} + +// DefaultMountOps is the global instance used by the package functions. +var DefaultMountOps MountOps = &realMountOps{} + +// FileSystem abstracts the basic filesystem operations used for isolation. +type FileSystem interface { + Stat(name string) (os.FileInfo, error) + MkdirAll(path string, perm os.FileMode) error + CreateTemp(dir, pattern string) (*os.File, error) + MkdirTemp(dir, pattern string) (string, error) + Remove(name string) error +} + +// realFS is the production implementation using the os package. +type realFS struct{} + +func (r *realFS) Stat(name string) (os.FileInfo, error) { return os.Stat(name) } +func (r *realFS) MkdirAll(path string, perm os.FileMode) error { + return os.MkdirAll(path, perm) +} +func (r *realFS) CreateTemp(dir, pattern string) (*os.File, error) { + return os.CreateTemp(dir, pattern) +} +func (r *realFS) MkdirTemp(dir, pattern string) (string, error) { + return os.MkdirTemp(dir, pattern) +} +func (r *realFS) Remove(name string) error { return os.Remove(name) } + +// DefaultFS is the global instance used by the package functions. +var DefaultFS FileSystem = &realFS{} + +// ResetDefaults restores the default implementations of MountOps and FileSystem. +func ResetDefaults() { + DefaultMountOps = &realMountOps{} + DefaultFS = &realFS{} +} + +// BootstrapConfig contains the environment and arguments needed to execute the bootstrap launcher. +type BootstrapConfig struct { + Args []string + Env []string + Fds []int + ExecPath string +} + //go:embed launcher.bin var launcherBytes []byte @@ -104,48 +164,59 @@ func Bootstrap() (err error) { return nil } - var fdsToClose []int + config, err := PrepareBootstrap() + if err != nil { + return err + } + defer func() { - if err != nil { - for _, fd := range fdsToClose { - _ = syscall.Close(fd) - } + for _, fd := range config.Fds { + _ = syscall.Close(fd) } }() + err = syscall.Exec(config.ExecPath, config.Args, config.Env) + if err != nil { + return fmt.Errorf("launcher exec failed: %w", err) + } + + return nil +} + +// PrepareBootstrap calculates the environment and arguments needed for the bootstrap launcher. +func PrepareBootstrap() (*BootstrapConfig, error) { // 0. Validate current arguments for null bytes before proceeding. for i, arg := range os.Args { for j := 0; j < len(arg); j++ { if arg[j] == 0 { - return fmt.Errorf("argument %d contains null byte at position %d", i, j) + return nil, fmt.Errorf("argument %d contains null byte at position %d", i, j) } } } self, err := os.Executable() if err != nil { - return fmt.Errorf("failed to get executable path: %w", err) + return nil, fmt.Errorf("failed to get executable path: %w", err) } execFd, err := prepareLauncher() if err != nil { - return err + return nil, err } - fdsToClose = append(fdsToClose, execFd) // Clear close-on-exec if flags, err := unix.FcntlInt(uintptr(execFd), unix.F_GETFD, 0); err == nil { _, _ = unix.FcntlInt(uintptr(execFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC) } - // 3. Prepare arguments for the launcher. + // Prepare arguments for the launcher. args := []string{self} args = append(args, os.Args[1:]...) for i, arg := range args { for j := 0; j < len(arg); j++ { if arg[j] == 0 { - return fmt.Errorf("launcher argument %d contains null byte at position %d", i, j) + return nil, fmt.Errorf("launcher argument %d contains null byte at position %d", i, j) } } } @@ -153,9 +224,8 @@ func Bootstrap() (err error) { // Open the host network namespace file descriptor before unsharing. hostNetFd, err := syscall.Open("/proc/self/ns/net", syscall.O_RDONLY, 0) if err != nil { - return fmt.Errorf("failed to open host netns: %w", err) + return nil, fmt.Errorf("failed to open host netns: %w", err) } - fdsToClose = append(fdsToClose, hostNetFd) // Clear close-on-exec if flags, err := unix.FcntlInt(uintptr(hostNetFd), unix.F_GETFD, 0); err == nil { @@ -174,18 +244,17 @@ func Bootstrap() (err error) { _, _ = unix.FcntlInt(hostSocketFd, unix.F_SETFD, flags&^unix.FD_CLOEXEC) } env = append(env, fmt.Sprintf("WG_WRAP_HOST_SOCKET_FD=%d", hostSocketFd)) - fdsToClose = append(fdsToClose, int(hostSocketFd)) _ = conn.Close() } } } - err = syscall.Exec(fmt.Sprintf("/proc/self/fd/%d", execFd), args, env) - if err != nil { - return fmt.Errorf("launcher exec failed: %w", err) - } - - return nil + return &BootstrapConfig{ + Args: args, + Env: env, + Fds: []int{execFd, hostNetFd}, + ExecPath: fmt.Sprintf("/proc/self/fd/%d", execFd), + }, nil } // BootstrapJoin joins the namespaces of the target PID and replaces the current process. @@ -194,33 +263,44 @@ func BootstrapJoin(targetPid int) (err error) { return nil } - var fdsToClose []int + config, err := PrepareBootstrapJoin(targetPid) + if err != nil { + return err + } + defer func() { - if err != nil { - for _, fd := range fdsToClose { - _ = syscall.Close(fd) - } + for _, fd := range config.Fds { + _ = syscall.Close(fd) } }() + err = syscall.Exec(config.ExecPath, config.Args, config.Env) + if err != nil { + return fmt.Errorf("launcher exec failed: %w", err) + } + + return nil +} + +// PrepareBootstrapJoin calculates the environment and arguments needed to join a namespace. +func PrepareBootstrapJoin(targetPid int) (*BootstrapConfig, error) { for i, arg := range os.Args { for j := 0; j < len(arg); j++ { if arg[j] == 0 { - return fmt.Errorf("argument %d contains null byte at position %d", i, j) + return nil, fmt.Errorf("argument %d contains null byte at position %d", i, j) } } } self, err := os.Executable() if err != nil { - return fmt.Errorf("failed to get executable path: %w", err) + return nil, fmt.Errorf("failed to get executable path: %w", err) } execFd, err := prepareLauncher() if err != nil { - return err + return nil, err } - fdsToClose = append(fdsToClose, execFd) if flags, err := unix.FcntlInt(uintptr(execFd), unix.F_GETFD, 0); err == nil { _, _ = unix.FcntlInt(uintptr(execFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC) @@ -232,7 +312,7 @@ func BootstrapJoin(targetPid int) (err error) { for i, arg := range args { for j := 0; j < len(arg); j++ { if arg[j] == 0 { - return fmt.Errorf("launcher argument %d contains null byte at position %d", i, j) + return nil, fmt.Errorf("launcher argument %d contains null byte at position %d", i, j) } } } @@ -242,12 +322,12 @@ func BootstrapJoin(targetPid int) (err error) { "WG_WRAP_JOINED=1", ) - err = syscall.Exec(fmt.Sprintf("/proc/self/fd/%d", execFd), args, env) - if err != nil { - return fmt.Errorf("launcher exec failed: %w", err) - } - - return nil + return &BootstrapConfig{ + Args: args, + Env: env, + Fds: []int{execFd}, + ExecPath: fmt.Sprintf("/proc/self/fd/%d", execFd), + }, nil } func prepareLauncher() (int, error) { diff --git a/internal/network/network.go b/internal/network/network.go index 6afcf5e..e9dce77 100644 --- a/internal/network/network.go +++ b/internal/network/network.go @@ -16,9 +16,54 @@ type InterfaceInfo struct { Index int } +// NetworkOps abstracts the low-level netlink operations. +type NetworkOps interface { + LinkList() ([]netlink.Link, error) + LinkByName(name string) (netlink.Link, error) + LinkSetMTU(link netlink.Link, mtu int) error + LinkSetUp(link netlink.Link) error + AddrAdd(link netlink.Link, addr *netlink.Addr) error + RouteAdd(route *netlink.Route) error + RouteReplace(route *netlink.Route) error +} + +// realNetworkOps is the production implementation using netlink. +type realNetworkOps struct{} + +func (r *realNetworkOps) LinkList() ([]netlink.Link, error) { return netlink.LinkList() } +func (r *realNetworkOps) LinkByName(name string) (netlink.Link, error) { + return netlink.LinkByName(name) +} +func (r *realNetworkOps) LinkSetMTU(link netlink.Link, mtu int) error { + return netlink.LinkSetMTU(link, mtu) +} +func (r *realNetworkOps) LinkSetUp(link netlink.Link) error { return netlink.LinkSetUp(link) } +func (r *realNetworkOps) AddrAdd(link netlink.Link, addr *netlink.Addr) error { + return netlink.AddrAdd(link, addr) +} + +func (r *realNetworkOps) RouteAdd(route *netlink.Route) error { return netlink.RouteAdd(route) } +func (r *realNetworkOps) RouteReplace(route *netlink.Route) error { return netlink.RouteReplace(route) } + +// DefaultNetworkOps is the global instance used by the package functions. +// It can be replaced during tests. +var DefaultNetworkOps NetworkOps = &realNetworkOps{} + +// NetworkManager coordinates network configuration within a namespace. +type NetworkManager struct { + Ops NetworkOps +} + +// NewNetworkManager creates a new NetworkManager with production defaults. +func NewNetworkManager() *NetworkManager { + return &NetworkManager{ + Ops: DefaultNetworkOps, + } +} + // ListInterfaces returns a list of all network interfaces present in the current namespace. -func ListInterfaces() ([]InterfaceInfo, error) { - links, err := netlink.LinkList() +func (nm *NetworkManager) ListInterfaces() ([]InterfaceInfo, error) { + links, err := nm.Ops.LinkList() if err != nil { return nil, fmt.Errorf("failed to list interfaces: %w", err) } @@ -35,17 +80,17 @@ func ListInterfaces() ([]InterfaceInfo, error) { // ConfigureInterface sets the MTU, brings the interface up, assigns an IP address, // and configures the default route. -func ConfigureInterface(name, address string, mtu int) error { - link, err := netlink.LinkByName(name) +func (nm *NetworkManager) ConfigureInterface(name, address string, mtu int) error { + link, err := nm.Ops.LinkByName(name) if err != nil { return fmt.Errorf("failed to find link %s: %w", name, err) } - if err := netlink.LinkSetMTU(link, mtu); err != nil { + if err := nm.Ops.LinkSetMTU(link, mtu); err != nil { return fmt.Errorf("failed to set MTU %d on link %s: %w", mtu, name, err) } - if err := netlink.LinkSetUp(link); err != nil { + if err := nm.Ops.LinkSetUp(link); err != nil { return fmt.Errorf("failed to bring up link %s: %w", name, err) } @@ -53,7 +98,7 @@ func ConfigureInterface(name, address string, mtu int) error { if err != nil { return fmt.Errorf("invalid IP address %s: %w", address, err) } - if err := netlink.AddrAdd(link, addr); err != nil { + if err := nm.Ops.AddrAdd(link, addr); err != nil { if !strings.Contains(err.Error(), "file exists") { return fmt.Errorf("failed to add address %s to link %s: %w", address, name, err) } @@ -72,11 +117,22 @@ func ConfigureInterface(name, address string, mtu int) error { Dst: dst, } - if err := netlink.RouteAdd(route); err != nil { - if err := netlink.RouteReplace(route); err != nil { + if err := nm.Ops.RouteAdd(route); err != nil { + if err := nm.Ops.RouteReplace(route); err != nil { return fmt.Errorf("failed to configure default route via %s: %w", name, err) } } return nil } + +// ListInterfaces returns a list of all network interfaces present in the current namespace. +func ListInterfaces() ([]InterfaceInfo, error) { + return NewNetworkManager().ListInterfaces() +} + +// ConfigureInterface sets the MTU, brings the interface up, assigns an IP address, +// and configures the default route. +func ConfigureInterface(name, address string, mtu int) error { + return NewNetworkManager().ConfigureInterface(name, address, mtu) +} diff --git a/internal/network/network_test.go b/internal/network/network_test.go new file mode 100644 index 0000000..b598484 --- /dev/null +++ b/internal/network/network_test.go @@ -0,0 +1,163 @@ +//go:build linux + +package network + +import ( + "errors" + "fmt" + "strings" + "testing" + + "github.com/vishvananda/netlink" +) + +// mockNetworkOps allows us to control the behavior of netlink calls. +type mockNetworkOps struct { + linkByNameFunc func(name string) (netlink.Link, error) + linkSetMTUFunc func(link netlink.Link, mtu int) error + linkSetUpFunc func(link netlink.Link) error + addrAddFunc func(link netlink.Link, addr *netlink.Addr) error + routeAddFunc func(route *netlink.Route) error + routeReplaceFunc func(route *netlink.Route) error +} + +func (m *mockNetworkOps) LinkList() ([]netlink.Link, error) { return nil, nil } +func (m *mockNetworkOps) LinkByName(name string) (netlink.Link, error) { + if m.linkByNameFunc != nil { + return m.linkByNameFunc(name) + } + return nil, fmt.Errorf("not implemented") +} +func (m *mockNetworkOps) LinkSetMTU(link netlink.Link, mtu int) error { + if m.linkSetMTUFunc != nil { + return m.linkSetMTUFunc(link, mtu) + } + return nil +} +func (m *mockNetworkOps) LinkSetUp(link netlink.Link) error { + if m.linkSetUpFunc != nil { + return m.linkSetUpFunc(link) + } + return nil +} +func (m *mockNetworkOps) AddrAdd(link netlink.Link, addr *netlink.Addr) error { + if m.addrAddFunc != nil { + return m.addrAddFunc(link, addr) + } + return nil +} +func (m *mockNetworkOps) RouteAdd(route *netlink.Route) error { + if m.routeAddFunc != nil { + return m.routeAddFunc(route) + } + return nil +} +func (m *mockNetworkOps) RouteReplace(route *netlink.Route) error { + if m.routeReplaceFunc != nil { + return m.routeReplaceFunc(route) + } + return nil +} + +// mockLink implements netlink.Link. +type mockLink struct { + name string + idx int +} + +func (m *mockLink) Type() string { + return "mock" +} + +func (m *mockLink) Attrs() *netlink.LinkAttrs { + return &netlink.LinkAttrs{Name: m.name, Index: m.idx} +} + +func TestConfigureInterface_Success(t *testing.T) { + t.Parallel() + mock := &mockNetworkOps{ + linkByNameFunc: func(name string) (netlink.Link, error) { + return &mockLink{name: name, idx: 1}, nil + }, + } + nm := &NetworkManager{Ops: mock} + + err := nm.ConfigureInterface("tun0", "10.0.0.1/24", 1420) + if err != nil { + t.Errorf("expected success, got %v", err) + } +} + +func TestConfigureInterface_RouteFallback(t *testing.T) { + t.Parallel() + routeAddCalled := false + routeReplaceCalled := false + + mock := &mockNetworkOps{ + linkByNameFunc: func(name string) (netlink.Link, error) { + return &mockLink{name: name, idx: 1}, nil + }, + routeAddFunc: func(route *netlink.Route) error { + routeAddCalled = true + return errors.New("file exists") // Simulate EEXIST + }, + routeReplaceFunc: func(route *netlink.Route) error { + routeReplaceCalled = true + return nil + }, + } + nm := &NetworkManager{Ops: mock} + + err := nm.ConfigureInterface("tun0", "10.0.0.1/24", 1420) + if err != nil { + t.Errorf("expected success after fallback, got %v", err) + } + if !routeAddCalled { + t.Error("expected RouteAdd to be called first") + } + if !routeReplaceCalled { + t.Error("expected RouteReplace to be called after RouteAdd fails with 'file exists'") + } +} + +func TestConfigureInterface_RouteFailure(t *testing.T) { + t.Parallel() + mock := &mockNetworkOps{ + linkByNameFunc: func(name string) (netlink.Link, error) { + return &mockLink{name: name, idx: 1}, nil + }, + routeAddFunc: func(route *netlink.Route) error { + return errors.New("critical network failure") + }, + routeReplaceFunc: func(route *netlink.Route) error { + return errors.New("critical network failure") + }, + } + nm := &NetworkManager{Ops: mock} + + err := nm.ConfigureInterface("tun0", "10.0.0.1/24", 1420) + if err == nil { + t.Error("expected error when both RouteAdd and RouteReplace fail, got nil") + } + if !strings.Contains(err.Error(), "failed to configure default route") { + t.Errorf("expected route error, got: %v", err) + } +} + +func TestConfigureInterface_LinkNotFound(t *testing.T) { + t.Parallel() + mock := &mockNetworkOps{ + linkByNameFunc: func(name string) (netlink.Link, error) { + return nil, errors.New("no such device") + }, + } + nm := &NetworkManager{Ops: mock} + + err := nm.ConfigureInterface("nonexistent", "10.0.0.1/24", 1420) + if err == nil { + t.Error("expected error when link is not found, got nil") + } + if !strings.Contains(err.Error(), "failed to find link") { + t.Errorf("expected link not found error, got: %v", err) + } +} diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 8ffe794..67cc211 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -51,8 +51,24 @@ type Tunnel struct { dnsFile string } +// TunnelManager coordinates the creation and management of a WireGuard tunnel. +type TunnelManager struct { + MountOps namespace.MountOps + FS namespace.FileSystem + Net *network.NetworkManager +} + +// NewTunnelManager creates a new TunnelManager with production defaults. +func NewTunnelManager() *TunnelManager { + return &TunnelManager{ + MountOps: namespace.DefaultMountOps, + FS: namespace.DefaultFS, + Net: network.NewNetworkManager(), + } +} + // StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes. -func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) { +func (tm *TunnelManager) StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) { var cleanups []func() defer func() { if err != nil { @@ -66,11 +82,11 @@ func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsS tunName := "tun0" mtu := 1420 - if err := unix.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil { + if err := tm.MountOps.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil { fmt.Printf("warning: failed to make mount namespace private: %v\n", err) } - if err := BlockHostServices(pm, profile); err != nil { + if err := tm.BlockHostServices(pm, profile); err != nil { fmt.Printf("warning: failed to block host services: %v\n", err) } @@ -118,18 +134,18 @@ func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsS } // 4. Configure network interface - if err := network.ConfigureInterface(tunName, cfg.Address, mtu); err != nil { + if err := tm.Net.ConfigureInterface(tunName, cfg.Address, mtu); err != nil { return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err) } var dnsFile string profileDir := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile) - if path, err := ConfigureResolvConf(dnsServer, profileDir); err != nil { + if path, err := tm.ConfigureResolvConf(dnsServer, profileDir); err != nil { fmt.Printf("warning: failed to configure DNS resolver: %v\n", err) } else { dnsFile = path cleanups = append(cleanups, func() { - if err := UnmountResolvConf(dnsFile); err != nil { + if err := tm.UnmountResolvConf(dnsFile); err != nil { fmt.Printf("warning: failed to unmount resolv.conf during cleanup: %v\n", err) } }) @@ -143,12 +159,12 @@ func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsS } // Close shuts down the userspace WireGuard device and closes the TUN interface. -func (t *Tunnel) Close() { +func (t *Tunnel) Close(tm *TunnelManager) { if t.Device != nil { t.Device.Close() } if t.dnsFile != "" { - if err := UnmountResolvConf(t.dnsFile); err != nil { + if err := tm.UnmountResolvConf(t.dnsFile); err != nil { fmt.Printf("warning: failed to unmount resolv.conf: %v\n", err) } } @@ -216,14 +232,12 @@ func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) { } // ConfigureResolvConf creates a temporary resolv.conf file and bind-mounts it to /etc/resolv.conf. -func ConfigureResolvConf(dns string, profileDir string) (string, error) { +func (tm *TunnelManager) ConfigureResolvConf(dns string, profileDir string) (string, error) { if dns == "" { return "", nil } - // Create the temporary resolv.conf file within the profile directory - // so it can be cleaned up during namespace unpinning. - tmpFile, err := os.CreateTemp(profileDir, "resolvconf") + tmpFile, err := tm.FS.CreateTemp(profileDir, "resolvconf") if err != nil { return "", fmt.Errorf("failed to create temp resolv.conf in %s: %w", profileDir, err) } @@ -235,11 +249,11 @@ func ConfigureResolvConf(dns string, profileDir string) (string, error) { } _ = tmpFile.Close() - if err := unix.Mount(launcherPath, "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil { + if err := tm.MountOps.Mount(launcherPath, "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil { return "", fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", launcherPath, err) } - if err := unix.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil { + if err := tm.MountOps.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil { fmt.Printf("warning: failed to make /etc/resolv.conf mount private: %v\n", err) } @@ -247,16 +261,14 @@ func ConfigureResolvConf(dns string, profileDir string) (string, error) { } // UnmountResolvConf unmounts the bind-mounted resolv.conf and removes the temporary file. -func UnmountResolvConf(path string) error { +func (tm *TunnelManager) UnmountResolvConf(path string) error { if path == "" { return nil } - // Attempt to unmount. If it fails, it might already be unmounted - // or the namespace might be gone. - _ = unix.Unmount("/etc/resolv.conf", unix.MNT_DETACH) + _ = tm.MountOps.Unmount("/etc/resolv.conf", unix.MNT_DETACH) - if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + if err := tm.FS.Remove(path); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove temp resolv.conf file %s: %w", path, err) } return nil @@ -264,18 +276,18 @@ func UnmountResolvConf(path string) error { // BlockHostServices bind-mounts empty files/directories over sensitive host services // to prevent access from within the isolated namespace. -func BlockHostServices(pm *paths.PathManager, profile string) error { +func (tm *TunnelManager) BlockHostServices(pm *paths.PathManager, profile string) error { blockDirBase := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "block") - if err := os.MkdirAll(blockDirBase, 0755); err != nil { + if err := tm.FS.MkdirAll(blockDirBase, 0755); err != nil { return fmt.Errorf("failed to create block base directory: %w", err) } - tmpDir, err := os.MkdirTemp(blockDirBase, "dir-") + tmpDir, err := tm.FS.MkdirTemp(blockDirBase, "dir-") if err != nil { return fmt.Errorf("failed to create temp block dir: %w", err) } - tmpFile, err := os.CreateTemp(blockDirBase, "file-") + tmpFile, err := tm.FS.CreateTemp(blockDirBase, "file-") if err != nil { return fmt.Errorf("failed to create temp block file: %w", err) } @@ -283,16 +295,16 @@ func BlockHostServices(pm *paths.PathManager, profile string) error { _ = tmpFile.Close() for _, p := range namespace.GetBlockPaths() { - stat, err := os.Stat(p) + stat, err := tm.FS.Stat(p) if err == nil { source := tmpFileName if stat.IsDir() { source = tmpDir } - if err := unix.Mount(source, p, "", unix.MS_BIND, ""); err != nil { + if err := tm.MountOps.Mount(source, p, "", unix.MS_BIND, ""); err != nil { fmt.Printf("warning: failed to bind-mount block over %s: %v\n", p, err) } else { - _ = unix.Mount("", p, "", unix.MS_PRIVATE, "") + _ = tm.MountOps.Mount("", p, "", unix.MS_PRIVATE, "") } } } diff --git a/internal/wireguard/wireguard_unit_test.go b/internal/wireguard/wireguard_unit_test.go new file mode 100644 index 0000000..1ad7f65 --- /dev/null +++ b/internal/wireguard/wireguard_unit_test.go @@ -0,0 +1,195 @@ +//go:build linux + +package wireguard + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "git.theodohertyfamily.com/wg-wrap/internal/namespace" + "git.theodohertyfamily.com/wg-wrap/internal/paths" + "golang.org/x/sys/unix" +) + +// mockMountOps records mount calls for verification. +type mockMountOps struct { + mounts []mountCall + unmounts []string +} + +type mountCall struct { + source string + target string + flags uintptr +} + +func (m *mockMountOps) Mount(source, target, fstype string, flags uintptr, data string) error { + m.mounts = append(m.mounts, mountCall{source, target, flags}) + return nil +} + +func (m *mockMountOps) Unmount(target string, flags int) error { + m.unmounts = append(m.unmounts, target) + return nil +} + +// mockFS mimics a filesystem using a temporary directory. +type mockFS struct { + root string +} + +func (m *mockFS) fullPath(path string) string { + if filepath.IsAbs(path) { + path = strings.TrimPrefix(path, "/") + } + return filepath.Join(m.root, path) +} + +func (m *mockFS) Stat(name string) (os.FileInfo, error) { + return os.Stat(m.fullPath(name)) +} + +func (m *mockFS) MkdirAll(path string, perm os.FileMode) error { + return os.MkdirAll(m.fullPath(path), perm) +} + +func (m *mockFS) CreateTemp(dir, pattern string) (*os.File, error) { + return os.CreateTemp(m.fullPath(dir), pattern) +} + +func (m *mockFS) MkdirTemp(dir, pattern string) (string, error) { + dirPath := m.fullPath(dir) + res, err := os.MkdirTemp(dirPath, pattern) + if err != nil { + return "", err + } + // Return path relative to root for consistency if needed, but OS calls return absolute. + // The TunnelManager expects a path it can pass to Remove() later. + // If we return absolute, mockFS.Remove must handle it. + return res, nil +} + +func (m *mockFS) Remove(name string) error { + // If the path is absolute and starts with our root, we can remove it directly. + // Otherwise, we use fullPath to ensure it's within root. + if filepath.IsAbs(name) && strings.HasPrefix(name, m.root) { + return os.Remove(name) + } + return os.Remove(m.fullPath(name)) +} + +// setupParallelEnv creates a clean, isolated test environment for a single test. +func setupParallelEnv(t *testing.T) (*paths.PathManager, *mockMountOps, *mockFS, *TunnelManager) { + t.Helper() + + tmpDir := t.TempDir() + hostRoot := t.TempDir() + + pm := paths.NewPathManager("", tmpDir) + mockMounts := &mockMountOps{} + mockFS := &mockFS{root: hostRoot} + + tm := &TunnelManager{ + MountOps: mockMounts, + FS: mockFS, + } + + return pm, mockMounts, mockFS, tm +} + +func TestBlockHostServices_Logic(t *testing.T) { + t.Parallel() + + pm, mockMounts, mockFS, tm := setupParallelEnv(t) + profile := "test-profile" + + blockPaths := namespace.GetBlockPaths() + for _, p := range blockPaths { + fullPath := mockFS.fullPath(p) + _ = os.MkdirAll(filepath.Dir(fullPath), 0755) + _ = os.WriteFile(fullPath, []byte("data"), 0644) + } + + err := tm.BlockHostServices(pm, profile) + if err != nil { + t.Fatalf("BlockHostServices failed: %v", err) + } + + for _, p := range blockPaths { + found := false + for _, m := range mockMounts.mounts { + if m.target == p { + found = true + break + } + } + if !found { + t.Errorf("path %s was not blocked", p) + } + } +} + +func TestConfigureResolvConf_Logic(t *testing.T) { + t.Parallel() + + _, mockMounts, mockFS, tm := setupParallelEnv(t) + + // Use a directory relative to mockFS.root + profileDir := "profile-dir" + _ = mockFS.MkdirAll(profileDir, 0755) + dnsServer := "1.1.1.1" + + path, err := tm.ConfigureResolvConf(dnsServer, profileDir) + if err != nil { + t.Fatalf("ConfigureResolvConf failed: %v", err) + } + + // path is returned from os.CreateTemp, which is absolute. + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("resolv.conf file was not created at %s", path) + } + + if len(mockMounts.mounts) < 2 { + t.Fatalf("expected 2 mounts, got %d", len(mockMounts.mounts)) + } + + if mockMounts.mounts[0].target != "/etc/resolv.conf" || mockMounts.mounts[0].flags != unix.MS_BIND { + t.Errorf("expected BIND mount to /etc/resolv.conf, got %+v", mockMounts.mounts[0]) + } + + if mockMounts.mounts[1].target != "/etc/resolv.conf" || mockMounts.mounts[1].flags != unix.MS_PRIVATE { + t.Errorf("expected PRIVATE mount to /etc/resolv.conf, got %+v", mockMounts.mounts[1]) + } +} + +func TestUnmountResolvConf_Logic(t *testing.T) { + t.Parallel() + + _, mockMounts, mockFS, tm := setupParallelEnv(t) + + // Create a file within the mock FS root + profileDir := "profile-dir" + _ = mockFS.MkdirAll(profileDir, 0755) + tmpFile := "resolv.conf.tmp" + fullPath := mockFS.fullPath(filepath.Join(profileDir, tmpFile)) + + err := os.WriteFile(fullPath, []byte("test"), 0644) + if err != nil { + t.Fatalf("failed to write temp resolv.conf: %v", err) + } + + err = tm.UnmountResolvConf(fullPath) + if err != nil { + t.Errorf("UnmountResolvConf failed: %v", err) + } + + if len(mockMounts.unmounts) == 0 || mockMounts.unmounts[0] != "/etc/resolv.conf" { + t.Errorf("expected unmount of /etc/resolv.conf, got %+v", mockMounts.unmounts) + } + + if _, err := os.Stat(fullPath); !os.IsNotExist(err) { + t.Error("temporary resolv.conf file was not removed") + } +} |
