summaryrefslogtreecommitdiff
path: root/internal/wireguard
diff options
context:
space:
mode:
Diffstat (limited to 'internal/wireguard')
-rw-r--r--internal/wireguard/wireguard.go64
-rw-r--r--internal/wireguard/wireguard_unit_test.go195
2 files changed, 233 insertions, 26 deletions
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index 8ffe794..67cc211 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -51,8 +51,24 @@ type Tunnel struct {
dnsFile string
}
+// TunnelManager coordinates the creation and management of a WireGuard tunnel.
+type TunnelManager struct {
+ MountOps namespace.MountOps
+ FS namespace.FileSystem
+ Net *network.NetworkManager
+}
+
+// NewTunnelManager creates a new TunnelManager with production defaults.
+func NewTunnelManager() *TunnelManager {
+ return &TunnelManager{
+ MountOps: namespace.DefaultMountOps,
+ FS: namespace.DefaultFS,
+ Net: network.NewNetworkManager(),
+ }
+}
+
// StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes.
-func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
+func (tm *TunnelManager) StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) {
var cleanups []func()
defer func() {
if err != nil {
@@ -66,11 +82,11 @@ func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsS
tunName := "tun0"
mtu := 1420
- if err := unix.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil {
+ if err := tm.MountOps.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil {
fmt.Printf("warning: failed to make mount namespace private: %v\n", err)
}
- if err := BlockHostServices(pm, profile); err != nil {
+ if err := tm.BlockHostServices(pm, profile); err != nil {
fmt.Printf("warning: failed to block host services: %v\n", err)
}
@@ -118,18 +134,18 @@ func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsS
}
// 4. Configure network interface
- if err := network.ConfigureInterface(tunName, cfg.Address, mtu); err != nil {
+ if err := tm.Net.ConfigureInterface(tunName, cfg.Address, mtu); err != nil {
return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err)
}
var dnsFile string
profileDir := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile)
- if path, err := ConfigureResolvConf(dnsServer, profileDir); err != nil {
+ if path, err := tm.ConfigureResolvConf(dnsServer, profileDir); err != nil {
fmt.Printf("warning: failed to configure DNS resolver: %v\n", err)
} else {
dnsFile = path
cleanups = append(cleanups, func() {
- if err := UnmountResolvConf(dnsFile); err != nil {
+ if err := tm.UnmountResolvConf(dnsFile); err != nil {
fmt.Printf("warning: failed to unmount resolv.conf during cleanup: %v\n", err)
}
})
@@ -143,12 +159,12 @@ func StartTunnel(pm *paths.PathManager, profile string, cfg *wgconf.Config, dnsS
}
// Close shuts down the userspace WireGuard device and closes the TUN interface.
-func (t *Tunnel) Close() {
+func (t *Tunnel) Close(tm *TunnelManager) {
if t.Device != nil {
t.Device.Close()
}
if t.dnsFile != "" {
- if err := UnmountResolvConf(t.dnsFile); err != nil {
+ if err := tm.UnmountResolvConf(t.dnsFile); err != nil {
fmt.Printf("warning: failed to unmount resolv.conf: %v\n", err)
}
}
@@ -216,14 +232,12 @@ func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) {
}
// ConfigureResolvConf creates a temporary resolv.conf file and bind-mounts it to /etc/resolv.conf.
-func ConfigureResolvConf(dns string, profileDir string) (string, error) {
+func (tm *TunnelManager) ConfigureResolvConf(dns string, profileDir string) (string, error) {
if dns == "" {
return "", nil
}
- // Create the temporary resolv.conf file within the profile directory
- // so it can be cleaned up during namespace unpinning.
- tmpFile, err := os.CreateTemp(profileDir, "resolvconf")
+ tmpFile, err := tm.FS.CreateTemp(profileDir, "resolvconf")
if err != nil {
return "", fmt.Errorf("failed to create temp resolv.conf in %s: %w", profileDir, err)
}
@@ -235,11 +249,11 @@ func ConfigureResolvConf(dns string, profileDir string) (string, error) {
}
_ = tmpFile.Close()
- if err := unix.Mount(launcherPath, "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil {
+ if err := tm.MountOps.Mount(launcherPath, "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil {
return "", fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", launcherPath, err)
}
- if err := unix.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil {
+ if err := tm.MountOps.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil {
fmt.Printf("warning: failed to make /etc/resolv.conf mount private: %v\n", err)
}
@@ -247,16 +261,14 @@ func ConfigureResolvConf(dns string, profileDir string) (string, error) {
}
// UnmountResolvConf unmounts the bind-mounted resolv.conf and removes the temporary file.
-func UnmountResolvConf(path string) error {
+func (tm *TunnelManager) UnmountResolvConf(path string) error {
if path == "" {
return nil
}
- // Attempt to unmount. If it fails, it might already be unmounted
- // or the namespace might be gone.
- _ = unix.Unmount("/etc/resolv.conf", unix.MNT_DETACH)
+ _ = tm.MountOps.Unmount("/etc/resolv.conf", unix.MNT_DETACH)
- if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
+ if err := tm.FS.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove temp resolv.conf file %s: %w", path, err)
}
return nil
@@ -264,18 +276,18 @@ func UnmountResolvConf(path string) error {
// BlockHostServices bind-mounts empty files/directories over sensitive host services
// to prevent access from within the isolated namespace.
-func BlockHostServices(pm *paths.PathManager, profile string) error {
+func (tm *TunnelManager) BlockHostServices(pm *paths.PathManager, profile string) error {
blockDirBase := filepath.Join(pm.RuntimeBaseDir(), "profiles", profile, "block")
- if err := os.MkdirAll(blockDirBase, 0755); err != nil {
+ if err := tm.FS.MkdirAll(blockDirBase, 0755); err != nil {
return fmt.Errorf("failed to create block base directory: %w", err)
}
- tmpDir, err := os.MkdirTemp(blockDirBase, "dir-")
+ tmpDir, err := tm.FS.MkdirTemp(blockDirBase, "dir-")
if err != nil {
return fmt.Errorf("failed to create temp block dir: %w", err)
}
- tmpFile, err := os.CreateTemp(blockDirBase, "file-")
+ tmpFile, err := tm.FS.CreateTemp(blockDirBase, "file-")
if err != nil {
return fmt.Errorf("failed to create temp block file: %w", err)
}
@@ -283,16 +295,16 @@ func BlockHostServices(pm *paths.PathManager, profile string) error {
_ = tmpFile.Close()
for _, p := range namespace.GetBlockPaths() {
- stat, err := os.Stat(p)
+ stat, err := tm.FS.Stat(p)
if err == nil {
source := tmpFileName
if stat.IsDir() {
source = tmpDir
}
- if err := unix.Mount(source, p, "", unix.MS_BIND, ""); err != nil {
+ if err := tm.MountOps.Mount(source, p, "", unix.MS_BIND, ""); err != nil {
fmt.Printf("warning: failed to bind-mount block over %s: %v\n", p, err)
} else {
- _ = unix.Mount("", p, "", unix.MS_PRIVATE, "")
+ _ = tm.MountOps.Mount("", p, "", unix.MS_PRIVATE, "")
}
}
}
diff --git a/internal/wireguard/wireguard_unit_test.go b/internal/wireguard/wireguard_unit_test.go
new file mode 100644
index 0000000..1ad7f65
--- /dev/null
+++ b/internal/wireguard/wireguard_unit_test.go
@@ -0,0 +1,195 @@
+//go:build linux
+
+package wireguard
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "git.theodohertyfamily.com/wg-wrap/internal/namespace"
+ "git.theodohertyfamily.com/wg-wrap/internal/paths"
+ "golang.org/x/sys/unix"
+)
+
+// mockMountOps records mount calls for verification.
+type mockMountOps struct {
+ mounts []mountCall
+ unmounts []string
+}
+
+type mountCall struct {
+ source string
+ target string
+ flags uintptr
+}
+
+func (m *mockMountOps) Mount(source, target, fstype string, flags uintptr, data string) error {
+ m.mounts = append(m.mounts, mountCall{source, target, flags})
+ return nil
+}
+
+func (m *mockMountOps) Unmount(target string, flags int) error {
+ m.unmounts = append(m.unmounts, target)
+ return nil
+}
+
+// mockFS mimics a filesystem using a temporary directory.
+type mockFS struct {
+ root string
+}
+
+func (m *mockFS) fullPath(path string) string {
+ if filepath.IsAbs(path) {
+ path = strings.TrimPrefix(path, "/")
+ }
+ return filepath.Join(m.root, path)
+}
+
+func (m *mockFS) Stat(name string) (os.FileInfo, error) {
+ return os.Stat(m.fullPath(name))
+}
+
+func (m *mockFS) MkdirAll(path string, perm os.FileMode) error {
+ return os.MkdirAll(m.fullPath(path), perm)
+}
+
+func (m *mockFS) CreateTemp(dir, pattern string) (*os.File, error) {
+ return os.CreateTemp(m.fullPath(dir), pattern)
+}
+
+func (m *mockFS) MkdirTemp(dir, pattern string) (string, error) {
+ dirPath := m.fullPath(dir)
+ res, err := os.MkdirTemp(dirPath, pattern)
+ if err != nil {
+ return "", err
+ }
+ // Return path relative to root for consistency if needed, but OS calls return absolute.
+ // The TunnelManager expects a path it can pass to Remove() later.
+ // If we return absolute, mockFS.Remove must handle it.
+ return res, nil
+}
+
+func (m *mockFS) Remove(name string) error {
+ // If the path is absolute and starts with our root, we can remove it directly.
+ // Otherwise, we use fullPath to ensure it's within root.
+ if filepath.IsAbs(name) && strings.HasPrefix(name, m.root) {
+ return os.Remove(name)
+ }
+ return os.Remove(m.fullPath(name))
+}
+
+// setupParallelEnv creates a clean, isolated test environment for a single test.
+func setupParallelEnv(t *testing.T) (*paths.PathManager, *mockMountOps, *mockFS, *TunnelManager) {
+ t.Helper()
+
+ tmpDir := t.TempDir()
+ hostRoot := t.TempDir()
+
+ pm := paths.NewPathManager("", tmpDir)
+ mockMounts := &mockMountOps{}
+ mockFS := &mockFS{root: hostRoot}
+
+ tm := &TunnelManager{
+ MountOps: mockMounts,
+ FS: mockFS,
+ }
+
+ return pm, mockMounts, mockFS, tm
+}
+
+func TestBlockHostServices_Logic(t *testing.T) {
+ t.Parallel()
+
+ pm, mockMounts, mockFS, tm := setupParallelEnv(t)
+ profile := "test-profile"
+
+ blockPaths := namespace.GetBlockPaths()
+ for _, p := range blockPaths {
+ fullPath := mockFS.fullPath(p)
+ _ = os.MkdirAll(filepath.Dir(fullPath), 0755)
+ _ = os.WriteFile(fullPath, []byte("data"), 0644)
+ }
+
+ err := tm.BlockHostServices(pm, profile)
+ if err != nil {
+ t.Fatalf("BlockHostServices failed: %v", err)
+ }
+
+ for _, p := range blockPaths {
+ found := false
+ for _, m := range mockMounts.mounts {
+ if m.target == p {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("path %s was not blocked", p)
+ }
+ }
+}
+
+func TestConfigureResolvConf_Logic(t *testing.T) {
+ t.Parallel()
+
+ _, mockMounts, mockFS, tm := setupParallelEnv(t)
+
+ // Use a directory relative to mockFS.root
+ profileDir := "profile-dir"
+ _ = mockFS.MkdirAll(profileDir, 0755)
+ dnsServer := "1.1.1.1"
+
+ path, err := tm.ConfigureResolvConf(dnsServer, profileDir)
+ if err != nil {
+ t.Fatalf("ConfigureResolvConf failed: %v", err)
+ }
+
+ // path is returned from os.CreateTemp, which is absolute.
+ if _, err := os.Stat(path); os.IsNotExist(err) {
+ t.Errorf("resolv.conf file was not created at %s", path)
+ }
+
+ if len(mockMounts.mounts) < 2 {
+ t.Fatalf("expected 2 mounts, got %d", len(mockMounts.mounts))
+ }
+
+ if mockMounts.mounts[0].target != "/etc/resolv.conf" || mockMounts.mounts[0].flags != unix.MS_BIND {
+ t.Errorf("expected BIND mount to /etc/resolv.conf, got %+v", mockMounts.mounts[0])
+ }
+
+ if mockMounts.mounts[1].target != "/etc/resolv.conf" || mockMounts.mounts[1].flags != unix.MS_PRIVATE {
+ t.Errorf("expected PRIVATE mount to /etc/resolv.conf, got %+v", mockMounts.mounts[1])
+ }
+}
+
+func TestUnmountResolvConf_Logic(t *testing.T) {
+ t.Parallel()
+
+ _, mockMounts, mockFS, tm := setupParallelEnv(t)
+
+ // Create a file within the mock FS root
+ profileDir := "profile-dir"
+ _ = mockFS.MkdirAll(profileDir, 0755)
+ tmpFile := "resolv.conf.tmp"
+ fullPath := mockFS.fullPath(filepath.Join(profileDir, tmpFile))
+
+ err := os.WriteFile(fullPath, []byte("test"), 0644)
+ if err != nil {
+ t.Fatalf("failed to write temp resolv.conf: %v", err)
+ }
+
+ err = tm.UnmountResolvConf(fullPath)
+ if err != nil {
+ t.Errorf("UnmountResolvConf failed: %v", err)
+ }
+
+ if len(mockMounts.unmounts) == 0 || mockMounts.unmounts[0] != "/etc/resolv.conf" {
+ t.Errorf("expected unmount of /etc/resolv.conf, got %+v", mockMounts.unmounts)
+ }
+
+ if _, err := os.Stat(fullPath); !os.IsNotExist(err) {
+ t.Error("temporary resolv.conf file was not removed")
+ }
+}