Skip to content

Commit

Permalink
a bit of refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mosajjal committed Sep 7, 2024
1 parent 60bbcd7 commit 742be28
Show file tree
Hide file tree
Showing 10 changed files with 395 additions and 267 deletions.
13 changes: 7 additions & 6 deletions cmd/sniproxy/config.defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ general:
tls_cert:
# Path to the certificate key for DoH, DoT and DoQ. eg: /tmp/mycert.key
tls_key:
# HTTP Port to listen on. Should remain 80 in most cases
# HTTP Port to listen on. Should remain 80 in most cases. use :80 to listen on both IPv4 and IPv6
bind_http: "0.0.0.0:80"
# HTTPS Port to listen on. Should remain 443 in most cases
bind_https: "0.0.0.0:443"
# Enable prometheus endpoint on IP:PORT. example: 127.0.0.1:8080. Always exposes /metrics and only supports HTTP
bind_prometheus:
# Interface used for outbound TLS connections. uses OS prefered one if empty
interface:
# Preferred ip version for outgoing connections. default is 0 (random choice)
# Preferred ip version for outgoing connections. choises: ipv4 (or 4), ipv6 (or 6), ipv4only, ipv6only, any. empty (or 0) means any.
# numeric values are kept for backward compatibility
prefered_version:
# Public IPv4 of the server, reply address of DNS A queries
public_ipv4:
Expand Down Expand Up @@ -59,11 +60,11 @@ acl:
# CA will be rejected.
geoip:
enabled: false
# priority of the geoip filter. lower priority means it's checked first, meaning it can be ovveriden by other ACLs with higehr priority number.
# priority of the geoip filter. lower priority means it's checked first, meaning it can be ovveriden by other ACLs with higehr priority number.
priority: 10
# strictly blocked countries
blocked:
# allowed countries
# allowed countries
allowed:
# Path to the MMDB file. eg: /tmp/Country.mmdb, https://raw.githubusercontent.com/Loyalsoldier/geoip/release/Country.mmdb
path:
Expand All @@ -88,7 +89,7 @@ acl:
# Interval to re-fetch the domain list
refresh_interval: 1h0m0s
# FQDN override. This ACL is used to override the destination IP to not be the one resolved by the upstream DNS or the proxy itself, rather a custom IP and port
# if the destination is HTTP, it uses tls_cert and tls_key certificate to terminate the original connection.
# if the destination is HTTP, it uses tls_cert and tls_key certificate to terminate the original connection.
override:
enabled: false
# priority of the override filter. lower priority means it's checked first. if multiple filters have the same priority, they're checked in random order
Expand All @@ -101,7 +102,7 @@ acl:
"one.one.one.one": "1.1.1.1:443"
"google.com": "8.8.8.8:443"
# enable listening on DoH on a specific SNI. example: "myawesomedoh.example.com". empty disables it. If you need DoH to be enabled and don't want
# any other overrides, enable this ACL with empty rules. DoH SNI will add a default rule and start.
# any other overrides, enable this ACL with empty rules. DoH SNI will add a default rule and start.
doh_sni: "myawesomedoh.example.com"
# Path to the certificate for handling tls decryption. eg: /tmp/mycert.pem
tls_cert:
Expand Down
199 changes: 54 additions & 145 deletions cmd/sniproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@ package main

import (
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"path/filepath"
"strings"
"time"

"github.com/google/uuid"
"github.com/knadh/koanf"
"github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/env"
"github.com/knadh/koanf/providers/file"
"github.com/knadh/koanf/providers/rawbytes"
"github.com/rs/zerolog"
"github.com/txthinking/socks5"

prometheusmetrics "github.com/deathowl/go-metrics-prometheus"
"github.com/prometheus/client_golang/prometheus"
Expand All @@ -30,9 +28,6 @@ import (
_ "embed"
stdlog "log"

"github.com/miekg/dns"
"golang.org/x/net/proxy"

sniproxy "github.com/mosajjal/sniproxy/v2/pkg"
"github.com/mosajjal/sniproxy/v2/pkg/acl"
"github.com/mosajjal/sniproxy/v2/pkg/doh"
Expand All @@ -53,84 +48,6 @@ var defaultConfig []byte
var nocolorLog = strings.ToLower(os.Getenv("NO_COLOR")) == "true"
var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339, NoColor: nocolorLog})

