diff options
Diffstat (limited to 'internal/wireguard/wireguard.go')
| -rw-r--r-- | internal/wireguard/wireguard.go | 90 |
1 files changed, 74 insertions, 16 deletions
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 42e095d..a45401c 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -28,12 +28,20 @@ type Tunnel struct { } // StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes. -func StartTunnel(cfg *wgconf.Config) (*Tunnel, error) { +func StartTunnel(cfg *wgconf.Config, dnsServer string) (*Tunnel, error) { // 1. Create the TUN device inside the current (isolated) namespace // We use the default name 'tun0' tunName := "tun0" mtu := 1420 + // Ensure the mount namespace is private to prevent mount propagation to the host. + // This is critical for the bind-mount of /etc/resolv.conf to work in rootless environments. + if err := unix.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil { + // We log this as a warning because some environments might not allow this, + // but we can still try to proceed. + fmt.Printf("warning: failed to make mount namespace private: %v\n", err) + } + tunDev, err := tun.CreateTUN(tunName, mtu) if err != nil { return nil, fmt.Errorf("failed to create TUN device %s: %w", tunName, err) @@ -91,6 +99,13 @@ func StartTunnel(cfg *wgconf.Config) (*Tunnel, error) { return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err) } + // Configure DNS resolver inside the namespace + if err := ConfigureResolvConf(dnsServer); err != nil { + // We treat DNS failure as a warning rather than a fatal error to allow + // the tunnel to function even if /etc/resolv.conf is read-only. + fmt.Printf("warning: failed to configure DNS resolver: %v\n", err) + } + return &Tunnel{ Device: wgDev, Tun: tunDev, @@ -210,14 +225,34 @@ func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) { return ip.String(), nil } -// ConfigureResolvConf sets up the DNS inside the namespace's /etc/resolv.conf. -// Because the namespace is completely isolated, writing to /etc/resolv.conf inside -// the container/namespaces context won't affect the host, but since we are mapped to root -// inside a mount namespace, we may want to bind-mount a custom resolv.conf. -// To keep it simple and clean without requiring complex host mount setup, we can write -// directly to /etc/resolv.conf inside our user namespace. Since /etc/resolv.conf is usually -// writable inside user namespaces, we try to modify it directly. func ConfigureResolvConf(dns string) error { + if dns == "" { + return nil + } + + // To avoid modifying the host's /etc/resolv.conf, we use the private mount namespace. + tmpFile, err := os.CreateTemp("", "resolvconf") + if err != nil { + return fmt.Errorf("failed to create temp resolv.conf: %w", err) + } + defer func() { _ = tmpFile.Close() }() + + content := fmt.Sprintf("nameserver %s\n", dns) + if _, err := tmpFile.WriteString(content); err != nil { + return fmt.Errorf("failed to write to temp resolv.conf: %w", err) + } + + // 1. Bind-mount the temp file over /etc/resolv.conf + if err := unix.Mount(tmpFile.Name(), "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil { + return fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", tmpFile.Name(), err) + } + + // 2. Make the mount private to ensure it doesn't propagate back to the host + // and to satisfy kernel requirements for mount transitions in some environments. + if err := unix.Mount("/etc/resolv.conf", "/etc/resolv.conf", "", unix.MS_REMOUNT|unix.MS_BIND|unix.MS_PRIVATE, ""); err != nil { + return fmt.Errorf("failed to make /etc/resolv.conf mount private: %w", err) + } + return nil } @@ -289,7 +324,8 @@ func (h *HostBind) BatchSize() int { // host UDP socket file descriptor. This allows unprivileged processes inside // network namespaces to communicate over the host network loop. type FDBind struct { - conn *net.UDPConn + originalFd int + conn *net.UDPConn } type FDEndpoint struct { @@ -323,20 +359,31 @@ func (e *FDEndpoint) SrcIfidx() int32 { } func NewFDBind(fd int) (*FDBind, error) { - file := os.NewFile(uintptr(fd), "host-udp-socket") + return &FDBind{originalFd: fd}, nil +} + +func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + // Duplicate the original fd so we can close the duplicated socket during + // transitions or shutdown, while preserving the ability to re-open/re-bind it later. + dupFd, err := unix.Dup(b.originalFd) + if err != nil { + return nil, 0, fmt.Errorf("failed to duplicate host socket fd: %w", err) + } + + file := os.NewFile(uintptr(dupFd), "host-udp-socket") pconn, err := net.FilePacketConn(file) if err != nil { - return nil, fmt.Errorf("failed to wrap fd %d as packet conn: %w", fd, err) + _ = file.Close() + return nil, 0, fmt.Errorf("failed to wrap fd %d as packet conn: %w", dupFd, err) } + udpConn, ok := pconn.(*net.UDPConn) if !ok { _ = pconn.Close() - return nil, fmt.Errorf("fd %d is not a UDP socket", fd) + return nil, 0, fmt.Errorf("fd %d is not a UDP socket", dupFd) } - return &FDBind{conn: udpConn}, nil -} + b.conn = udpConn -func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { laddr, ok := b.conn.LocalAddr().(*net.UDPAddr) if !ok { return nil, 0, fmt.Errorf("local address is not a UDP address") @@ -347,6 +394,9 @@ func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, e if len(packets) == 0 { return 0, nil } + if b.conn == nil { + return 0, net.ErrClosed + } nBytes, addr, err := b.conn.ReadFromUDP(packets[0]) if err != nil { return 0, err @@ -361,7 +411,12 @@ func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, e } func (b *FDBind) Close() error { - return b.conn.Close() + if b.conn != nil { + err := b.conn.Close() + b.conn = nil + return err + } + return nil } func (b *FDBind) SetMark(mark uint32) error { @@ -369,6 +424,9 @@ func (b *FDBind) SetMark(mark uint32) error { } func (b *FDBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { + if b.conn == nil { + return net.ErrClosed + } addrPort, err := netip.ParseAddrPort(endpoint.DstToString()) if err != nil { return fmt.Errorf("failed to parse destination endpoint %s: %w", endpoint.DstToString(), err) |
