summaryrefslogtreecommitdiff
path: root/internal/wireguard/wireguard_unit_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/wireguard/wireguard_unit_test.go')
-rw-r--r--internal/wireguard/wireguard_unit_test.go195
1 files changed, 195 insertions, 0 deletions
diff --git a/internal/wireguard/wireguard_unit_test.go b/internal/wireguard/wireguard_unit_test.go
new file mode 100644
index 0000000..1ad7f65
--- /dev/null
+++ b/internal/wireguard/wireguard_unit_test.go
@@ -0,0 +1,195 @@
+//go:build linux
+
+package wireguard
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "git.theodohertyfamily.com/wg-wrap/internal/namespace"
+ "git.theodohertyfamily.com/wg-wrap/internal/paths"
+ "golang.org/x/sys/unix"
+)
+
+// mockMountOps records mount calls for verification.
+type mockMountOps struct {
+ mounts []mountCall
+ unmounts []string
+}
+
+type mountCall struct {
+ source string
+ target string
+ flags uintptr
+}
+
+func (m *mockMountOps) Mount(source, target, fstype string, flags uintptr, data string) error {
+ m.mounts = append(m.mounts, mountCall{source, target, flags})
+ return nil
+}
+
+func (m *mockMountOps) Unmount(target string, flags int) error {
+ m.unmounts = append(m.unmounts, target)
+ return nil
+}
+
+// mockFS mimics a filesystem using a temporary directory.
+type mockFS struct {
+ root string
+}
+
+func (m *mockFS) fullPath(path string) string {
+ if filepath.IsAbs(path) {
+ path = strings.TrimPrefix(path, "/")
+ }
+ return filepath.Join(m.root, path)
+}
+
+func (m *mockFS) Stat(name string) (os.FileInfo, error) {
+ return os.Stat(m.fullPath(name))
+}
+
+func (m *mockFS) MkdirAll(path string, perm os.FileMode) error {
+ return os.MkdirAll(m.fullPath(path), perm)
+}
+
+func (m *mockFS) CreateTemp(dir, pattern string) (*os.File, error) {
+ return os.CreateTemp(m.fullPath(dir), pattern)
+}
+
+func (m *mockFS) MkdirTemp(dir, pattern string) (string, error) {
+ dirPath := m.fullPath(dir)
+ res, err := os.MkdirTemp(dirPath, pattern)
+ if err != nil {
+ return "", err
+ }
+ // Return path relative to root for consistency if needed, but OS calls return absolute.
+ // The TunnelManager expects a path it can pass to Remove() later.
+ // If we return absolute, mockFS.Remove must handle it.
+ return res, nil
+}
+
+func (m *mockFS) Remove(name string) error {
+ // If the path is absolute and starts with our root, we can remove it directly.
+ // Otherwise, we use fullPath to ensure it's within root.
+ if filepath.IsAbs(name) && strings.HasPrefix(name, m.root) {
+ return os.Remove(name)
+ }
+ return os.Remove(m.fullPath(name))
+}
+
+// setupParallelEnv creates a clean, isolated test environment for a single test.
+func setupParallelEnv(t *testing.T) (*paths.PathManager, *mockMountOps, *mockFS, *TunnelManager) {
+ t.Helper()
+
+ tmpDir := t.TempDir()
+ hostRoot := t.TempDir()
+
+ pm := paths.NewPathManager("", tmpDir)
+ mockMounts := &mockMountOps{}
+ mockFS := &mockFS{root: hostRoot}
+
+ tm := &TunnelManager{
+ MountOps: mockMounts,
+ FS: mockFS,
+ }
+
+ return pm, mockMounts, mockFS, tm
+}
+
+func TestBlockHostServices_Logic(t *testing.T) {
+ t.Parallel()
+
+ pm, mockMounts, mockFS, tm := setupParallelEnv(t)
+ profile := "test-profile"
+
+ blockPaths := namespace.GetBlockPaths()
+ for _, p := range blockPaths {
+ fullPath := mockFS.fullPath(p)
+ _ = os.MkdirAll(filepath.Dir(fullPath), 0755)
+ _ = os.WriteFile(fullPath, []byte("data"), 0644)
+ }
+
+ err := tm.BlockHostServices(pm, profile)
+ if err != nil {
+ t.Fatalf("BlockHostServices failed: %v", err)
+ }
+
+ for _, p := range blockPaths {
+ found := false
+ for _, m := range mockMounts.mounts {
+ if m.target == p {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("path %s was not blocked", p)
+ }
+ }
+}
+
+func TestConfigureResolvConf_Logic(t *testing.T) {
+ t.Parallel()
+
+ _, mockMounts, mockFS, tm := setupParallelEnv(t)
+
+ // Use a directory relative to mockFS.root
+ profileDir := "profile-dir"
+ _ = mockFS.MkdirAll(profileDir, 0755)
+ dnsServer := "1.1.1.1"
+
+ path, err := tm.ConfigureResolvConf(dnsServer, profileDir)
+ if err != nil {
+ t.Fatalf("ConfigureResolvConf failed: %v", err)
+ }
+
+ // path is returned from os.CreateTemp, which is absolute.
+ if _, err := os.Stat(path); os.IsNotExist(err) {
+ t.Errorf("resolv.conf file was not created at %s", path)
+ }
+
+ if len(mockMounts.mounts) < 2 {
+ t.Fatalf("expected 2 mounts, got %d", len(mockMounts.mounts))
+ }
+
+ if mockMounts.mounts[0].target != "/etc/resolv.conf" || mockMounts.mounts[0].flags != unix.MS_BIND {
+ t.Errorf("expected BIND mount to /etc/resolv.conf, got %+v", mockMounts.mounts[0])
+ }
+
+ if mockMounts.mounts[1].target != "/etc/resolv.conf" || mockMounts.mounts[1].flags != unix.MS_PRIVATE {
+ t.Errorf("expected PRIVATE mount to /etc/resolv.conf, got %+v", mockMounts.mounts[1])
+ }
+}
+
+func TestUnmountResolvConf_Logic(t *testing.T) {
+ t.Parallel()
+
+ _, mockMounts, mockFS, tm := setupParallelEnv(t)
+
+ // Create a file within the mock FS root
+ profileDir := "profile-dir"
+ _ = mockFS.MkdirAll(profileDir, 0755)
+ tmpFile := "resolv.conf.tmp"
+ fullPath := mockFS.fullPath(filepath.Join(profileDir, tmpFile))
+
+ err := os.WriteFile(fullPath, []byte("test"), 0644)
+ if err != nil {
+ t.Fatalf("failed to write temp resolv.conf: %v", err)
+ }
+
+ err = tm.UnmountResolvConf(fullPath)
+ if err != nil {
+ t.Errorf("UnmountResolvConf failed: %v", err)
+ }
+
+ if len(mockMounts.unmounts) == 0 || mockMounts.unmounts[0] != "/etc/resolv.conf" {
+ t.Errorf("expected unmount of /etc/resolv.conf, got %+v", mockMounts.unmounts)
+ }
+
+ if _, err := os.Stat(fullPath); !os.IsNotExist(err) {
+ t.Error("temporary resolv.conf file was not removed")
+ }
+}