//go:build linux package wireguard import ( "os" "path/filepath" "strings" "testing" "git.theodohertyfamily.com/wg-wrap/internal/namespace" "git.theodohertyfamily.com/wg-wrap/internal/paths" "golang.org/x/sys/unix" ) // mockMountOps records mount calls for verification. type mockMountOps struct { mounts []mountCall unmounts []string } type mountCall struct { source string target string flags uintptr } func (m *mockMountOps) Mount(source, target, fstype string, flags uintptr, data string) error { m.mounts = append(m.mounts, mountCall{source, target, flags}) return nil } func (m *mockMountOps) Unmount(target string, flags int) error { m.unmounts = append(m.unmounts, target) return nil } // mockFS mimics a filesystem using a temporary directory. type mockFS struct { root string } func (m *mockFS) fullPath(path string) string { if filepath.IsAbs(path) { path = strings.TrimPrefix(path, "/") } return filepath.Join(m.root, path) } func (m *mockFS) Stat(name string) (os.FileInfo, error) { return os.Stat(m.fullPath(name)) } func (m *mockFS) MkdirAll(path string, perm os.FileMode) error { return os.MkdirAll(m.fullPath(path), perm) } func (m *mockFS) CreateTemp(dir, pattern string) (*os.File, error) { return os.CreateTemp(m.fullPath(dir), pattern) } func (m *mockFS) MkdirTemp(dir, pattern string) (string, error) { dirPath := m.fullPath(dir) res, err := os.MkdirTemp(dirPath, pattern) if err != nil { return "", err } // Return path relative to root for consistency if needed, but OS calls return absolute. // The TunnelManager expects a path it can pass to Remove() later. // If we return absolute, mockFS.Remove must handle it. return res, nil } func (m *mockFS) ReadFile(name string) ([]byte, error) { return os.ReadFile(m.fullPath(name)) } func (m *mockFS) Open(name string) (*os.File, error) { return os.Open(m.fullPath(name)) } func (m *mockFS) Remove(name string) error { // If the path is absolute and starts with our root, we can remove it directly. // Otherwise, we use fullPath to ensure it's within root. if filepath.IsAbs(name) && strings.HasPrefix(name, m.root) { return os.Remove(name) } return os.Remove(m.fullPath(name)) } // setupParallelEnv creates a clean, isolated test environment for a single test. func setupParallelEnv(t *testing.T) (*paths.PathManager, *mockMountOps, *mockFS, *TunnelManager) { t.Helper() tmpDir := t.TempDir() hostRoot := t.TempDir() pm := paths.NewPathManager("", tmpDir) mockMounts := &mockMountOps{} mockFS := &mockFS{root: hostRoot} tm := &TunnelManager{ MountOps: mockMounts, FS: mockFS, } return pm, mockMounts, mockFS, tm } func TestBlockHostServices_Logic(t *testing.T) { t.Parallel() pm, mockMounts, mockFS, tm := setupParallelEnv(t) profile := "test-profile" blockPaths := namespace.GetBlockPaths() for _, p := range blockPaths { fullPath := mockFS.fullPath(p) _ = os.MkdirAll(filepath.Dir(fullPath), 0755) _ = os.WriteFile(fullPath, []byte("data"), 0644) } err := tm.BlockHostServices(pm, profile) if err != nil { t.Fatalf("BlockHostServices failed: %v", err) } for _, p := range blockPaths { found := false for _, m := range mockMounts.mounts { if m.target == p { found = true break } } if !found { t.Errorf("path %s was not blocked", p) } } } func TestConfigureResolvConf_Logic(t *testing.T) { t.Parallel() _, mockMounts, mockFS, tm := setupParallelEnv(t) // Use a directory relative to mockFS.root profileDir := "profile-dir" _ = mockFS.MkdirAll(profileDir, 0755) dnsServer := "1.1.1.1" path, err := tm.ConfigureResolvConf(dnsServer, profileDir) if err != nil { t.Fatalf("ConfigureResolvConf failed: %v", err) } // path is returned from os.CreateTemp, which is absolute. if _, err := os.Stat(path); os.IsNotExist(err) { t.Errorf("resolv.conf file was not created at %s", path) } if len(mockMounts.mounts) < 2 { t.Fatalf("expected 2 mounts, got %d", len(mockMounts.mounts)) } if mockMounts.mounts[0].target != "/etc/resolv.conf" || mockMounts.mounts[0].flags != unix.MS_BIND { t.Errorf("expected BIND mount to /etc/resolv.conf, got %+v", mockMounts.mounts[0]) } if mockMounts.mounts[1].target != "/etc/resolv.conf" || mockMounts.mounts[1].flags != unix.MS_PRIVATE { t.Errorf("expected PRIVATE mount to /etc/resolv.conf, got %+v", mockMounts.mounts[1]) } } func TestUnmountResolvConf_Logic(t *testing.T) { t.Parallel() _, mockMounts, mockFS, tm := setupParallelEnv(t) // Create a file within the mock FS root profileDir := "profile-dir" _ = mockFS.MkdirAll(profileDir, 0755) tmpFile := "resolv.conf.tmp" fullPath := mockFS.fullPath(filepath.Join(profileDir, tmpFile)) err := os.WriteFile(fullPath, []byte("test"), 0644) if err != nil { t.Fatalf("failed to write temp resolv.conf: %v", err) } err = tm.UnmountResolvConf(fullPath) if err != nil { t.Errorf("UnmountResolvConf failed: %v", err) } if len(mockMounts.unmounts) == 0 || mockMounts.unmounts[0] != "/etc/resolv.conf" { t.Errorf("expected unmount of /etc/resolv.conf, got %+v", mockMounts.unmounts) } if _, err := os.Stat(fullPath); !os.IsNotExist(err) { t.Error("temporary resolv.conf file was not removed") } }