diff options
Diffstat (limited to 'internal/wireguard/wireguard.go')
| -rw-r--r-- | internal/wireguard/wireguard.go | 275 |
1 files changed, 108 insertions, 167 deletions
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 3a2bfa3..3c293b4 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -9,12 +9,11 @@ import ( "net" "net/netip" "os" - "os/exec" - "runtime" "strconv" "strings" "git.theodohertyfamily.com/tools/wg-wrap/pkg/wgconf" + "github.com/vishvananda/netlink" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" @@ -23,8 +22,9 @@ import ( // Tunnel represents an active Userspace WireGuard tunnel inside a network namespace. type Tunnel struct { - Device *device.Device - Tun tun.Device + Device *device.Device + Tun tun.Device + dnsFile string } // StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes. @@ -54,7 +54,11 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) { if err != nil { return nil, fmt.Errorf("failed to create TUN device %s: %w", tunName, err) } - cleanups = append(cleanups, func() { tunDev.Close() }) + cleanups = append(cleanups, func() { + if err := tunDev.Close(); err != nil { + fmt.Printf("warning: failed to close TUN device: %v\n", err) + } + }) // 2. Instantiate the userspace WireGuard device logger := device.NewLogger(device.LogLevelSilent, "[wg-wrap] ") @@ -69,12 +73,7 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) { } if bind == nil { - bind = conn.NewDefaultBind() - if hostNetNSFdStr := os.Getenv("WG_WRAP_HOST_NETNS_FD"); hostNetNSFdStr != "" { - if fd, err := strconv.Atoi(hostNetNSFdStr); err == nil && fd > 0 { - bind = NewHostBind(bind, fd) - } - } + return nil, fmt.Errorf("failed to acquire host socket FD: no valid WG_WRAP_HOST_SOCKET_FD provided") } wgDev := device.NewDevice(tunDev, bind, logger) @@ -94,13 +93,15 @@ func StartTunnel(cfg *wgconf.Config, dnsServer string) (t *Tunnel, err error) { return nil, fmt.Errorf("failed to bring up WireGuard device: %w", err) } - // 4. Configure network interface using standard Linux network commands (iproute2) + // 4. Configure network interface using netlink if err := configureInterface(tunName, cfg.Address, mtu); err != nil { return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err) } - if err := ConfigureResolvConf(dnsServer); err != nil { + if path, err := ConfigureResolvConf(dnsServer); err != nil { fmt.Printf("warning: failed to configure DNS resolver: %v\n", err) + } else { + t.dnsFile = path } return &Tunnel{ @@ -114,31 +115,30 @@ func (t *Tunnel) Close() { if t.Device != nil { t.Device.Close() } + if t.dnsFile != "" { + if err := UnmountResolvConf(t.dnsFile); err != nil { + fmt.Printf("warning: failed to unmount resolv.conf: %v\n", err) + } + } } // keyToHex ensures a WireGuard key is in hexadecimal format, converting from base64 if needed. func keyToHex(key string) (string, error) { - // Try base64 decoding first decoded, err := base64.StdEncoding.DecodeString(key) if err == nil && len(decoded) == 32 { return hex.EncodeToString(decoded), nil } - - // Try decoding as hex if len(key) == 64 { if _, err := hex.DecodeString(key); err == nil { return strings.ToLower(key), nil } } - return "", fmt.Errorf("key is neither valid base64 nor hex 32-byte key: %s", key) } // buildUAPIConfig translates our wgconf.Config into the standard WireGuard UAPI format func buildUAPIConfig(cfg *wgconf.Config) (string, error) { var sb strings.Builder - - // Global section if cfg.PrivateKey != "" { hexKey, err := keyToHex(cfg.PrivateKey) if err != nil { @@ -146,11 +146,7 @@ func buildUAPIConfig(cfg *wgconf.Config) (string, error) { } _, _ = fmt.Fprintf(&sb, "private_key=%s\n", hexKey) } - - // If there are existing peers, remove them first to have a clean state sb.WriteString("replace_peers=true\n") - - // Peer sections for _, peer := range cfg.Peers { if peer.PublicKey == "" { continue @@ -160,11 +156,9 @@ func buildUAPIConfig(cfg *wgconf.Config) (string, error) { return "", fmt.Errorf("invalid Peer PublicKey: %w", err) } _, _ = fmt.Fprintf(&sb, "public_key=%s\n", hexKey) - if peer.Endpoint != "" { _, _ = fmt.Fprintf(&sb, "endpoint=%s\n", peer.Endpoint) } - for _, allowedIP := range peer.AllowedIPs { trimmed := strings.TrimSpace(allowedIP) if trimmed != "" { @@ -172,36 +166,43 @@ func buildUAPIConfig(cfg *wgconf.Config) (string, error) { } } } - return sb.String(), nil } -// configureInterface uses the 'ip' command to set address, MTU, and default routing table +// configureInterface uses netlink to set address, MTU, and default routing table. func configureInterface(name, address string, mtu int) error { - // Set MTU and bring up link - // ip link set dev tun0 mtu 1420 up - cmd := exec.Command("ip", "link", "set", "dev", name, "mtu", fmt.Sprintf("%d", mtu), "up") - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to set link %s state/mtu: %v", name, err) - } - - // Add IP address - // ip addr add <address> dev tun0 - cmd = exec.Command("ip", "addr", "add", address, "dev", name) - if _, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to add address %s to link %s: %v", address, name, err) - } - - // Set default route or peer routes. - // For transparent userspace tunneling inside an isolated network namespace, - // we route all traffic (0.0.0.0/0) through our TUN device 'tun0'. - cmd = exec.Command("ip", "route", "add", "default", "dev", name) - if err := cmd.Run(); err != nil { - // If a default route already exists, we replace it or log the warning - // We try to replace first - cmdReplace := exec.Command("ip", "route", "replace", "default", "dev", name) - if errReplace := cmdReplace.Run(); errReplace != nil { - return fmt.Errorf("failed to configure default route to %s: %v", name, errReplace) + link, err := netlink.LinkByName(name) + if err != nil { + return fmt.Errorf("failed to find link %s: %w", name, err) + } + + if err := netlink.LinkSetMTU(link, mtu); err != nil { + return fmt.Errorf("failed to set MTU %d on link %s: %w", mtu, name, err) + } + + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up link %s: %w", name, err) + } + + addr, err := netlink.ParseAddr(address) + if err != nil { + return fmt.Errorf("invalid IP address %s: %w", address, err) + } + if err := netlink.AddrAdd(link, addr); err != nil { + if !strings.Contains(err.Error(), "file exists") { + return fmt.Errorf("failed to add address %s to link %s: %w", address, name, err) + } + } + + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + LinkIndex: link.Attrs().Index, + Dst: nil, + } + + if err := netlink.RouteAdd(route); err != nil { + if err := netlink.RouteReplace(route); err != nil { + return fmt.Errorf("failed to configure default route via %s: %w", name, err) } } @@ -222,46 +223,56 @@ func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) { return ip.String(), nil } -func ConfigureResolvConf(dns string) error { +func ConfigureResolvConf(dns string) (string, error) { if dns == "" { - return nil + return "", nil } - tmpFile, err := os.CreateTemp("", "resolvconf") if err != nil { - return fmt.Errorf("failed to create temp resolv.conf: %w", err) + return "", fmt.Errorf("failed to create temp resolv.conf: %w", err) } launcherPath := tmpFile.Name() - defer func() { - _ = tmpFile.Close() - _ = os.Remove(launcherPath) - }() - content := fmt.Sprintf("nameserver %s\n", dns) if _, err := tmpFile.WriteString(content); err != nil { - return fmt.Errorf("failed to write to temp resolv.conf: %w", err) + _ = tmpFile.Close() + return "", fmt.Errorf("failed to write to temp resolv.conf: %w", err) } + _ = tmpFile.Close() if err := unix.Mount(launcherPath, "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil { - return fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", launcherPath, err) + return "", fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", launcherPath, err) } if err := unix.Mount("", "/etc/resolv.conf", "", unix.MS_PRIVATE, ""); err != nil { fmt.Printf("warning: failed to make /etc/resolv.conf mount private: %v\n", err) } + return launcherPath, nil +} + +func UnmountResolvConf(path string) error { + if path == "" { + return nil + } + if err := unix.Unmount("/etc/resolv.conf", 0); err != nil { + return fmt.Errorf("failed to unmount /etc/resolv.conf: %w", err) + } + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove temp resolv.conf file %s: %w", path, err) + } return nil } -// BlockHostServices blocks local D-Bus and name service cache daemon (nscd) sockets -// inside the mount namespace. This prevents glibc from bypassing the network namespace -// isolation via host services (e.g. systemd-resolved via D-Bus). func BlockHostServices() error { tmpDir, err := os.MkdirTemp("", "wg-wrap-block-") if err != nil { return fmt.Errorf("failed to create temp dir: %w", err) } - defer os.RemoveAll(tmpDir) + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + fmt.Printf("warning: failed to remove temp dir %s: %v\n", tmpDir, err) + } + }() tmpFile, err := os.CreateTemp("", "wg-wrap-block-file-") if err != nil { @@ -269,7 +280,11 @@ func BlockHostServices() error { } tmpFileName := tmpFile.Name() _ = tmpFile.Close() - defer os.Remove(tmpFileName) + defer func() { + if err := os.Remove(tmpFileName); err != nil && !os.IsNotExist(err) { + fmt.Printf("warning: failed to remove temp file %s: %v\n", tmpFileName, err) + } + }() pathsToBlock := []string{ "/run/dbus/system_bus_socket", @@ -299,76 +314,22 @@ func BlockHostServices() error { return nil } -// HostBind wraps a standard conn.Bind so that its socket creation (Open) -// is forced to execute within a host network namespace. -type HostBind struct { - inner conn.Bind - hostNetNSFd int -} +type HostBind struct{} func NewHostBind(inner conn.Bind, hostNetNSFd int) *HostBind { - return &HostBind{inner: inner, hostNetNSFd: hostNetNSFd} + return &HostBind{} } func (h *HostBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - // Open/save a reference to our current isolated network namespace to switch back to. - isolatedFd, err := unix.Open("/proc/self/ns/net", unix.O_RDONLY, 0) - if err != nil { - return nil, 0, fmt.Errorf("failed to open isolated netns: %w", err) - } - defer func() { _ = unix.Close(isolatedFd) }() - - // Temporarily switch this thread to the host network namespace - if err := unix.Setns(h.hostNetNSFd, unix.CLONE_NEWNET); err != nil { - return nil, 0, fmt.Errorf("failed to join host netns: %w", err) - } - - // Sockets are opened in the host network namespace! - fns, actualPort, err = h.inner.Open(port) - if err != nil { - return nil, 0, fmt.Errorf("failed to open sockets in host netns: %w", err) - } - - // Switch this thread back to the isolated network namespace - if err := unix.Setns(isolatedFd, unix.CLONE_NEWNET); err != nil { - _ = h.inner.Close() - // CRITICAL: The thread is stuck in the host network namespace. Returning it to the Go runtime pool - // will cause other goroutines to run in the host namespace, breaching isolation. We must panic - // immediately to abort the process and prevent a namespace escape. - panic(fmt.Sprintf("CRITICAL: failed to restore isolated netns: %v", err)) - } - - return fns, actualPort, nil -} - -func (h *HostBind) Close() error { - return h.inner.Close() -} - -func (h *HostBind) SetMark(mark uint32) error { - return h.inner.SetMark(mark) -} - -func (h *HostBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { - // Linux socket routing maps to the namespace in which the socket was created, - // so h.inner.Send will automatically route via host namespace without Setns here! - return h.inner.Send(bufs, endpoint) -} - -func (h *HostBind) ParseEndpoint(s string) (conn.Endpoint, error) { - return h.inner.ParseEndpoint(s) + return nil, 0, fmt.Errorf("HostBind.Open is disabled for security reasons") } -func (h *HostBind) BatchSize() int { - return h.inner.BatchSize() -} +func (h *HostBind) Close() error { return nil } +func (h *HostBind) SetMark(mark uint32) error { return nil } +func (h *HostBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { return nil } +func (h *HostBind) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } +func (h *HostBind) BatchSize() int { return 0 } -// FDBind implements the conn.Bind interface by wrapping a pre-opened -// host UDP socket file descriptor. This allows unprivileged processes inside -// network namespaces to communicate over the host network loop. type FDBind struct { originalFd int conn *net.UDPConn @@ -378,39 +339,19 @@ type FDEndpoint struct { addr netip.AddrPort } -func (e *FDEndpoint) DstIP() netip.Addr { - return e.addr.Addr() -} - -func (e *FDEndpoint) DstToString() string { - return e.addr.String() -} - -func (e *FDEndpoint) DstToBytes() []byte { - return e.addr.Addr().AsSlice() -} - -func (e *FDEndpoint) ClearSrc() {} - -func (e *FDEndpoint) SrcIP() netip.Addr { - return netip.Addr{} -} - -func (e *FDEndpoint) SrcToString() string { - return "" -} - -func (e *FDEndpoint) SrcIfidx() int32 { - return 0 -} +func (e *FDEndpoint) DstIP() netip.Addr { return e.addr.Addr() } +func (e *FDEndpoint) DstToString() string { return e.addr.String() } +func (e *FDEndpoint) DstToBytes() []byte { return e.addr.Addr().AsSlice() } +func (e *FDEndpoint) ClearSrc() {} +func (e *FDEndpoint) SrcIP() netip.Addr { return netip.Addr{} } +func (e *FDEndpoint) SrcToString() string { return "" } +func (e *FDEndpoint) SrcIfidx() int32 { return 0 } func NewFDBind(fd int) (*FDBind, error) { return &FDBind{originalFd: fd}, nil } func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { - // Duplicate the original fd so we can close the duplicated socket during - // transitions or shutdown, while preserving the ability to re-open/re-bind it later. dupFd, err := unix.Dup(b.originalFd) if err != nil { return nil, 0, fmt.Errorf("failed to duplicate host socket fd: %w", err) @@ -443,14 +384,17 @@ func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, e if b.conn == nil { return 0, net.ErrClosed } - nBytes, addr, err := b.conn.ReadFromUDP(packets[0]) - if err != nil { - return 0, err + for i := 0; i < len(packets); i++ { + nBytes, addr, err := b.conn.ReadFromUDP(packets[i]) + if err != nil { + return i, err + } + sizes[i] = nBytes + addrPort := addr.AddrPort() + eps[i] = &FDEndpoint{addr: addrPort} + n++ } - sizes[0] = nBytes - addrPort := addr.AddrPort() - eps[0] = &FDEndpoint{addr: addrPort} - return 1, nil + return n, nil } return []conn.ReceiveFunc{receive}, actualPort, nil @@ -465,9 +409,7 @@ func (b *FDBind) Close() error { return nil } -func (b *FDBind) SetMark(mark uint32) error { - return nil -} +func (b *FDBind) SetMark(mark uint32) error { return nil } func (b *FDBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { if b.conn == nil { @@ -478,7 +420,6 @@ func (b *FDBind) Send(bufs [][]byte, endpoint conn.Endpoint) error { return fmt.Errorf("failed to parse destination endpoint %s: %w", endpoint.DstToString(), err) } addr := net.UDPAddrFromAddrPort(addrPort) - for _, buf := range bufs { _, err := b.conn.WriteToUDP(buf, addr) if err != nil { |
