summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames O'Doherty <james@theodohertyfamily.com>2026-06-13 11:51:04 -0400
committerJames O'Doherty <james@theodohertyfamily.com>2026-06-13 11:51:04 -0400
commit29621ecbd1e77e6e1a70b6b3ea8fbe3a56e47df3 (patch)
treefa54976bbb0c4e9db59c983e7cb4e60c5119d18b
parentf8afb7d5889f5c8b6ea256fd078fa8426d21c7be (diff)
refactor: implement dependency injection and enable parallel testing
This commit refactors the core system operations to use a manager-based dependency injection pattern, eliminating global state and resolving data races in the test suite. Architecture: - Introduced NetworkManager and NetworkOps interface in internal/network to abstract netlink calls. - Introduced MountOps and FileSystem interfaces in internal/namespace to abstract mount and filesystem operations. - Introduced TunnelManager in internal/wireguard to coordinate tunnel lifecycle using the new abstractions. - Updated internal/cli and internal/manager to use these managers. Testing: - Restored t.Parallel() to unit tests in internal/network and internal/wireguard. - Implemented setupParallelEnv and an enhanced mockFS in wireguard_unit_test.go to ensure complete test isolation. - Added bootstrap_test.go to verify launcher preparation logic in internal/namespace without requiring syscall.Exec. - Resolved data races in internal/network tests. CLI: - Added support for -h, --help, and -help flags for the main command. Verification: - Passed all tests (unit, integration, E2E). - Verified zero data races with 'go test -race'. - Passed golangci-lint and go vet.
-rw-r--r--internal/cli/cli.go10
-rw-r--r--internal/manager/manager.go4
-rw-r--r--internal/namespace/bootstrap_test.go111
-rw-r--r--internal/namespace/namespace.go152
-rw-r--r--internal/network/network.go74
-rw-r--r--internal/network/network_test.go163
-rw-r--r--internal/wireguard/wireguard.go64
-rw-r--r--internal/wireguard/wireguard_unit_test.go195
8 files changed, 698 insertions, 75 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go
index 5beb989..b38d0d9 100644
--- a/internal/cli/cli.go
+++ b/internal/cli/cli.go
@@ -56,7 +56,13 @@ func (a *App) Route() error {
return fmt.Errorf("no command provided")
}
- switch a.Args[1] {
+ firstArg := a.Args[1]
+ if firstArg == "-h" || firstArg == "--help" || firstArg == "-help" {
+ a.printUsage()
+ return nil
+ }
+
+ switch firstArg {
case "show-config":
return a.showConfig()
case "profile":
@@ -71,7 +77,7 @@ func (a *App) Route() error {
return a.testLifecycle()
default:
a.printUsage()
- return fmt.Errorf("unknown command: %s", a.Args[1])
+ return fmt.Errorf("unknown command: %s", firstArg)
}
}
diff --git a/internal/manager/manager.go b/internal/manager/manager.go
index ffb02c0..270b99e 100644
--- a/internal/manager/manager.go
+++ b/internal/manager/manager.go
@@ -185,11 +185,11 @@ func (m *Manager) Execute(cfg *config.Config, verbose bool) error {
}
}
- tunnel, err := wireguard.StartTunnel(m.PM, cfg.Profile, wgCfg, dnsServer)
+ tunnel, err := wireguard.NewTunnelManager().StartTunnel(m.PM, cfg.Profile, wgCfg, dnsServer)
if err != nil {
return fmt.Errorf("failed to start WireGuard tunnel: %w", err)
}
- defer tunnel.Close()
+ defer tunnel.Close(wireguard.NewTunnelManager())
if err := m.NS.PinNamespace(m.PM, cfg.Profile); err != nil {
fmt.Fprintf(os.Stderr, "warning: failed to pin namespace: %v\n", err)
diff --git a/internal/namespace/bootstrap_test.go b/internal/namespace/bootstrap_test.go
new file mode 100644
index 0000000..eff04e2
--- /dev/null
+++ b/internal/namespace/bootstrap_test.go
@@ -0,0 +1,111 @@
+//go:build linux
+
+package namespace
+
+import (
+ "os"
+ "strings"
+ "testing"
+)
+
+func TestPrepareBootstrap(t *testing.T) {
+ t.Parallel()
+
+ // We are testing the environment and argument preparation.
+ // This will actually open some FDs (memfd, netns, socket),
+ // which is fine for an integration-style unit test.
+ config, err := PrepareBootstrap()
+ if err != nil {
+ t.Fatalf("PrepareBootstrap failed: %v", err)
+ }
+
+ // 1. Verify ExecPath is correct (should be a fd in /proc/self/fd)
+ if !strings.HasPrefix(config.ExecPath, "/proc/self/fd/") {
+ t.Errorf("expected ExecPath to be in /proc/self/fd/, got %s", config.ExecPath)
+ }
+
+ // 2. Verify Arguments are preserved and expanded
+ if len(config.Args) == 0 || config.Args[0] == "" {
+ t.Error("args should not be empty")
+ }
+
+ // 3. Verify Host NetNS FD is present in Env
+ foundNetNs := false
+ for _, env := range config.Env {
+ if strings.HasPrefix(env, "WG_WRAP_HOST_NETNS_FD=") {
+ foundNetNs = true
+ break
+ }
+ }
+ if !foundNetNs {
+ t.Error("WG_WRAP_HOST_NETNS_FD missing from environment")
+ }
+
+ // 4. Verify Host Socket FD is present in Env
+ foundSocket := false
+ for _, env := range config.Env {
+ if strings.HasPrefix(env, "WG_WRAP_HOST_SOCKET_FD=") {
+ foundSocket = true
+ break
+ }
+ }
+ if !foundSocket {
+ t.Error("WG_WRAP_HOST_SOCKET_FD missing from environment")
+ }
+
+ // 5. Verify FDs are tracked for cleanup
+ if len(config.Fds) < 2 {
+ t.Errorf("expected at least 2 FDs (launcher, netns), got %d", len(config.Fds))
+ }
+}
+
+func TestPrepareBootstrapJoin(t *testing.T) {
+ t.Parallel()
+
+ targetPid := 1234
+ config, err := PrepareBootstrapJoin(targetPid)
+ if err != nil {
+ t.Fatalf("PrepareBootstrapJoin failed: %v", err)
+ }
+
+ // 1. Verify Join PID is present in Env
+ foundJoinPid := false
+ expectedPidEnv := "WG_WRAP_JOIN_PID=1234"
+ for _, env := range config.Env {
+ if env == expectedPidEnv {
+ foundJoinPid = true
+ break
+ }
+ }
+ if !foundJoinPid {
+ t.Errorf("expected %s in environment", expectedPidEnv)
+ }
+
+ // 2. Verify Joined flag is present
+ foundJoined := false
+ for _, env := range config.Env {
+ if env == "WG_WRAP_JOINED=1" {
+ foundJoined = true
+ break
+ }
+ }
+ if !foundJoined {
+ t.Error("WG_WRAP_JOINED=1 missing from environment")
+ }
+}
+
+func TestPrepareBootstrap_NullByteValidation(t *testing.T) {
+ // Temporarily inject a null byte into os.Args to test validation
+ // Note: os.Args is a slice, so we can modify it.
+ oldArgs := os.Args
+ defer func() { os.Args = oldArgs }()
+
+ os.Args = append(os.Args, "bad\x00arg")
+
+ _, err := PrepareBootstrap()
+ if err == nil {
+ t.Error("expected error when argument contains null byte, got nil")
+ } else if !strings.Contains(err.Error(), "contains null byte") {
+ t.Errorf("expected null byte error, got: %v", err)
+ }
+}
diff --git a/internal/namespace/namespace.go b/internal/namespace/namespace.go
index 45eba73..368775f 100644
--- a/internal/namespace/namespace.go
+++ b/internal/namespace/namespace.go
@@ -32,6 +32,66 @@ import (
"golang.org/x/sys/unix"
)
+// MountOps abstracts the filesystem mount operations.
+type MountOps interface {
+ Mount(source, target, fstype string, flags uintptr, data string) error
+ Unmount(target string, flags int) error
+}
+
+// realMountOps is the production implementation using unix.Mount.
+type realMountOps struct{}
+
+func (r *realMountOps) Mount(source, target, fstype string, flags uintptr, data string) error {
+ return unix.Mount(source, target, fstype, flags, data)
+}
+func (r *realMountOps) Unmount(target string, flags int) error {
+ return unix.Unmount(target, flags)
+}
+
+// DefaultMountOps is the global instance used by the package functions.
+var DefaultMountOps MountOps = &realMountOps{}
+
+// FileSystem abstracts the basic filesystem operations used for isolation.
+type FileSystem interface {
+ Stat(name string) (os.FileInfo, error)
+ MkdirAll(path string, perm os.FileMode) error
+ CreateTemp(dir, pattern string) (*os.File, error)
+ MkdirTemp(dir, pattern string) (string, error)
+ Remove(name string) error
+}
+
+// realFS is the production implementation using the os package.
+type realFS struct{}
+
+func (r *realFS) Stat(name string) (os.FileInfo, error) { return os.Stat(name) }
+func (r *realFS) MkdirAll(path string, perm os.FileMode) error {
+ return os.MkdirAll(path, perm)
+}
+func (r *realFS) CreateTemp(dir, pattern string) (*os.File, error) {
+ return os.CreateTemp(dir, pattern)
+}
+func (r *realFS) MkdirTemp(dir, pattern string) (string, error) {
+ return os.MkdirTemp(dir, pattern)
+}
+func (r *realFS) Remove(name string) error { return os.Remove(name) }
+
+// DefaultFS is the global instance used by the package functions.
+var DefaultFS FileSystem = &realFS{}
+
+// ResetDefaults restores the default implementations of MountOps and FileSystem.
+func ResetDefaults() {
+ DefaultMountOps = &realMountOps{}
+ DefaultFS = &realFS{}
+}
+
+// BootstrapConfig contains the environment and arguments needed to execute the bootstrap launcher.
+type BootstrapConfig struct {
+ Args []string
+ Env []string
+ Fds []int
+ ExecPath string
+}
+
//go:embed launcher.bin
var launcherBytes []byte
@@ -104,48 +164,59 @@ func Bootstrap() (err error) {
return nil
}
- var fdsToClose []int
+ config, err := PrepareBootstrap()
+ if err != nil {
+ return err
+ }
+
defer func() {
- if err != nil {
- for _, fd := range fdsToClose {
- _ = syscall.Close(fd)
- }
+ for _, fd := range config.Fds {
+ _ = syscall.Close(fd)
}
}()
+ err = syscall.Exec(config.ExecPath, config.Args, config.Env)
+ if err != nil {
+ return fmt.Errorf("launcher exec failed: %w", err)
+ }
+
+ return nil
+}
+
+// PrepareBootstrap calculates the environment and arguments needed for the bootstrap launcher.
+func PrepareBootstrap() (*BootstrapConfig, error) {
// 0. Validate current arguments for null bytes before proceeding.
for i, arg := range os.Args {
for j := 0; j < len(arg); j++ {
if arg[j] == 0 {
- return fmt.Errorf("argument %d contains null byte at position %d", i, j)
+ return nil, fmt.Errorf("argument %d contains null byte at position %d", i, j)
}
}
}
self, err := os.Executable()
if err != nil {
- return fmt.Errorf("failed to get executable path: %w", err)
+ return nil, fmt.Errorf("failed to get executable path: %w", err)
}
execFd, err := prepareLauncher()
if err != nil {
- return err
+ return nil, err
}
- fdsToClose = append(fdsToClose, execFd)
// Clear close-on-exec
if flags, err := unix.FcntlInt(uintptr(execFd), unix.F_GETFD, 0); err == nil {
_, _ = unix.FcntlInt(uintptr(execFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC)
}
- // 3. Prepare arguments for the launcher.
+ // Prepare arguments for the launcher.
args := []string{self}
args = append(args, os.Args[1:]...)
for i, arg := range args {
for j := 0; j < len(arg); j++ {
if arg[j] == 0 {
- return fmt.Errorf("launcher argument %d contains null byte at position %d", i, j)
+ return nil, fmt.Errorf("launcher argument %d contains null byte at position %d", i, j)
}
}
}
@@ -153,9 +224,8 @@ func Bootstrap() (err error) {
// Open the host network namespace file descriptor before unsharing.
hostNetFd, err := syscall.Open("/proc/self/ns/net", syscall.O_RDONLY, 0)
if err != nil {
- return fmt.Errorf("failed to open host netns: %w", err)
+ return nil, fmt.Errorf("failed to open host netns: %w", err)
}
- fdsToClose = append(fdsToClose, hostNetFd)
// Clear close-on-exec
if flags, err := unix.FcntlInt(uintptr(hostNetFd), unix.F_GETFD, 0); err == nil {
@@ -174,18 +244,17 @@ func Bootstrap() (err error) {
_, _ = unix.FcntlInt(hostSocketFd, unix.F_SETFD, flags&^unix.FD_CLOEXEC)
}
env = append(env, fmt.Sprintf("WG_WRAP_HOST_SOCKET_FD=%d", hostSocketFd))
- fdsToClose = append(fdsToClose, int(hostSocketFd))
_ = conn.Close()
}
}
}
- err = syscall.Exec(fmt.Sprintf("/proc/self/fd/%d", execFd), args, env)
- if err != nil {
- return fmt.Errorf("launcher exec failed: %w", err)
- }
-
- return nil
+ return &BootstrapConfig{
+ Args: args,
+ Env: env,
+ Fds: []int{execFd, hostNetFd},
+ ExecPath: fmt.Sprintf("/proc/self/fd/%d", execFd),
+ }, nil
}
// BootstrapJoin joins the namespaces of the target PID and replaces the current process.
@@ -194,33 +263,44 @@ func BootstrapJoin(targetPid int) (err error) {
return nil
}
- var fdsToClose []int
+ config, err := PrepareBootstrapJoin(targetPid)
+ if err != nil {
+ return err
+ }
+
defer func() {
- if err != nil {
- for _, fd := range fdsToClose {
- _ = syscall.Close(fd)
- }
+ for _, fd := range config.Fds {
+ _ = syscall.Close(fd)
}
}()
+ err = syscall.Exec(config.ExecPath, config.Args, config.Env)
+ if err != nil {
+ return fmt.Errorf("launcher exec failed: %w", err)
+ }
+
+ return nil
+}
+
+// PrepareBootstrapJoin calculates the environment and arguments needed to join a namespace.
+func PrepareBootstrapJoin(targetPid int) (*BootstrapConfig, error) {
for i, arg := range os.Args {
for j := 0; j < len(arg); j++ {
if arg[j] == 0 {
- return fmt.Errorf("argument %d contains null byte at position %d", i, j)
+ return nil, fmt.Errorf("argument %d contains null byte at position %d", i, j)
}
}
}
self, err := os.Executable()
if err != nil {
- return fmt.Errorf("failed to get executable path: %w", err)
+ return nil, fmt.Errorf("failed to get executable path: %w", err)
}
execFd, err := prepareLauncher()
if err != nil {
- return err
+ return nil, err
}
- fdsToClose = append(fdsToClose, execFd)
if flags, err := unix.FcntlInt(uintptr(execFd), unix.F_GETFD, 0); err == nil {
_, _ = unix.FcntlInt(uintptr(execFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC)
@@ -232,7 +312,7 @@ func BootstrapJoin(targetPid int) (err error) {
for i, arg := range args {
for j := 0; j < len(arg); j++ {
if arg[j] == 0 {
- return fmt.Errorf("launcher argument %d contains null byte at position %d", i, j)
+ return nil, fmt.Errorf("launcher argument %d contains null byte at position %d", i, j)
}
}
}
@@ -242,12 +322,12 @@ func BootstrapJoin(targetPid int) (err error) {
"WG_WRAP_JOINED=1",
)
- err = syscall.Exec(fmt.Sprintf("/proc/self/fd/%d", execFd), args, env)
- if err != nil {
- return fmt.Errorf("launcher exec failed: %w", err)
- }
-
- return nil
+ return &BootstrapConfig{
+ Args: args,
+ Env: env,
+ Fds: []int{execFd},
+ ExecPath: fmt.Sprintf("/proc/self/fd/%d", execFd),
+ }, nil
}
func prepareLauncher() (int, error) {
diff --git a/internal/network/network.go b/internal/network/network.go
index 6afcf5e..e9dce77 100644
--- a/internal/network/network.go
+++ b/internal/network/network.go
@@ -16,9 +16,54 @@ type InterfaceInfo struct {
Index int
}
+// NetworkOps abstracts the low-level netlink operations.
+type NetworkOps interface {
+ LinkList() ([]netlink.Link, error)
+ LinkByName(name string) (netlink.Link, error)
+ LinkSetMTU(link netlink.Link, mtu int) error
+ LinkSetUp(link netlink.Link) error
+ AddrAdd(link netlink.Link, addr *netlink.Addr) error
+ RouteAdd(route *netlink.Route) error
+ RouteReplace(route *netlink.Route) error
+}
+
+// realNetworkOps is the production implementation using netlink.
+type realNetworkOps struct{}
+
+func (r *realNetworkOps) LinkList() ([]netlink.Link, error) { return netlink.LinkList() }
+func (r *realNetworkOps) LinkByName(name string) (netlink.Link, error) {
+ return netlink.LinkByName(name)
+}
+func (r *realNetworkOps) LinkSetMTU(link netlink.Link, mtu int) error {
+ return netlink.LinkSetMTU(link, mtu)
+}
+func (r *realNetworkOps) LinkSetUp(link netlink.Link) error { return netlink.LinkSetUp(link) }
+func (r *realNetworkOps) AddrAdd(link netlink.Link, addr *netlink.Addr) error {
+ return netlink.AddrAdd(link, addr)
+}
+
+func (r *realNetworkOps) RouteAdd(route *netlink.Route) error { return netlink.RouteAdd(route) }
+func (r *realNetworkOps) RouteReplace(route *netlink.Route) error { return netlink.RouteReplace(route) }
+
+// DefaultNetworkOps is the global instance used by the package functions.
+// It can be replaced during tests.
+var DefaultNetworkOps NetworkOps = &realNetworkOps{}
+
+// NetworkManager coordinates network configuration within a namespace.
+type NetworkManager struct {
+ Ops NetworkOps
+}
+
+// NewNetworkManager creates a new NetworkManager with production defaults.
+func NewNetworkManager() *NetworkManager {
+ return &NetworkManager{
+ Ops: DefaultNetworkOps,
+ }
+}
+
// ListInterfaces returns a list of all network interfaces present in the current namespace.
-func ListInterfaces() ([]InterfaceInfo, error) {
- links, err := netlink.LinkList()
+func (nm *NetworkManager) ListInterfaces() ([]InterfaceInfo, error) {
+ links, err := nm.Ops.LinkList()
if err != nil {
return nil, fmt.Errorf("failed to list interfaces: %w", err)
}
@@ -35,17 +80,17 @@ func ListInterfaces() ([]InterfaceInfo, error) {
// ConfigureInterface sets the MTU, brings the interface up, assigns an IP address,
// and configures the default route.
-func ConfigureInterface(name, address string, mtu int) error {
- link, err := netlink.LinkByName(name)
+func (nm *NetworkManager) ConfigureInterface(name, address string, mtu int) error {
+ link, err := nm.Ops.LinkByName(name)
if err != nil {
return fmt.Errorf("failed to find link %s: %w", name, err)
}
- if err := netlink.LinkSetMTU(link, mtu); err != nil {
+ if err := nm.Ops.LinkSetMTU(link, mtu); err != nil {
return fmt.Errorf("failed to set MTU %d on link %s: %w", mtu, name, err)
}
- if err := netlink.LinkSetUp(link); err != nil {
+ if err := nm.Ops.LinkSetUp(link); err != nil {
return fmt.Errorf("failed to bring up link %s: %w", name, err)
}
@@ -53,7 +98,7 @@ func ConfigureInterface(name, address string, mtu int) error {
if err != nil {
return fmt.Errorf("invalid IP address %s: %w", address, err)
}
- if err := netlink.AddrAdd(link, addr); err != nil {
+ if err := nm.Ops.AddrAdd(link, addr); err != nil {
if !strings.Contains(err.Error(), "file exists") {
return fmt.Errorf("failed to add address %s to link %s: %w", address, name, err)
}
@@ -72,11 +117,22 @@ func ConfigureInterface(name, address string, mtu int) error {
Dst: dst,
}
- if err := netlink.RouteAdd(route); err != nil {
- if err := netlink.RouteReplace(route); err != nil {
+ if err := nm.Ops.RouteAdd(route); err != nil {
+ if err := nm.Ops.RouteReplace(route); err != nil {
return fmt.Errorf("failed to configure default route via %s: %w", name, err)
}
}
return nil
}
+
+// ListInterfaces returns a list of all network interfaces present in the current namespace.
+func ListInterfaces() ([]InterfaceInfo, error) {
+ return NewNetworkManager().ListInterfaces()
+}
+
+// ConfigureInterface sets the MTU, brings the interface up, assigns an IP address,
+// and configures the default route.
+func ConfigureInterface(name, address string, mtu int) error {
+ return NewNetworkManager().ConfigureInterface(name, address, mtu)
+}
diff --git a/internal/network/network_test.go b/internal/network/network_test.go
new file mode 100644
index 0000000..b598484
--- /dev/null
+++ b/internal/network/network_test.go
@@ -0,0 +1,163 @@
+//go:build linux
+
+package network
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+ "testing"
+
+ "github.com/vishvananda/netlink"
+)
+
+// mockNetworkOps allows us to control the behavior of netlink calls.
+type mockNetworkOps struct {
+ linkByNameFunc func(name string) (netlink.Link, error)
+ linkSetMTUFunc func(link netlink.Link, mtu int) error
+ linkSetUpFunc func(link netlink.Link) error
+ addrAddFunc func(link netlink.Link, addr *netlink.Addr) error
+ routeAddFunc func(route *netlink.Route) error
+ routeReplaceFunc func(route *netlink.Route) error
+}
+
+func (m *mockNetworkOps) LinkList() ([]netlink.Link, error) { return nil, nil }
+func (m *mockNetworkOps) LinkByName(name string) (netlink.Link, error) {
+ if m.linkByNameFunc != nil {
+ return m.linkByNameFunc(name)
+ }
+ return nil, fmt.Errorf("not implemented")
+}
+func (m *mockNetworkOps) LinkSetMTU(link netlink.Link, mtu int) error {
+ if m.linkSetMTUFunc != nil {
+ return m.linkSetMTUFunc(link, mtu)
+ }
+ return nil
+}
+func (m *mockNetworkOps) LinkSetUp(link netlink.Link) error {
+ if m.linkSetUpFunc != nil {
+ return m.linkSetUpFunc(link)
+ }
+ return nil
+}
+func (m *mockNetworkOps) AddrAdd(link netlink.Link, addr *netlink.Addr) error {
+ if m.addrAddFunc != nil {
+ return m.addrAddFunc(link, addr)
+ }
+ return nil
+}
+func (m *mockNetworkOps) RouteAdd(route *netlink.Route) error {
+ if m.routeAddFunc != nil {
+ return m.routeAddFunc(route)
+ }
+ return nil
+}
+func (m *mockNetworkOps) RouteReplace(route *netlink.Route) error {
+ if m.routeReplaceFunc != nil {
+ return m.routeReplaceFunc(route)
+ }
+ return nil
+}
+
+// mockLink implements netlink.Link.
+type mockLink struct {
+ name string
+ idx int
+}
+
+func (m *mockLink) Type() string {
+ return "mock"
+}
+
+func (m *mockLink) Attrs() *netlink.LinkAttrs {
+ return &netlink.LinkAttrs{Name: m.name, Index: m.idx}
+}
+
+func TestConfigureInterface_Success(t *testing.T) {
+ t.Parallel()
+ mock := &mockNetworkOps{
+ linkByNameFunc: func(name string) (netlink.Link, error) {
+ return &mockLink{name: name, idx: 1}, nil
+ },
+ }
+ nm := &NetworkManager{Ops: mock}
+
+ err := nm.ConfigureInterface("tun0", "10.0.0.1/24", 1420)
+ if err != nil {
+ t.Errorf("expected success, got %v", err)
+ }
+}
+
+func TestConfigureInterface_RouteFallback(t *testing.T) {
+ t.Parallel()
+ routeAddCalled := false
+ routeReplaceCalled := false
+
+ mock := &mockNetworkOps{
+ linkByNameFunc: func(name string) (netlink.Link, error) {
+ return &mockLink{name: name, idx: 1}, nil
+ },
+ routeAddFunc: func(route *netlink.Route) error {
+ routeAddCalled = true
+ return errors.New("file exists") // Simulate EEXIST
+ },
+ routeReplaceFunc: func(route *netlink.Route) error {
+ routeReplaceCalled = true
+ return nil
+ },
+ }
+ nm := &NetworkManager{Ops: mock}
+
+ err := nm.ConfigureInterface("tun0", "10.0.0.1/24", 1420)
+ if err != nil {
+ t.Errorf("expected success after fallback, got %v", err)
+ }
+ if !routeAddCalled {
+ t.Error("expected RouteAdd to be called first")
+ }
+ if !routeReplaceCalled {
+ t.Error("expected RouteReplace to be called after RouteAdd fails with 'file exists'")
+ }
+}
+
+func TestConfigureInterface_RouteFailure(t *testing.T) {
+ t.Parallel()
+ mock := &mockNetworkOps{
+ linkByNameFunc: func(name string) (netlink.Link, error) {
+ return &mockLink{name: name, idx: 1}, nil
+ },
+ routeAddFunc: func(route *netlink.Route) error {
+ return errors.New("critical network failure")
+ },
+ routeReplaceFunc: func(route *netlink.Route) error {
+ return errors.New("critical network failure")
+ },
+ }
+ nm := &NetworkManager{Ops: mock}
+
+ err := nm.ConfigureInterface("tun0", "10.0.0.1/24", 1420)
+ if err == nil {
+ t.Error("expected error when both RouteAdd and RouteReplace fail, got nil")
+ }
+ if !strings.Contains(err.Error(), "failed to configure default route") {
+ t.Errorf("expected route error, got: %v", err)
+ }
+}
+
+func TestConfigureInterface_LinkNotFound(t *testing.T) {
+ t.Parallel()
+ mock := &mockNetworkOps{
+ linkByNameFunc: func(name string) (netlink.Link, error) {
+ return nil, errors.New("no such device")
+ },
+ }
+ nm := &NetworkManager{Ops: mock}
+
+ err := nm.ConfigureInterface("nonexistent", "10.0.0.1/24", 1420)
+ if err == nil {
+ t.Error("expected error when link is not found, got nil")
+ }
+ if !strings.Contains(err.Error(), "failed to find link") {
+ t.Errorf("expected link not found error, got: %v", err)
+ }
+}
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index 8ffe794..67cc211 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -51,8 +51,24 @@ type Tunnel struct {
dnsFile string
}
+// TunnelManager coordinates the creation and management of a WireGuard tunnel.
+type TunnelManager struct {
+ MountOps namespace.MountOps
+ FS namespace.FileSystem
+ Net *network.NetworkManager
+}
+
+// NewTunnelManager creates a new TunnelManager with production defaults.
+func NewTunnelManager() *TunnelManager {
+ return &TunnelManager{
+ MountOps: namespace.DefaultMountOps,
+ FS: namespace.DefaultFS,
+ Net: network.NewNetworkManager(),
+ }
+}
+
// StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes.
-func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
+func (tm *TunnelManager) StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
var cleanups []func()
defer func() {
if err != nil {
@@ -66,11 +82,11 @@ func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsS
tunName := "tun0"
mtu := 1420
- if err := unix.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil {
+ if err := tm.MountOps.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil {
fmt.Printf("warning: failed to make mount namespace private: %v\n", err)
}
- if err := BlockHostServices(pm, profile); err != nil {
+ if err := tm.BlockHostServices(pm, profile); err != nil {
fmt.Printf("warning: failed to block host services: %v\n", err)
}
@@ -118,18 +134,18 @@ func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsS
}
// 4. Configure network interface
- if err := network.ConfigureInterface(tunName, cfg.Address, mtu); err != nil {
+ if err := tm.Net.ConfigureInterface(tunName, cfg.Address, mtu); err != nil {
return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err)
}
var dnsFile string
profileDir := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile)
- if path, err := ConfigureResolvConf(dnsServer, profileDir); err != nil {
+ if path, err := tm.ConfigureResolvConf(dnsServer, profileDir); err != nil {
fmt.Printf("warning: failed to configure DNS resolver: %v\n", err)
} else {
dnsFile = path
cleanups = append(cleanups, func() {
- if err := UnmountResolvConf(dnsFile); err != nil {
+ if err := tm.UnmountResolvConf(dnsFile); err != nil {
fmt.Printf("warning: failed to unmount resolv.conf during cleanup: %v\n", err)
}
})
@@ -143,12 +159,12 @@ func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsS
}
// Close shuts down the userspace WireGuard device and closes the TUN interface.
-func (t *Tunnel) Close() {
+func (t *Tunnel) Close(tm *TunnelManager) {
if t.Device != nil {
t.Device.Close()
}
if t.dnsFile != "" {
- if err := UnmountResolvConf(t.dnsFile); err != nil {
+ if err := tm.UnmountResolvConf(t.dnsFile); err != nil {
fmt.Printf("warning: failed to unmount resolv.conf: %v\n", err)
}
}
@@ -216,14 +232,12 @@ func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) {
}
// ConfigureResolvConf creates a temporary resolv.conf file and bind-mounts it to /etc/resolv.conf.
-func ConfigureResolvConf(dns string, profileDir string) (string, error) {
+func (tm *TunnelManager) ConfigureResolvConf(dns string, profileDir string) (string, error) {
if dns == "" {
return "", nil
}
- // Create the temporary resolv.conf file within the profile directory
- // so it can be cleaned up during namespace unpinning.
- tmpFile, err := os.CreateTemp(profileDir, "resolvconf")
+ tmpFile, err := tm.FS.CreateTemp(profileDir, "resolvconf")
if err != nil {
return "", fmt.Errorf("failed to create temp resolv.conf in %s: %w", profileDir, err)
}
@@ -235,11 +249,11 @@ func ConfigureResolvConf(dns string, profileDir string) (string, error) {
}
_ = tmpFile.Close()
- if err := unix.Mount(launcherPath, "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil {
+ if err := tm.MountOps.Mount(launcherPath, "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil {
return "", fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", launcherPath, err)
}
- if err := unix.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil {
+ if err := tm.MountOps.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil {
fmt.Printf("warning: failed to make /etc/resolv.conf mount private: %v\n", err)
}
@@ -247,16 +261,14 @@ func ConfigureResolvConf(dns string, profileDir string) (string, error) {
}
// UnmountResolvConf unmounts the bind-mounted resolv.conf and removes the temporary file.
-func UnmountResolvConf(path string) error {
+func (tm *TunnelManager) UnmountResolvConf(path string) error {
if path == "" {
return nil
}
- // Attempt to unmount. If it fails, it might already be unmounted
- // or the namespace might be gone.
- _ = unix.Unmount("/etc/resolv.conf", unix.MNT_DETACH)
+ _ = tm.MountOps.Unmount("/etc/resolv.conf", unix.MNT_DETACH)
- if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
+ if err := tm.FS.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove temp resolv.conf file %s: %w", path, err)
}
return nil
@@ -264,18 +276,18 @@ func UnmountResolvConf(path string) error {
// BlockHostServices bind-mounts empty files/directories over sensitive host services
// to prevent access from within the isolated namespace.
-func BlockHostServices(pm *paths.PathManager, profile string) error {
+func (tm *TunnelManager) BlockHostServices(pm *paths.PathManager, profile string) error {
blockDirBase := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "block")
- if err := os.MkdirAll(blockDirBase, 0755); err != nil {
+ if err := tm.FS.MkdirAll(blockDirBase, 0755); err != nil {
return fmt.Errorf("failed to create block base directory: %w", err)
}
- tmpDir, err := os.MkdirTemp(blockDirBase, "dir-")
+ tmpDir, err := tm.FS.MkdirTemp(blockDirBase, "dir-")
if err != nil {
return fmt.Errorf("failed to create temp block dir: %w", err)
}
- tmpFile, err := os.CreateTemp(blockDirBase, "file-")
+ tmpFile, err := tm.FS.CreateTemp(blockDirBase, "file-")
if err != nil {
return fmt.Errorf("failed to create temp block file: %w", err)
}
@@ -283,16 +295,16 @@ func BlockHostServices(pm *paths.PathManager, profile string) error {
_ = tmpFile.Close()
for _, p := range namespace.GetBlockPaths() {
- stat, err := os.Stat(p)
+ stat, err := tm.FS.Stat(p)
if err == nil {
source := tmpFileName
if stat.IsDir() {
source = tmpDir
}
- if err := unix.Mount(source, p, "", unix.MS_BIND, ""); err != nil {
+ if err := tm.MountOps.Mount(source, p, "", unix.MS_BIND, ""); err != nil {
fmt.Printf("warning: failed to bind-mount block over %s: %v\n", p, err)
} else {
- _ = unix.Mount("", p, "", unix.MS_PRIVATE, "")
+ _ = tm.MountOps.Mount("", p, "", unix.MS_PRIVATE, "")
}
}
}
diff --git a/internal/wireguard/wireguard_unit_test.go b/internal/wireguard/wireguard_unit_test.go
new file mode 100644
index 0000000..1ad7f65
--- /dev/null
+++ b/internal/wireguard/wireguard_unit_test.go
@@ -0,0 +1,195 @@
+//go:build linux
+
+package wireguard
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "git.theodohertyfamily.com/wg-wrap/internal/namespace"
+ "git.theodohertyfamily.com/wg-wrap/internal/paths"
+ "golang.org/x/sys/unix"
+)
+
+// mockMountOps records mount calls for verification.
+type mockMountOps struct {
+ mounts []mountCall
+ unmounts []string
+}
+
+type mountCall struct {
+ source string
+ target string
+ flags uintptr
+}
+
+func (m *mockMountOps) Mount(source, target, fstype string, flags uintptr, data string) error {
+ m.mounts = append(m.mounts, mountCall{source, target, flags})
+ return nil
+}
+
+func (m *mockMountOps) Unmount(target string, flags int) error {
+ m.unmounts = append(m.unmounts, target)
+ return nil
+}
+
+// mockFS mimics a filesystem using a temporary directory.
+type mockFS struct {
+ root string
+}
+
+func (m *mockFS) fullPath(path string) string {
+ if filepath.IsAbs(path) {
+ path = strings.TrimPrefix(path, "/")
+ }
+ return filepath.Join(m.root, path)
+}
+
+func (m *mockFS) Stat(name string) (os.FileInfo, error) {
+ return os.Stat(m.fullPath(name))
+}
+
+func (m *mockFS) MkdirAll(path string, perm os.FileMode) error {
+ return os.MkdirAll(m.fullPath(path), perm)
+}
+
+func (m *mockFS) CreateTemp(dir, pattern string) (*os.File, error) {
+ return os.CreateTemp(m.fullPath(dir), pattern)
+}
+
+func (m *mockFS) MkdirTemp(dir, pattern string) (string, error) {
+ dirPath := m.fullPath(dir)
+ res, err := os.MkdirTemp(dirPath, pattern)
+ if err != nil {
+ return "", err
+ }
+ // Return path relative to root for consistency if needed, but OS calls return absolute.
+ // The TunnelManager expects a path it can pass to Remove() later.
+ // If we return absolute, mockFS.Remove must handle it.
+ return res, nil
+}
+
+func (m *mockFS) Remove(name string) error {
+ // If the path is absolute and starts with our root, we can remove it directly.
+ // Otherwise, we use fullPath to ensure it's within root.
+ if filepath.IsAbs(name) && strings.HasPrefix(name, m.root) {
+ return os.Remove(name)
+ }
+ return os.Remove(m.fullPath(name))
+}
+
+// setupParallelEnv creates a clean, isolated test environment for a single test.
+func setupParallelEnv(t *testing.T) (*paths.PathManager, *mockMountOps, *mockFS, *TunnelManager) {
+ t.Helper()
+
+ tmpDir := t.TempDir()
+ hostRoot := t.TempDir()
+
+ pm := paths.NewPathManager("", tmpDir)
+ mockMounts := &mockMountOps{}
+ mockFS := &mockFS{root: hostRoot}
+
+ tm := &TunnelManager{
+ MountOps: mockMounts,
+ FS: mockFS,
+ }
+
+ return pm, mockMounts, mockFS, tm
+}
+
+func TestBlockHostServices_Logic(t *testing.T) {
+ t.Parallel()
+
+ pm, mockMounts, mockFS, tm := setupParallelEnv(t)
+ profile := "test-profile"
+
+ blockPaths := namespace.GetBlockPaths()
+ for _, p := range blockPaths {
+ fullPath := mockFS.fullPath(p)
+ _ = os.MkdirAll(filepath.Dir(fullPath), 0755)
+ _ = os.WriteFile(fullPath, []byte("data"), 0644)
+ }
+
+ err := tm.BlockHostServices(pm, profile)
+ if err != nil {
+ t.Fatalf("BlockHostServices failed: %v", err)
+ }
+
+ for _, p := range blockPaths {
+ found := false
+ for _, m := range mockMounts.mounts {
+ if m.target == p {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("path %s was not blocked", p)
+ }
+ }
+}
+
+func TestConfigureResolvConf_Logic(t *testing.T) {
+ t.Parallel()
+
+ _, mockMounts, mockFS, tm := setupParallelEnv(t)
+
+ // Use a directory relative to mockFS.root
+ profileDir := "profile-dir"
+ _ = mockFS.MkdirAll(profileDir, 0755)
+ dnsServer := "1.1.1.1"
+
+ path, err := tm.ConfigureResolvConf(dnsServer, profileDir)
+ if err != nil {
+ t.Fatalf("ConfigureResolvConf failed: %v", err)
+ }
+
+ // path is returned from os.CreateTemp, which is absolute.
+ if _, err := os.Stat(path); os.IsNotExist(err) {
+ t.Errorf("resolv.conf file was not created at %s", path)
+ }
+
+ if len(mockMounts.mounts) < 2 {
+ t.Fatalf("expected 2 mounts, got %d", len(mockMounts.mounts))
+ }
+
+ if mockMounts.mounts[0].target != "/etc/resolv.conf" || mockMounts.mounts[0].flags != unix.MS_BIND {
+ t.Errorf("expected BIND mount to /etc/resolv.conf, got %+v", mockMounts.mounts[0])
+ }
+
+ if mockMounts.mounts[1].target != "/etc/resolv.conf" || mockMounts.mounts[1].flags != unix.MS_PRIVATE {
+ t.Errorf("expected PRIVATE mount to /etc/resolv.conf, got %+v", mockMounts.mounts[1])
+ }
+}
+
+func TestUnmountResolvConf_Logic(t *testing.T) {
+ t.Parallel()
+
+ _, mockMounts, mockFS, tm := setupParallelEnv(t)
+
+ // Create a file within the mock FS root
+ profileDir := "profile-dir"
+ _ = mockFS.MkdirAll(profileDir, 0755)
+ tmpFile := "resolv.conf.tmp"
+ fullPath := mockFS.fullPath(filepath.Join(profileDir, tmpFile))
+
+ err := os.WriteFile(fullPath, []byte("test"), 0644)
+ if err != nil {
+ t.Fatalf("failed to write temp resolv.conf: %v", err)
+ }
+
+ err = tm.UnmountResolvConf(fullPath)
+ if err != nil {
+ t.Errorf("UnmountResolvConf failed: %v", err)
+ }
+
+ if len(mockMounts.unmounts) == 0 || mockMounts.unmounts[0] != "/etc/resolv.conf" {
+ t.Errorf("expected unmount of /etc/resolv.conf, got %+v", mockMounts.unmounts)
+ }
+
+ if _, err := os.Stat(fullPath); !os.IsNotExist(err) {
+ t.Error("temporary resolv.conf file was not removed")
+ }
+}