func getPublicIPv4() (string, error) {
conn, err := net.Dial("udp", "8.8.8.8:53")
if err != nil {
return "", err
}
defer conn.Close()
localAddr := conn.LocalAddr().String()
idx := strings.LastIndex(localAddr, ":")
ipaddr := localAddr[0:idx]
if !net.ParseIP(ipaddr).IsPrivate() {
return ipaddr, nil
}
externalIP := ""
// trying to get the public IP from multiple sources to see if they match.
resp, err := http.Get("https://myexternalip.com/raw")
if err == nil {
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err == nil {
externalIP = string(body)
}

if externalIP != "" {
return externalIP, nil
}
logger.Error().Msg("Could not automatically find the public IPv4 address. Please specify it in the configuration.")

}
return "", nil
}

func cleanIPv6(ip string) string {
ip = strings.TrimPrefix(ip, "[")
ip = strings.TrimSuffix(ip, "]")
return ip
}

func getPublicIPv6() (string, error) {
conn, err := net.Dial("udp6", "[2001:4860:4860::8888]:53")
if err != nil {
return "", err
}
defer conn.Close()
localAddr := conn.LocalAddr().String()
idx := strings.LastIndex(localAddr, ":")
ipaddr := localAddr[0:idx]
if !net.ParseIP(ipaddr).IsPrivate() {
return cleanIPv6(ipaddr), nil
}
externalIP := ""
// trying to get the public IP from multiple sources to see if they match.
resp, err := http.Get("https://6.ident.me")
if err == nil {
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err == nil {
externalIP = string(body)
}

// backup method of getting a public IP
if externalIP == "" {
// dig +short -6 myip.opendns.com aaaa @2620:0:ccc::2
dnsRes, err := c.DnsClient.PerformExternalAQuery("myip.opendns.com.", dns.TypeAAAA)
if err != nil {
return "", err
}
externalIP = dnsRes[0].(*dns.AAAA).AAAA.String()
}

if externalIP != "" {
return cleanIPv6(externalIP), nil
}
logger.Error().Msg("Could not automatically find the public IPv6 address. Please specify it in the configuration.")

}
return "", nil
}

