From 14fffd095a8d5ac54a43b8d060d0d1d8b907d578 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 21 Dec 2023 16:55:37 +0100 Subject: [PATCH] Move common DNS monitor code to own class. --- dnsmonitor.go | 213 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 dnsmonitor.go diff --git a/dnsmonitor.go b/dnsmonitor.go new file mode 100644 index 00000000..51619d72 --- /dev/null +++ b/dnsmonitor.go @@ -0,0 +1,213 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2023 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "log" + "net" + "sync" + "time" +) + +var ( + lookupDnsMonitorIP = net.LookupIP +) + +const ( + defaultDnsMonitorInterval = time.Second +) + +type DnsMonitorCallback = func(entry *DnsMonitorEntry, added []net.IP, keep []net.IP, removed []net.IP) + +type DnsMonitorEntry struct { + entry *dnsMonitorEntry + callback DnsMonitorCallback +} + +func (e *DnsMonitorEntry) Hostname() string { + return e.entry.hostname +} + +type dnsMonitorEntry struct { + hostname string + + ips []net.IP + entries map[*DnsMonitorEntry]bool +} + +func (e *dnsMonitorEntry) setIPs(ips []net.IP, fromIP bool) { + if empty := len(e.ips) == 0; fromIP || empty { + if empty { + e.ips = ips + e.runCallbacks(ips, nil, nil) + } + return + } + + var newIPs []net.IP + var addedIPs []net.IP + var removedIPs []net.IP + var keepIPs []net.IP + for _, oldIP := range e.ips { + found := false + for idx, newIP := range ips { + if oldIP.Equal(newIP) { + ips = append(ips[:idx], ips[idx+1:]...) + found = true + keepIPs = append(keepIPs, oldIP) + newIPs = append(newIPs, oldIP) + break + } + } + + if !found { + removedIPs = append(removedIPs, oldIP) + } + } + + if len(ips) > 0 { + addedIPs = append(addedIPs, ips...) + newIPs = append(newIPs, ips...) + } + e.ips = newIPs + + if len(addedIPs) > 0 || len(removedIPs) > 0 { + e.runCallbacks(addedIPs, keepIPs, removedIPs) + } +} + +func (e *dnsMonitorEntry) runCallbacks(added []net.IP, keep []net.IP, removed []net.IP) { + for entry := range e.entries { + entry.callback(entry, added, keep, removed) + } +} + +type DnsMonitor struct { + interval time.Duration + + stopCtx context.Context + stopFunc func() + + mu sync.RWMutex + hostnames map[string]*dnsMonitorEntry +} + +func NewDnsMonitor(interval time.Duration) (*DnsMonitor, error) { + if interval < 0 { + interval = defaultDnsMonitorInterval + } + + stopCtx, stopFunc := context.WithCancel(context.Background()) + monitor := &DnsMonitor{ + interval: interval, + + stopCtx: stopCtx, + stopFunc: stopFunc, + + hostnames: make(map[string]*dnsMonitorEntry), + } + return monitor, nil +} + +func (m *DnsMonitor) Start() error { + go m.run() + return nil +} + +func (m *DnsMonitor) Stop() { + m.stopFunc() +} + +func (m *DnsMonitor) Add(hostname string, callback DnsMonitorCallback) *DnsMonitorEntry { + m.mu.Lock() + defer m.mu.Unlock() + + e := &DnsMonitorEntry{ + callback: callback, + } + + entry, found := m.hostnames[hostname] + if !found { + entry = &dnsMonitorEntry{ + hostname: hostname, + entries: make(map[*DnsMonitorEntry]bool), + } + m.hostnames[hostname] = entry + } + e.entry = entry + entry.entries[e] = true + return e +} + +func (m *DnsMonitor) Remove(entry *DnsMonitorEntry) { + m.mu.Lock() + defer m.mu.Unlock() + + if entry.entry == nil { + return + } + + e, found := m.hostnames[entry.entry.hostname] + if !found { + return + } + + entry.entry = nil + delete(e.entries, entry) +} + +func (m *DnsMonitor) run() { + ticker := time.NewTicker(m.interval) + for { + select { + case <-m.stopCtx.Done(): + return + case <-ticker.C: + m.checkHostnames() + } + } +} + +func (m *DnsMonitor) checkHostnames() { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, entry := range m.hostnames { + m.checkHostname(entry) + } +} + +func (m *DnsMonitor) checkHostname(entry *dnsMonitorEntry) { + if ip := net.ParseIP(entry.hostname); ip != nil { + entry.setIPs([]net.IP{ip}, true) + return + } + + ips, err := lookupDnsMonitorIP(entry.hostname) + if err != nil { + log.Printf("Could not lookup %s: %s", entry.hostname, err) + return + } + + entry.setIPs(ips, false) +}