summaryrefslogtreecommitdiff
path: root/internal/wireguard/wireguard.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/wireguard/wireguard.go')
-rw-r--r--internal/wireguard/wireguard.go90
1 files changed, 74 insertions, 16 deletions
diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go
index 42e095d..a45401c 100644
--- a/internal/wireguard/wireguard.go
+++ b/internal/wireguard/wireguard.go
@@ -28,12 +28,20 @@ type Tunnel struct {
}
// StartTunnel creates a TUN device, launches wireguard-go over it, and configures IPs/routes.
-func StartTunnel(cfg *wgconf.Config) (*Tunnel, error) {
+func StartTunnel(cfg *wgconf.Config, dnsServer string) (*Tunnel, error) {
// 1. Create the TUN device inside the current (isolated) namespace
// We use the default name 'tun0'
tunName := "tun0"
mtu := 1420
+ // Ensure the mount namespace is private to prevent mount propagation to the host.
+ // This is critical for the bind-mount of /etc/resolv.conf to work in rootless environments.
+ if err := unix.Mount("", "/", "", unix.MS_REC|unix.MS_PRIVATE, ""); err != nil {
+ // We log this as a warning because some environments might not allow this,
+ // but we can still try to proceed.
+ fmt.Printf("warning: failed to make mount namespace private: %v\n", err)
+ }
+
tunDev, err := tun.CreateTUN(tunName, mtu)
if err != nil {
return nil, fmt.Errorf("failed to create TUN device %s: %w", tunName, err)
@@ -91,6 +99,13 @@ func StartTunnel(cfg *wgconf.Config) (*Tunnel, error) {
return nil, fmt.Errorf("failed to configure network interface %s: %w", tunName, err)
}
+ // Configure DNS resolver inside the namespace
+ if err := ConfigureResolvConf(dnsServer); err != nil {
+ // We treat DNS failure as a warning rather than a fatal error to allow
+ // the tunnel to function even if /etc/resolv.conf is read-only.
+ fmt.Printf("warning: failed to configure DNS resolver: %v\n", err)
+ }
+
return &Tunnel{
Device: wgDev,
Tun: tunDev,
@@ -210,14 +225,34 @@ func GetTunnelLocalIP(cfg *wgconf.Config) (string, error) {
return ip.String(), nil
}
-// ConfigureResolvConf sets up the DNS inside the namespace's /etc/resolv.conf.
-// Because the namespace is completely isolated, writing to /etc/resolv.conf inside
-// the container/namespaces context won't affect the host, but since we are mapped to root
-// inside a mount namespace, we may want to bind-mount a custom resolv.conf.
-// To keep it simple and clean without requiring complex host mount setup, we can write
-// directly to /etc/resolv.conf inside our user namespace. Since /etc/resolv.conf is usually
-// writable inside user namespaces, we try to modify it directly.
func ConfigureResolvConf(dns string) error {
+ if dns == "" {
+ return nil
+ }
+
+ // To avoid modifying the host's /etc/resolv.conf, we use the private mount namespace.
+ tmpFile, err := os.CreateTemp("", "resolvconf")
+ if err != nil {
+ return fmt.Errorf("failed to create temp resolv.conf: %w", err)
+ }
+ defer func() { _ = tmpFile.Close() }()
+
+ content := fmt.Sprintf("nameserver %s\n", dns)
+ if _, err := tmpFile.WriteString(content); err != nil {
+ return fmt.Errorf("failed to write to temp resolv.conf: %w", err)
+ }
+
+ // 1. Bind-mount the temp file over /etc/resolv.conf
+ if err := unix.Mount(tmpFile.Name(), "/etc/resolv.conf", "", unix.MS_BIND, ""); err != nil {
+ return fmt.Errorf("failed to bind-mount %s to /etc/resolv.conf: %w", tmpFile.Name(), err)
+ }
+
+ // 2. Make the mount private to ensure it doesn't propagate back to the host
+ // and to satisfy kernel requirements for mount transitions in some environments.
+ if err := unix.Mount("/etc/resolv.conf", "/etc/resolv.conf", "", unix.MS_REMOUNT|unix.MS_BIND|unix.MS_PRIVATE, ""); err != nil {
+ return fmt.Errorf("failed to make /etc/resolv.conf mount private: %w", err)
+ }
+
return nil
}
@@ -289,7 +324,8 @@ func (h *HostBind) BatchSize() int {
// host UDP socket file descriptor. This allows unprivileged processes inside
// network namespaces to communicate over the host network loop.
type FDBind struct {
- conn *net.UDPConn
+ originalFd int
+ conn *net.UDPConn
}
type FDEndpoint struct {
@@ -323,20 +359,31 @@ func (e *FDEndpoint) SrcIfidx() int32 {
}
func NewFDBind(fd int) (*FDBind, error) {
- file := os.NewFile(uintptr(fd), "host-udp-socket")
+ return &FDBind{originalFd: fd}, nil
+}
+
+func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ // Duplicate the original fd so we can close the duplicated socket during
+ // transitions or shutdown, while preserving the ability to re-open/re-bind it later.
+ dupFd, err := unix.Dup(b.originalFd)
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to duplicate host socket fd: %w", err)
+ }
+
+ file := os.NewFile(uintptr(dupFd), "host-udp-socket")
pconn, err := net.FilePacketConn(file)
if err != nil {
- return nil, fmt.Errorf("failed to wrap fd %d as packet conn: %w", fd, err)
+ _ = file.Close()
+ return nil, 0, fmt.Errorf("failed to wrap fd %d as packet conn: %w", dupFd, err)
}
+
udpConn, ok := pconn.(*net.UDPConn)
if !ok {
_ = pconn.Close()
- return nil, fmt.Errorf("fd %d is not a UDP socket", fd)
+ return nil, 0, fmt.Errorf("fd %d is not a UDP socket", dupFd)
}
- return &FDBind{conn: udpConn}, nil
-}
+ b.conn = udpConn
-func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
laddr, ok := b.conn.LocalAddr().(*net.UDPAddr)
if !ok {
return nil, 0, fmt.Errorf("local address is not a UDP address")
@@ -347,6 +394,9 @@ func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, e
if len(packets) == 0 {
return 0, nil
}
+ if b.conn == nil {
+ return 0, net.ErrClosed
+ }
nBytes, addr, err := b.conn.ReadFromUDP(packets[0])
if err != nil {
return 0, err
@@ -361,7 +411,12 @@ func (b *FDBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, e
}
func (b *FDBind) Close() error {
- return b.conn.Close()
+ if b.conn != nil {
+ err := b.conn.Close()
+ b.conn = nil
+ return err
+ }
+ return nil
}
func (b *FDBind) SetMark(mark uint32) error {
@@ -369,6 +424,9 @@ func (b *FDBind) SetMark(mark uint32) error {
}
func (b *FDBind) Send(bufs [][]byte, endpoint conn.Endpoint) error {
+ if b.conn == nil {
+ return net.ErrClosed
+ }
addrPort, err := netip.ParseAddrPort(endpoint.DstToString())
if err != nil {
return fmt.Errorf("failed to parse destination endpoint %s: %w", endpoint.DstToString(), err)