diff options
Diffstat (limited to 'internal/namespace')
| -rw-r--r-- | internal/namespace/bootstrap_test.go | 111 | ||||
| -rw-r--r-- | internal/namespace/namespace.go | 152 |
2 files changed, 227 insertions, 36 deletions
diff --git a/internal/namespace/bootstrap_test.go b/internal/namespace/bootstrap_test.go new file mode 100644 index 0000000..eff04e2 --- /dev/null +++ b/internal/namespace/bootstrap_test.go @@ -0,0 +1,111 @@ +//go:build linux + +package namespace + +import ( + "os" + "strings" + "testing" +) + +func TestPrepareBootstrap(t *testing.T) { + t.Parallel() + + // We are testing the environment and argument preparation. + // This will actually open some FDs (memfd, netns, socket), + // which is fine for an integration-style unit test. + config, err := PrepareBootstrap() + if err != nil { + t.Fatalf("PrepareBootstrap failed: %v", err) + } + + // 1. Verify ExecPath is correct (should be a fd in /proc/self/fd) + if !strings.HasPrefix(config.ExecPath, "/proc/self/fd/") { + t.Errorf("expected ExecPath to be in /proc/self/fd/, got %s", config.ExecPath) + } + + // 2. Verify Arguments are preserved and expanded + if len(config.Args) == 0 || config.Args[0] == "" { + t.Error("args should not be empty") + } + + // 3. Verify Host NetNS FD is present in Env + foundNetNs := false + for _, env := range config.Env { + if strings.HasPrefix(env, "WG_WRAP_HOST_NETNS_FD=") { + foundNetNs = true + break + } + } + if !foundNetNs { + t.Error("WG_WRAP_HOST_NETNS_FD missing from environment") + } + + // 4. Verify Host Socket FD is present in Env + foundSocket := false + for _, env := range config.Env { + if strings.HasPrefix(env, "WG_WRAP_HOST_SOCKET_FD=") { + foundSocket = true + break + } + } + if !foundSocket { + t.Error("WG_WRAP_HOST_SOCKET_FD missing from environment") + } + + // 5. Verify FDs are tracked for cleanup + if len(config.Fds) < 2 { + t.Errorf("expected at least 2 FDs (launcher, netns), got %d", len(config.Fds)) + } +} + +func TestPrepareBootstrapJoin(t *testing.T) { + t.Parallel() + + targetPid := 1234 + config, err := PrepareBootstrapJoin(targetPid) + if err != nil { + t.Fatalf("PrepareBootstrapJoin failed: %v", err) + } + + // 1. Verify Join PID is present in Env + foundJoinPid := false + expectedPidEnv := "WG_WRAP_JOIN_PID=1234" + for _, env := range config.Env { + if env == expectedPidEnv { + foundJoinPid = true + break + } + } + if !foundJoinPid { + t.Errorf("expected %s in environment", expectedPidEnv) + } + + // 2. Verify Joined flag is present + foundJoined := false + for _, env := range config.Env { + if env == "WG_WRAP_JOINED=1" { + foundJoined = true + break + } + } + if !foundJoined { + t.Error("WG_WRAP_JOINED=1 missing from environment") + } +} + +func TestPrepareBootstrap_NullByteValidation(t *testing.T) { + // Temporarily inject a null byte into os.Args to test validation + // Note: os.Args is a slice, so we can modify it. + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + os.Args = append(os.Args, "bad\x00arg") + + _, err := PrepareBootstrap() + if err == nil { + t.Error("expected error when argument contains null byte, got nil") + } else if !strings.Contains(err.Error(), "contains null byte") { + t.Errorf("expected null byte error, got: %v", err) + } +} diff --git a/internal/namespace/namespace.go b/internal/namespace/namespace.go index 45eba73..368775f 100644 --- a/internal/namespace/namespace.go +++ b/internal/namespace/namespace.go @@ -32,6 +32,66 @@ import ( "golang.org/x/sys/unix" ) +// MountOps abstracts the filesystem mount operations. +type MountOps interface { + Mount(source, target, fstype string, flags uintptr, data string) error + Unmount(target string, flags int) error +} + +// realMountOps is the production implementation using unix.Mount. +type realMountOps struct{} + +func (r *realMountOps) Mount(source, target, fstype string, flags uintptr, data string) error { + return unix.Mount(source, target, fstype, flags, data) +} +func (r *realMountOps) Unmount(target string, flags int) error { + return unix.Unmount(target, flags) +} + +// DefaultMountOps is the global instance used by the package functions. +var DefaultMountOps MountOps = &realMountOps{} + +// FileSystem abstracts the basic filesystem operations used for isolation. +type FileSystem interface { + Stat(name string) (os.FileInfo, error) + MkdirAll(path string, perm os.FileMode) error + CreateTemp(dir, pattern string) (*os.File, error) + MkdirTemp(dir, pattern string) (string, error) + Remove(name string) error +} + +// realFS is the production implementation using the os package. +type realFS struct{} + +func (r *realFS) Stat(name string) (os.FileInfo, error) { return os.Stat(name) } +func (r *realFS) MkdirAll(path string, perm os.FileMode) error { + return os.MkdirAll(path, perm) +} +func (r *realFS) CreateTemp(dir, pattern string) (*os.File, error) { + return os.CreateTemp(dir, pattern) +} +func (r *realFS) MkdirTemp(dir, pattern string) (string, error) { + return os.MkdirTemp(dir, pattern) +} +func (r *realFS) Remove(name string) error { return os.Remove(name) } + +// DefaultFS is the global instance used by the package functions. +var DefaultFS FileSystem = &realFS{} + +// ResetDefaults restores the default implementations of MountOps and FileSystem. +func ResetDefaults() { + DefaultMountOps = &realMountOps{} + DefaultFS = &realFS{} +} + +// BootstrapConfig contains the environment and arguments needed to execute the bootstrap launcher. +type BootstrapConfig struct { + Args []string + Env []string + Fds []int + ExecPath string +} + //go:embed launcher.bin var launcherBytes []byte @@ -104,48 +164,59 @@ func Bootstrap() (err error) { return nil } - var fdsToClose []int + config, err := PrepareBootstrap() + if err != nil { + return err + } + defer func() { - if err != nil { - for _, fd := range fdsToClose { - _ = syscall.Close(fd) - } + for _, fd := range config.Fds { + _ = syscall.Close(fd) } }() + err = syscall.Exec(config.ExecPath, config.Args, config.Env) + if err != nil { + return fmt.Errorf("launcher exec failed: %w", err) + } + + return nil +} + +// PrepareBootstrap calculates the environment and arguments needed for the bootstrap launcher. +func PrepareBootstrap() (*BootstrapConfig, error) { // 0. Validate current arguments for null bytes before proceeding. for i, arg := range os.Args { for j := 0; j < len(arg); j++ { if arg[j] == 0 { - return fmt.Errorf("argument %d contains null byte at position %d", i, j) + return nil, fmt.Errorf("argument %d contains null byte at position %d", i, j) } } } self, err := os.Executable() if err != nil { - return fmt.Errorf("failed to get executable path: %w", err) + return nil, fmt.Errorf("failed to get executable path: %w", err) } execFd, err := prepareLauncher() if err != nil { - return err + return nil, err } - fdsToClose = append(fdsToClose, execFd) // Clear close-on-exec if flags, err := unix.FcntlInt(uintptr(execFd), unix.F_GETFD, 0); err == nil { _, _ = unix.FcntlInt(uintptr(execFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC) } - // 3. Prepare arguments for the launcher. + // Prepare arguments for the launcher. args := []string{self} args = append(args, os.Args[1:]...) for i, arg := range args { for j := 0; j < len(arg); j++ { if arg[j] == 0 { - return fmt.Errorf("launcher argument %d contains null byte at position %d", i, j) + return nil, fmt.Errorf("launcher argument %d contains null byte at position %d", i, j) } } } @@ -153,9 +224,8 @@ func Bootstrap() (err error) { // Open the host network namespace file descriptor before unsharing. hostNetFd, err := syscall.Open("/proc/self/ns/net", syscall.O_RDONLY, 0) if err != nil { - return fmt.Errorf("failed to open host netns: %w", err) + return nil, fmt.Errorf("failed to open host netns: %w", err) } - fdsToClose = append(fdsToClose, hostNetFd) // Clear close-on-exec if flags, err := unix.FcntlInt(uintptr(hostNetFd), unix.F_GETFD, 0); err == nil { @@ -174,18 +244,17 @@ func Bootstrap() (err error) { _, _ = unix.FcntlInt(hostSocketFd, unix.F_SETFD, flags&^unix.FD_CLOEXEC) } env = append(env, fmt.Sprintf("WG_WRAP_HOST_SOCKET_FD=%d", hostSocketFd)) - fdsToClose = append(fdsToClose, int(hostSocketFd)) _ = conn.Close() } } } - err = syscall.Exec(fmt.Sprintf("/proc/self/fd/%d", execFd), args, env) - if err != nil { - return fmt.Errorf("launcher exec failed: %w", err) - } - - return nil + return &BootstrapConfig{ + Args: args, + Env: env, + Fds: []int{execFd, hostNetFd}, + ExecPath: fmt.Sprintf("/proc/self/fd/%d", execFd), + }, nil } // BootstrapJoin joins the namespaces of the target PID and replaces the current process. @@ -194,33 +263,44 @@ func BootstrapJoin(targetPid int) (err error) { return nil } - var fdsToClose []int + config, err := PrepareBootstrapJoin(targetPid) + if err != nil { + return err + } + defer func() { - if err != nil { - for _, fd := range fdsToClose { - _ = syscall.Close(fd) - } + for _, fd := range config.Fds { + _ = syscall.Close(fd) } }() + err = syscall.Exec(config.ExecPath, config.Args, config.Env) + if err != nil { + return fmt.Errorf("launcher exec failed: %w", err) + } + + return nil +} + +// PrepareBootstrapJoin calculates the environment and arguments needed to join a namespace. +func PrepareBootstrapJoin(targetPid int) (*BootstrapConfig, error) { for i, arg := range os.Args { for j := 0; j < len(arg); j++ { if arg[j] == 0 { - return fmt.Errorf("argument %d contains null byte at position %d", i, j) + return nil, fmt.Errorf("argument %d contains null byte at position %d", i, j) } } } self, err := os.Executable() if err != nil { - return fmt.Errorf("failed to get executable path: %w", err) + return nil, fmt.Errorf("failed to get executable path: %w", err) } execFd, err := prepareLauncher() if err != nil { - return err + return nil, err } - fdsToClose = append(fdsToClose, execFd) if flags, err := unix.FcntlInt(uintptr(execFd), unix.F_GETFD, 0); err == nil { _, _ = unix.FcntlInt(uintptr(execFd), unix.F_SETFD, flags&^unix.FD_CLOEXEC) @@ -232,7 +312,7 @@ func BootstrapJoin(targetPid int) (err error) { for i, arg := range args { for j := 0; j < len(arg); j++ { if arg[j] == 0 { - return fmt.Errorf("launcher argument %d contains null byte at position %d", i, j) + return nil, fmt.Errorf("launcher argument %d contains null byte at position %d", i, j) } } } @@ -242,12 +322,12 @@ func BootstrapJoin(targetPid int) (err error) { "WG_WRAP_JOINED=1", ) - err = syscall.Exec(fmt.Sprintf("/proc/self/fd/%d", execFd), args, env) - if err != nil { - return fmt.Errorf("launcher exec failed: %w", err) - } - - return nil + return &BootstrapConfig{ + Args: args, + Env: env, + Fds: []int{execFd}, + ExecPath: fmt.Sprintf("/proc/self/fd/%d", execFd), + }, nil } func prepareLauncher() (int, error) { |