func main() {

cmd := &cobra.Command{
Expand Down Expand Up @@ -211,23 +128,47 @@ func main() {
c.BindHTTP = generalConfig.String("bind_http")
c.BindHTTPS = generalConfig.String("bind_https")
c.Interface = generalConfig.String("interface")
c.PublicIPv4 = generalConfig.String("public_ipv4")
if c.PublicIPv4 == "" {
c.PublicIPv4, _ = getPublicIPv4()
c.PreferredVersion = generalConfig.String("preferred_version")

// if preferred version is ipv6only, we don't need to check for ipv4 public ip
if c.PreferredVersion != "ipv6only" {
c.PublicIPv4 = generalConfig.String("public_ipv4")
if c.PublicIPv4 == "" {
var err error
c.PublicIPv4, err = sniproxy.GetPublicIPv4()
if err != nil {
logger.Fatal().Msgf("failed to get public IPv4, while ipv4 is enabled in preferred_version: %s", err)
}
logger.Info().Msgf("public IPv4 (automatically determined): %s", c.PublicIPv4)
} else {
logger.Info().Msgf("public IPv4 (manually provided): %s", c.PublicIPv4)
}
}
c.PublicIPv6 = generalConfig.String("public_ipv6")
if c.PublicIPv6 == "" {
c.PublicIPv6, _ = getPublicIPv6()
// if preferred version is ipv4only, we don't need to check for ipv6 public ip
if c.PreferredVersion != "ipv4only" {
c.PublicIPv6 = generalConfig.String("public_ipv6")
if c.PublicIPv6 == "" {
var err error
c.PublicIPv6, err = sniproxy.GetPublicIPv6()
if err != nil {
logger.Fatal().Msgf("failed to get public IPv6, while ipv6 is enabled in preferred_version: %s", err)
}
logger.Info().Msgf("public IPv6 (automatically determined): %s", c.PublicIPv6)
} else {
logger.Info().Msgf("public IPv6 (manually provided): %s", c.PublicIPv6)
}
}

// in any case, at least one public IP address is needed to run the server. if both are empty, we can't proceed
if c.PublicIPv4 == "" && c.PublicIPv6 == "" {
logger.Error().Msg("Could not automatically determine public IP. you should provide it manually using --publicIPv4 and --publicIPv6")
logger.Error().Msg("Could not automatically determine any public IP. you should provide it manually using --publicIPv4 or --publicIPv6 or both.")
logger.Error().Msg("if your environment is single-stack, you can use --preferredVersion to specify the version as ipv4only or ipv6only.")
return
}

c.BindPrometheus = generalConfig.String("prometheus")
c.AllowConnToLocal = generalConfig.Bool("allow_conn_to_local")

c.PreferredVersion = uint(generalConfig.Int("preferred_version"))

var err error
c.Acl, err = acl.StartACLs(&logger, k)
if err != nil {
Expand All @@ -236,6 +177,7 @@ func main() {
}

// set up metrics
// TODO: add ipv6 vs ipv4 metrics
c.RecievedDNS = metrics.GetOrRegisterCounter("dns.requests.recieved", metrics.DefaultRegistry)
c.ProxiedDNS = metrics.GetOrRegisterCounter("dns.requests.proxied", metrics.DefaultRegistry)
c.RecievedHTTP = metrics.GetOrRegisterCounter("http.requests.recieved", metrics.DefaultRegistry)
Expand All @@ -257,30 +199,21 @@ func main() {
}()
}

if c.PublicIPv4 != "" {
logger.Info().Str("public_ip", c.PublicIPv4).Msg("server info")
} else {
logger.Error().Msg("Could not automatically determine public IPv4. you should provide it manually using --publicIPv4")
}

if c.PublicIPv6 != "" {
logger.Info().Str("public_ip", c.PublicIPv6).Msg("server info")
} else {
logger.Error().Msg("Could not automatically determine public IPv6. you should provide it manually using --publicIPv6")
}

// generate self-signed certificate if not provided
// generate self-signed certificate if not provided.
if c.TLSCert == "" && c.TLSKey == "" {
_, _, err := doh.GenerateSelfSignedCertKey(c.PublicIPv4, nil, nil, os.TempDir())
// generate a random 16 char string as hostname
hostname := uuid.NewString()[:16]
logger.Info().Msg("certificate was not provided, generating a self signed cert in temp directory")
_, _, err := doh.GenerateSelfSignedCertKey(hostname, nil, nil, os.TempDir())
if err != nil {
logger.Error().Msgf("error while generating self-signed cert: %s", err)
}
c.TLSCert = filepath.Join(os.TempDir(), c.PublicIPv4+".crt")
c.TLSKey = filepath.Join(os.TempDir(), c.PublicIPv4+".key")
c.TLSCert = filepath.Join(os.TempDir(), hostname+".crt")
c.TLSKey = filepath.Join(os.TempDir(), hostname+".key")
}

// Finds source addr for outbound connections if interface is not empty
// if the "interface" configuration is provided, sniproxy must translate the interface name to the IP addresses
// and add them to the source address list
if c.Interface != "" {
logger.Info().Msgf("Using interface %s", c.Interface)
ief, err := net.InterfaceByName(c.Interface)
Expand All @@ -297,46 +230,22 @@ func main() {
}
}

if c.UpstreamSOCKS5 != "" {
uri, err := url.Parse(c.UpstreamSOCKS5)
if err != nil {
logger.Error().Msg(err.Error())
}
if uri.Scheme != "socks5" {
logger.Error().Msg("only SOCKS5 is supported")
return
}

logger.Info().Msgf("Using an upstream SOCKS5 proxy: %s", uri.Host)
socksAuth := new(proxy.Auth)
socksAuth.User = uri.User.Username()
socksAuth.Password, _ = uri.User.Password()
c.Dialer, err = socks5.NewClient(uri.Host, socksAuth.User, socksAuth.Password, 60, 60)
if err != nil {
logger.Error().Msg(err.Error())
}
} else {
c.Dialer = proxy.Direct
// set up dialer based on SOCKS5 configuration
if err := c.SetDialer(logger); err != nil {
logger.Error().Msgf("error setting up dialer: %v", err)
return
}

dnsProxy := c.UpstreamSOCKS5
if c.UpstreamSOCKS5 != "" && !c.UpstreamDNSOverSocks5 {
logger.Debug().Msg("disabling socks5 for dns")
dnsProxy = ""
}
tmp, err := sniproxy.NewDNSClient(&c, c.UpstreamDNS, true, dnsProxy)
if err != nil {
logger.Error().Msgf("error setting up dns client, removing proxy if provided: %v", err)
tmp, err = sniproxy.NewDNSClient(&c, c.UpstreamDNS, false, "")
if err != nil {
logger.Error().Msgf("error setting up dns client: %v", err)
return
}
// set up the DNS Client based on the configuration
if err := c.SetDNSClient(logger); err != nil {
logger.Error().Msgf("error setting up DNS client: %v", err)
return
}
c.DnsClient = *tmp
go sniproxy.RunHTTP(&c, logger)

go sniproxy.RunHTTP(&c, logger.With().Str("service", "http").Logger())
go sniproxy.RunHTTPS(&c, logger.With().Str("service", "https").Logger())
go sniproxy.RunDNS(&c, logger.With().Str("service", "dns").Logger())

// wait forever. TODO: add signal handling here
select {}
}
Loading

0 comments on commit 742be28

Please sign in to comment.