summaryrefslogtreecommitdiff
path: root/internal/namespace
diff options
context:
space:
mode:
authorJames O'Doherty <james@theodohertyfamily.com>2026-06-13 11:51:04 -0400
committerJames O'Doherty <james@theodohertyfamily.com>2026-06-13 11:51:04 -0400
commit29621ecbd1e77e6e1a70b6b3ea8fbe3a56e47df3 (patch)
treefa54976bbb0c4e9db59c983e7cb4e60c5119d18b /internal/namespace
parentf8afb7d5889f5c8b6ea256fd078fa8426d21c7be (diff)
refactor: implement dependency injection and enable parallel testing
This commit refactors the core system operations to use a manager-based dependency injection pattern, eliminating global state and resolving data races in the test suite. Architecture: - Introduced NetworkManager and NetworkOps interface in internal/network to abstract netlink calls. - Introduced MountOps and FileSystem interfaces in internal/namespace to abstract mount and filesystem operations. - Introduced TunnelManager in internal/wireguard to coordinate tunnel lifecycle using the new abstractions. - Updated internal/cli and internal/manager to use these managers. Testing: - Restored t.Parallel() to unit tests in internal/network and internal/wireguard. - Implemented setupParallelEnv and an enhanced mockFS in wireguard_unit_test.go to ensure complete test isolation. - Added bootstrap_test.go to verify launcher preparation logic in internal/namespace without requiring syscall.Exec. - Resolved data races in internal/network tests. CLI: - Added support for -h, --help, and -help flags for the main command. Verification: - Passed all tests (unit, integration, E2E). - Verified zero data races with 'go test -race'. - Passed golangci-lint and go vet.
Diffstat (limited to 'internal/namespace')
-rw-r--r--internal/namespace/bootstrap_test.go111
-rw-r--r--internal/namespace/namespace.go152
2 files changed, 227 insertions, 36 deletions
diff --git a/internal/namespace/bootstrap_test.go b/internal/namespace/bootstrap_test.go
new file mode 100644
index 0000000..eff04e2
--- /dev/null
+++ b/internal/namespace/bootstrap_test.go
@@ -0,0 +1,111 @@
+//go:build linux
+
+package namespace
+
+import (
+ "os"
+ "strings"
+ "testing"
+)
+
+func TestPrepareBootstrap(t *testing.T) {
+ t.Parallel()
+
+ // We are testing the environment and argument preparation.
+ // This will actually open some FDs (memfd, netns, socket),
+ // which is fine for an integration-style unit test.
+ config, err := PrepareBootstrap()
+ if err != nil {
+ t.Fatalf("PrepareBootstrap failed: %v", err)
+ }
+
+ // 1. Verify ExecPath is correct (should be a fd in /proc/self/fd)
+ if !strings.HasPrefix(config.ExecPath, "/proc/self/fd/") {
+ t.Errorf("expected ExecPath to be in /proc/self/fd/, got %s", config.ExecPath)
+ }
+
+ // 2. Verify Arguments are preserved and expanded
+ if len(config.Args) == 0 || config.Args[0] == "" {
+ t.Error("args should not be empty")
+ }
+
+ // 3. Verify Host NetNS FD is present in Env
+ foundNetNs := false
+ for _, env := range config.Env {
+ if strings.HasPrefix(env, "WG_WRAP_HOST_NETNS_FD=") {
+ foundNetNs = true
+ break
+ }
+ }
+ if !foundNetNs {
+ t.Error("WG_WRAP_HOST_NETNS_FD missing from environment")
+ }
+
+ // 4. Verify Host Socket FD is present in Env
+ foundSocket := false
+ for _, env := range config.Env {
+ if strings.HasPrefix(env, "WG_WRAP_HOST_SOCKET_FD=") {
+ foundSocket = true
+ break
+ }
+ }
+ if !foundSocket {
+ t.Error("WG_WRAP_HOST_SOCKET_FD missing from environment")
+ }
+
+ // 5. Verify FDs are tracked for cleanup
+ if len(config.Fds) < 2 {
+ t.Errorf("expected at least 2 FDs (launcher, netns), got %d", len(config.Fds))
+ }
+}
+
+func TestPrepareBootstrapJoin(t *testing.T) {
+ t.Parallel()
+
+ targetPid := 1234
+ config, err := PrepareBootstrapJoin(targetPid)
+ if err != nil {
+ t.Fatalf("PrepareBootstrapJoin failed: %v", err)
+ }
+
+ // 1. Verify Join PID is present in Env
+ foundJoinPid := false
+ expectedPidEnv := "WG_WRAP_JOIN_PID=1234"
+ for _, env := range config.Env {
+ if env == expectedPidEnv {
+ foundJoinPid = true
+ break
+ }
+ }
+ if !foundJoinPid {
+ t.Errorf("expected %s in environment", expectedPidEnv)
+ }
+
+ // 2. Verify Joined flag is present
+ foundJoined := false
+ for _, env := range config.Env {
+ if env == "WG_WRAP_JOINED=1" {
+ foundJoined = true
+ break
+ }
+ }
+ if !foundJoined {
+ t.Error("WG_WRAP_JOINED=1 missing from environment")
+ }
+}
+
+func TestPrepareBootstrap_NullByteValidation(t *testing.T) {
+ // Temporarily inject a null byte into os.Args to test validation
+ // Note: os.Args is a slice, so we can modify it.
+ oldArgs := os.Args
+ defer func() { os.Args = oldArgs }()
+
+ os.Args = append(os.Args, "bad\x00arg")
+
+ _, err := PrepareBootstrap()
+ if err == nil {
+ t.Error("expected error when argument contains null byte, got nil")
+ } else if !strings.Contains(err.Error(), "contains null byte") {
+ t.Errorf("expected null byte error, got: %v", err)
+ }
+}
diff --git a/internal/namespace/namespace.go b/internal/namespace/namespace.go
index 45eba73..368775f 100644
--- a/internal/namespace/namespace.go
+++ b/internal/namespace/namespace.go
@@ -32,6 +32,66 @@ import (
"golang.org/x/sys/unix"
)
+// MountOps abstracts the filesystem mount operations.
+type MountOps interface {
+ Mount(source, target, fstype string, flags uintptr, data string) error
+ Unmount(target string, flags int) error
+}
+
+// realMountOps is the production implementation using unix.Mount.
+type realMountOps struct{}
+
+func (r *realMountOps) Mount(source, target, fstype string, flags uintptr, data string) error {
+ return unix.Mount(source, target, fstype, flags, data)
+}
+func (r *realMountOps) Unmount(target string, flags int) error {
+ return unix.Unmount(target, flags)
+}
+
+// DefaultMountOps is the global instance used by the package functions.
+var DefaultMountOps MountOps = &realMountOps{}
+
+// FileSystem abstracts the basic filesystem operations used for isolation.
+type FileSystem interface {
+ Stat(name string) (os.FileInfo, error)
+ MkdirAll(path string, perm os.FileMode) error
+ CreateTemp(dir, pattern string) (*os.File, error)
+ MkdirTemp(dir, pattern string) (string, error)
+ Remove(name string) error
+}
+
+// realFS is the production implementation using the os package.
+type realFS struct{}
+
+func (r *realFS) Stat(name string) (os.FileInfo, error) { return os.Stat(name) }
+func (r *realFS) MkdirAll(path string, perm os.FileMode) error {
+ return os.MkdirAll(path, perm)
+}
+func (r *realFS) CreateTemp(dir, pattern string) (*os.File, error) {
+ return os.CreateTemp(dir, pattern)
+}
+func (r *realFS) MkdirTemp(dir, pattern string) (string, error) {
+ return os.MkdirTemp(dir, pattern)
+}
+func (r *realFS) Remove(name string) error { return os.Remove(name) }
+
+// DefaultFS is the global instance used by the package functions.
+var DefaultFS FileSystem = &realFS{}
+
+// ResetDefaults restores the default implementations of MountOps and FileSystem.
+func ResetDefaults() {
+ DefaultMountOps = &realMountOps{}
+ DefaultFS = &realFS{}
+}
+
+// BootstrapConfig contains the environment and arguments needed to execute the bootstrap launcher.
+type BootstrapConfig struct {
+ Args []string
+ Env []string
+ Fds []int
+ ExecPath string
+}
+
//go:embed launcher.bin
var launcherBytes []byte
@@ -104,48 +164,59 @@ func Bootstrap() (err error) {
return nil
}
- var fdsToClose []int
+ config, err := PrepareBootstrap()
+ if err != nil {
+ return err
+ }
+
defer func() {
- if err != nil {
- for _, fd := range fdsToClose {
- _ = syscall.Close(fd)
- }
+ for _, fd := range config.Fds {
+ _ = syscall.Close(fd)
}
}()
+ err = syscall.Exec(config.ExecPath, config.Args, config.Env)
+ if err != nil {
+ return fmt.Errorf("launcher exec failed: %w", err)
+ }
+
+ return nil
+}
+
+// PrepareBootstrap calculates the environment and arguments needed for the bootstrap launcher.
+func PrepareBootstrap() (*BootstrapConfig, error) {
// 0. Validate current arguments for null bytes before proceeding.
for i, arg := range os.Args {
for j := 0; j < len(arg); j++ {
if arg[j] == 0 {
- return fmt.Errorf("argument %d contains null byte at position %d", i, j)
+ return nil, fmt.Errorf("argument %d contains null byte at position %d", i, j)
}
}
}
self, err := os.Executable()
if err != nil {
- return fmt.Errorf("failed to get executable path: %w", err)
+ return nil, fmt.Errorf("failed to get executable path: %w", err)
}
execFd, err := prepareLauncher()
if err != nil {
- return err
+ return nil, err
}
- fdsToClose = append(fdsToClose, execFd)
// Clear close-on-exec
if flags, err := unix.FcntlInt(uintptr(execFd), unix.F_GETFD, 0); err == nil {
_, _ = unix.FcntlInt(uintptr(execFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC)
}
- // 3. Prepare arguments for the launcher.
+ // Prepare arguments for the launcher.
args := []string{self}
args = append(args, os.Args[1:]...)
for i, arg := range args {
for j := 0; j < len(arg); j++ {
if arg[j] == 0 {
- return fmt.Errorf("launcher argument %d contains null byte at position %d", i, j)
+ return nil, fmt.Errorf("launcher argument %d contains null byte at position %d", i, j)
}
}
}
@@ -153,9 +224,8 @@ func Bootstrap() (err error) {
// Open the host network namespace file descriptor before unsharing.
hostNetFd, err := syscall.Open("/proc/self/ns/net", syscall.O_RDONLY, 0)
if err != nil {
- return fmt.Errorf("failed to open host netns: %w", err)
+ return nil, fmt.Errorf("failed to open host netns: %w", err)
}
- fdsToClose = append(fdsToClose, hostNetFd)
// Clear close-on-exec
if flags, err := unix.FcntlInt(uintptr(hostNetFd), unix.F_GETFD, 0); err == nil {
@@ -174,18 +244,17 @@ func Bootstrap() (err error) {
_, _ = unix.FcntlInt(hostSocketFd, unix.F_SETFD, flags&^unix.FD_CLOEXEC)
}
env = append(env, fmt.Sprintf("WG_WRAP_HOST_SOCKET_FD=%d", hostSocketFd))
- fdsToClose = append(fdsToClose, int(hostSocketFd))
_ = conn.Close()
}
}
}
- err = syscall.Exec(fmt.Sprintf("/proc/self/fd/%d", execFd), args, env)
- if err != nil {
- return fmt.Errorf("launcher exec failed: %w", err)
- }
-
- return nil
+ return &BootstrapConfig{
+ Args: args,
+ Env: env,
+ Fds: []int{execFd, hostNetFd},
+ ExecPath: fmt.Sprintf("/proc/self/fd/%d", execFd),
+ }, nil
}
// BootstrapJoin joins the namespaces of the target PID and replaces the current process.
@@ -194,33 +263,44 @@ func BootstrapJoin(targetPid int) (err error) {
return nil
}
- var fdsToClose []int
+ config, err := PrepareBootstrapJoin(targetPid)
+ if err != nil {
+ return err
+ }
+
defer func() {
- if err != nil {
- for _, fd := range fdsToClose {
- _ = syscall.Close(fd)
- }
+ for _, fd := range config.Fds {
+ _ = syscall.Close(fd)
}
}()
+ err = syscall.Exec(config.ExecPath, config.Args, config.Env)
+ if err != nil {
+ return fmt.Errorf("launcher exec failed: %w", err)
+ }
+
+ return nil
+}
+
+// PrepareBootstrapJoin calculates the environment and arguments needed to join a namespace.
+func PrepareBootstrapJoin(targetPid int) (*BootstrapConfig, error) {
for i, arg := range os.Args {
for j := 0; j < len(arg); j++ {
if arg[j] == 0 {
- return fmt.Errorf("argument %d contains null byte at position %d", i, j)
+ return nil, fmt.Errorf("argument %d contains null byte at position %d", i, j)
}
}
}
self, err := os.Executable()
if err != nil {
- return fmt.Errorf("failed to get executable path: %w", err)
+ return nil, fmt.Errorf("failed to get executable path: %w", err)
}
execFd, err := prepareLauncher()
if err != nil {
- return err
+ return nil, err
}
- fdsToClose = append(fdsToClose, execFd)
if flags, err := unix.FcntlInt(uintptr(execFd), unix.F_GETFD, 0); err == nil {
_, _ = unix.FcntlInt(uintptr(execFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC)
@@ -232,7 +312,7 @@ func BootstrapJoin(targetPid int) (err error) {
for i, arg := range args {
for j := 0; j < len(arg); j++ {
if arg[j] == 0 {
- return fmt.Errorf("launcher argument %d contains null byte at position %d", i, j)
+ return nil, fmt.Errorf("launcher argument %d contains null byte at position %d", i, j)
}
}
}
@@ -242,12 +322,12 @@ func BootstrapJoin(targetPid int) (err error) {
"WG_WRAP_JOINED=1",
)
- err = syscall.Exec(fmt.Sprintf("/proc/self/fd/%d", execFd), args, env)
- if err != nil {
- return fmt.Errorf("launcher exec failed: %w", err)
- }
-
- return nil
+ return &BootstrapConfig{
+ Args: args,
+ Env: env,
+ Fds: []int{execFd},
+ ExecPath: fmt.Sprintf("/proc/self/fd/%d", execFd),
+ }, nil
}
func prepareLauncher() (int, error) {