summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/cli/cli.go31
-rw-r--r--internal/cli/profile_test.go12
-rw-r--r--internal/wireguard/wireguard.go90
3 files changed, 103 insertions, 30 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go
index 0876d08..11914b1 100644
--- a/internal/cli/cli.go
+++ b/internal/cli/cli.go
@@ -178,7 +178,15 @@ func (a *App) ExecuteCommand(cfg *config.Config) error {
}
// Start the WireGuard userspace device & routing table setup
- tunnel, err := wireguard.StartTunnel(wgCfg)
+ dnsServer := cfg.DNSServer
+ if dnsServer == "" {
+ dnsServer = wgCfg.DNS
+ }
+ if dnsServer == "" {
+ dnsServer = "1.1.1.1" // Fallback to safe public DNS to prevent leaks
+ }
+
+ tunnel, err := wireguard.StartTunnel(wgCfg, dnsServer)
if err != nil {
return fmt.Errorf("failed to start WireGuard tunnel: %w", err)
}
@@ -256,15 +264,23 @@ func (a *App) handleProfileConfigure(name string) error {
return fmt.Errorf("profile '%s' not found", name)
}
- cfg, err := wgconf.Parse(profilePath)
- if err != nil {
- return fmt.Errorf("failed to parse profile %s: %w", name, err)
+ editor := os.Getenv("EDITOR")
+ if editor == "" {
+ editor = "vi" // Sensible fallback
}
- fmt.Printf("Editing profile %s...\n", name)
- fmt.Println("DNS server (current: '" + cfg.DNS + "'):")
+ fmt.Printf("Opening profile %s in default editor (%s)...\n", name, editor)
+
+ cmd := exec.Command(editor, profilePath)
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ if err := cmd.Run(); err != nil {
+ return fmt.Errorf("editor failed: %w", err)
+ }
- return fmt.Errorf("interactive configuration not supported in this environment, use a config file")
+ return nil
}
func (a *App) handleProfileList() error {
@@ -354,6 +370,7 @@ func (a *App) showConfig() error {
fmt.Printf("Configuration:\n")
fmt.Printf(" Profile: %s\n", cfg.Profile)
fmt.Printf(" DNS Server: %s\n", cfg.DNSServer)
+ fmt.Printf(" Config Dir: %s\n", pm.ConfigDir())
fmt.Printf(" Runtime Base: %s\n", pm.RuntimeBaseDir())
fmt.Printf(" Profile Path: %s\n", profilePath)
fmt.Printf(" PIDs Path: %s\n", pidsPath)
diff --git a/internal/cli/profile_test.go b/internal/cli/profile_test.go
index d256cb0..17a5bc6 100644
--- a/internal/cli/profile_test.go
+++ b/internal/cli/profile_test.go
@@ -96,10 +96,6 @@ func TestProfileDeleteNotFound(t *testing.T) {
}
func TestProfileConfigure(t *testing.T) {
- // profile configure is intended to modify existing configs.
- // For now, we just want to ensure it doesn't crash and we can
- // eventually implement it.
-
tmpDir := t.TempDir()
profilesDir := filepath.Join(tmpDir, "profiles")
err := os.MkdirAll(profilesDir, 0755)
@@ -117,9 +113,11 @@ func TestProfileConfigure(t *testing.T) {
app := NewApp([]string{"wg-wrap", "profile", "configure", profileName})
app.ConfigDir = profilesDir
+ // Use "true" as the mock editor to ensure it exits successfully immediately
+ t.Setenv("EDITOR", "true")
+
err = app.Route()
- // This will currently return "not yet implemented" error, which is expected for now.
- if err == nil {
- t.Errorf("expected 'not yet implemented' error, got nil")
+ if err != nil {
+ t.Errorf("expected successful configuration, got error: %v", err)
}
}
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index 42e095d..a45401c 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -28,12 +28,20 @@ type Tunnel struct {
}
// StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes.
-func StartTunnel(cfg *wgconf.Config) (*Tunnel, error) {
+func StartTunnel(cfg *wgconf.Config, dnsServer string) (*Tunnel, error) {
// 1. Create the TUN device inside the current (isolated) namespace
// We use the default name 'tun0'
tunName := "tun0"
mtu := 1420
+ // Ensure the mount namespace is private to prevent mount propagation to the host.
+ // This is critical for the bind-mount of /etc/resolv.conf to work in rootless environments.
+ if err := unix.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil {
+ // We log this as a warning because some environments might not allow this,
+ // but we can still try to proceed.
+ fmt.Printf("warning: failed to make mount namespace private: %v\n", err)
+ }
+
tunDev, err := tun.CreateTUN(tunName, mtu)
if err != nil {
return nil, fmt.Errorf("failed to create TUN device %s: %w", tunName, err)
@@ -91,6 +99,13 @@ func StartTunnel(cfg *wgconf.Config) (*Tunnel, error) {
return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err)
}
+ // Configure DNS resolver inside the namespace
+ if err := ConfigureResolvConf(dnsServer); err != nil {
+ // We treat DNS failure as a warning rather than a fatal error to allow
+ // the tunnel to function even if /etc/resolv.conf is read-only.
+ fmt.Printf("warning: failed to configure DNS resolver: %v\n", err)
+ }
+
return &Tunnel{
Device: wgDev,
Tun: tunDev,
@@ -210,14 +225,34 @@ func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) {
return ip.String(), nil
}
-// ConfigureResolvConf sets up the DNS inside the namespace's /etc/resolv.conf.
-// Because the namespace is completely isolated, writing to /etc/resolv.conf inside
-// the container/namespaces context won't affect the host, but since we are mapped to root
-// inside a mount namespace, we may want to bind-mount a custom resolv.conf.
-// To keep it simple and clean without requiring complex host mount setup, we can write
-// directly to /etc/resolv.conf inside our user namespace. Since /etc/resolv.conf is usually
-// writable inside user namespaces, we try to modify it directly.
func ConfigureResolvConf(dns string) error {
+ if dns == "" {
+ return nil
+ }
+
+ // To avoid modifying the host's /etc/resolv.conf, we use the private mount namespace.
+ tmpFile, err := os.CreateTemp("", "resolvconf")
+ if err != nil {
+ return fmt.Errorf("failed to create temp resolv.conf: %w", err)
+ }
+ defer func() { _ = tmpFile.Close() }()
+
+ content := fmt.Sprintf("nameserver %s\n", dns)
+ if _, err := tmpFile.WriteString(content); err != nil {
+ return fmt.Errorf("failed to write to temp resolv.conf: %w", err)
+ }
+
+ // 1. Bind-mount the temp file over /etc/resolv.conf
+ if err := unix.Mount(tmpFile.Name(), "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil {
+ return fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", tmpFile.Name(), err)
+ }
+
+ // 2. Make the mount private to ensure it doesn't propagate back to the host
+ // and to satisfy kernel requirements for mount transitions in some environments.
+ if err := unix.Mount("/etc/resolv.conf", "/etc/resolv.conf", "", unix.MS_REMOUNT|unix.MS_BIND|unix.MS_PRIVATE, ""); err != nil {
+ return fmt.Errorf("failed to make /etc/resolv.conf mount private: %w", err)
+ }
+
return nil
}
@@ -289,7 +324,8 @@ func (h *HostBind) BatchSize() int {
// host UDP socket file descriptor. This allows unprivileged processes inside
// network namespaces to communicate over the host network loop.
type FDBind struct {
- conn *net.UDPConn
+ originalFd int
+ conn *net.UDPConn
}
type FDEndpoint struct {
@@ -323,20 +359,31 @@ func (e *FDEndpoint) SrcIfidx() int32 {
}
func NewFDBind(fd int) (*FDBind, error) {
- file := os.NewFile(uintptr(fd), "host-udp-socket")
+ return &FDBind{originalFd: fd}, nil
+}
+
+func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ // Duplicate the original fd so we can close the duplicated socket during
+ // transitions or shutdown, while preserving the ability to re-open/re-bind it later.
+ dupFd, err := unix.Dup(b.originalFd)
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to duplicate host socket fd: %w", err)
+ }
+
+ file := os.NewFile(uintptr(dupFd), "host-udp-socket")
pconn, err := net.FilePacketConn(file)
if err != nil {
- return nil, fmt.Errorf("failed to wrap fd %d as packet conn: %w", fd, err)
+ _ = file.Close()
+ return nil, 0, fmt.Errorf("failed to wrap fd %d as packet conn: %w", dupFd, err)
}
+
udpConn, ok := pconn.(*net.UDPConn)
if !ok {
_ = pconn.Close()
- return nil, fmt.Errorf("fd %d is not a UDP socket", fd)
+ return nil, 0, fmt.Errorf("fd %d is not a UDP socket", dupFd)
}
- return &FDBind{conn: udpConn}, nil
-}
+ b.conn = udpConn
-func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
laddr, ok := b.conn.LocalAddr().(*net.UDPAddr)
if !ok {
return nil, 0, fmt.Errorf("local address is not a UDP address")
@@ -347,6 +394,9 @@ func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, e
if len(packets) == 0 {
return 0, nil
}
+ if b.conn == nil {
+ return 0, net.ErrClosed
+ }
nBytes, addr, err := b.conn.ReadFromUDP(packets[0])
if err != nil {
return 0, err
@@ -361,7 +411,12 @@ func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, e
}
func (b *FDBind) Close() error {
- return b.conn.Close()
+ if b.conn != nil {
+ err := b.conn.Close()
+ b.conn = nil
+ return err
+ }
+ return nil
}
func (b *FDBind) SetMark(mark uint32) error {
@@ -369,6 +424,9 @@ func (b *FDBind) SetMark(mark uint32) error {
}
func (b *FDBind) Send(bufs [][]byte, endpoint conn.Endpoint) error {
+ if b.conn == nil {
+ return net.ErrClosed
+ }
addrPort, err := netip.ParseAddrPort(endpoint.DstToString())
if err != nil {
return fmt.Errorf("failed to parse destination endpoint %s: %w", endpoint.DstToString(), err)