Skip to content

Commit

Permalink
a bit of refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mosajjal committed Oct 13, 2024
1 parent 511b0ca commit 3b37c31
Show file tree
Hide file tree
Showing 14 changed files with 46 additions and 32 deletions.
2 changes: 1 addition & 1 deletion cmd/sniproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func main() {
c.AllowConnToLocal = generalConfig.Bool("allow_conn_to_local")

var err error
c.Acl, err = acl.StartACLs(&logger, k)
c.ACL, err = acl.StartACLs(&logger, k)
if err != nil {
logger.Error().Msgf("failed to start ACLs: %s", err)
return
Expand Down
12 changes: 7 additions & 5 deletions pkg/acl/acl.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Package acl contains the logic for Access Control Lists. It provides a way to make decisions based on the connection information.
package acl

import (
Expand Down Expand Up @@ -35,12 +36,13 @@ type ConnInfo struct {
Decision
}

type ByPriority []ACL
type byPriority []ACL

func (a ByPriority) Len() int { return len(a) }
func (a ByPriority) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByPriority) Less(i, j int) bool { return a[i].Priority() < a[j].Priority() }
func (a byPriority) Len() int { return len(a) }
func (a byPriority) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byPriority) Less(i, j int) bool { return a[i].Priority() < a[j].Priority() }

// ACL is the interface that each ACL should implement
type ACL interface {
Decide(*ConnInfo) error
Name() string
Expand Down Expand Up @@ -74,7 +76,7 @@ func StartACLs(log *zerolog.Logger, k *koanf.Koanf) ([]ACL, error) {

// MakeDecision loops through all the ACLs and makes a decision for the connection
func MakeDecision(c *ConnInfo, a []ACL) error {
sort.Sort(ByPriority(a))
sort.Sort(byPriority(a))
for _, acl := range a {
if err := acl.Decide(c); err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/acl/cidr.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (d *cidr) LoadCIDRCSV(path string) error {
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
d.logger.Info().Msg("CIDR list is a URL, trying to fetch")
client := http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
CheckRedirect: func(r *http.Request, _ []*http.Request) error {
r.URL.Opaque = r.URL.Path
return nil
},
Expand Down
2 changes: 1 addition & 1 deletion pkg/acl/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (d *domain) LoadDomainsCsv(Filename string) error {
if strings.HasPrefix(Filename, "http://") || strings.HasPrefix(Filename, "https://") {
d.logger.Info().Msg("domain list is a URL, trying to fetch")
client := http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
CheckRedirect: func(r *http.Request, _ []*http.Request) error {
r.URL.Opaque = r.URL.Path
return nil
},
Expand Down
11 changes: 6 additions & 5 deletions pkg/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"golang.org/x/net/proxy"
)

// Config is the main runtime configuration for the proxy
type Config struct {
PublicIPv4 string `yaml:"public_ipv4"`
PublicIPv6 string `yaml:"public_ipv6"`
Expand All @@ -36,9 +37,9 @@ type Config struct {
BindPrometheus string `yaml:"bind_prometheus"`
AllowConnToLocal bool `yaml:"allow_conn_to_local"`

Acl []acl.ACL `yaml:"-"`
ACL []acl.ACL `yaml:"-"`

DnsClient DNSClient `yaml:"-"`
DNSClient DNSClient `yaml:"-"`
Dialer proxy.Dialer `yaml:"-"`
// list of interface source IPs; used to rotate source IPs when initializing connections
SourceAddr []netip.Addr `yaml:"-"`
Expand Down Expand Up @@ -108,7 +109,7 @@ func (c *Config) SetDNSClient(logger zerolog.Logger) error {
return fmt.Errorf("error setting up dns client: %v", err)
}
}
c.DnsClient = *dnsClient
c.DNSClient = *dnsClient
return nil
}

Expand Down Expand Up @@ -170,7 +171,7 @@ func parseBinders(bind string, additional []string) ([]string, error) {
// and the additional bind addresses from bind_http_additional as a list of ports or port ranges
// such as 8080, 8081-8083, 8085
// when this function is called, it will compile the list of bind addresses and store it in BindHTTPListeners
func (c *Config) SetBindHTTPListeners(logger zerolog.Logger) error {
func (c *Config) SetBindHTTPListeners(_ zerolog.Logger) error {
bindAddresses, err := parseBinders(c.BindHTTP, c.BindHTTPAdditional)
if err != nil {
return fmt.Errorf("error parsing bind addresses for HTTP: %v", err)
Expand All @@ -180,7 +181,7 @@ func (c *Config) SetBindHTTPListeners(logger zerolog.Logger) error {
}

// SetBindHTTPSListeners sets up a list of bind addresses for HTTPS
func (c *Config) SetBindHTTPSListeners(logger zerolog.Logger) error {
func (c *Config) SetBindHTTPSListeners(_ zerolog.Logger) error {
bindAddresses, err := parseBinders(c.BindHTTPS, c.BindHTTPSAdditional)
if err != nil {
return fmt.Errorf("error parsing bind addresses for HTTPS: %v", err)
Expand Down
8 changes: 4 additions & 4 deletions pkg/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ func (c *Config) pickSrcAddr(version string) net.IP {
return nil
}

// PerformExternalAQuery performs an external DNS query for the given domain name.
func (dnsc *DNSClient) PerformExternalAQuery(fqdn string, QType uint16) ([]dns.RR, error) {
if !strings.HasSuffix(fqdn, ".") {
fqdn = fqdn + "."
Expand Down Expand Up @@ -129,7 +130,7 @@ func processQuestion(c *Config, l zerolog.Logger, q dns.Question, decision acl.D
// Otherwise do an upstream query and use that answer.
default:
l.Debug().Msgf("perform external query for domain %s", q.Name)
resp, err := c.DnsClient.PerformExternalAQuery(q.Name, q.Qtype)
resp, err := c.DNSClient.PerformExternalAQuery(q.Name, q.Qtype)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -226,7 +227,7 @@ func handleDNS(c *Config, l zerolog.Logger) dns.HandlerFunc {
SrcIP: w.RemoteAddr(),
Domain: q.Name,
}
acl.MakeDecision(&connInfo, c.Acl)
acl.MakeDecision(&connInfo, c.ACL)
answers, err := processQuestion(c, l, q, connInfo.Decision)
if err != nil {
continue
Expand All @@ -238,6 +239,7 @@ func handleDNS(c *Config, l zerolog.Logger) dns.HandlerFunc {
}
}

// RunDNS starts DNS servers based on the provided configuration.
func RunDNS(c *Config, l zerolog.Logger) {
dns.HandleFunc(".", handleDNS(c, l))
// start DNS UDP serverUdp
Expand Down Expand Up @@ -296,8 +298,6 @@ func RunDNS(c *Config, l zerolog.Logger) {
if err != nil {
l.Error().Msg(err.Error())
}
tlsConfig := &tls.Config{}
tlsConfig.Certificates = []tls.Certificate{crt}

// Create the QUIC listener
doqConf := doqserver.Config{
Expand Down
2 changes: 2 additions & 0 deletions pkg/doc.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*
Package sniproxy is a simple SNI proxy server that allows you to serve multiple SSL-enabled websites from a single IP address.
Continuation of [byosh] and [SimpleSNIProxy] projects.
# pre-requisites
Expand Down
1 change: 1 addition & 0 deletions pkg/doh/certtools.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Package doh contains the logic for DNS over HTTPS. It provides a way to make decisions based on the connection information.
package doh

import (
Expand Down
6 changes: 4 additions & 2 deletions pkg/doh/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import (
"golang.org/x/net/idna"
)

func (s *Server) parseRequestGoogle(ctx context.Context, w http.ResponseWriter, r *http.Request) *DNSRequest {
func (s *Server) parseRequestGoogle(_ context.Context, _ http.ResponseWriter, r *http.Request) *DNSRequest {
name := r.FormValue("name")
if name == "" {
return &DNSRequest{
Expand All @@ -59,6 +59,7 @@ func (s *Server) parseRequestGoogle(ctx context.Context, w http.ResponseWriter,
rrTypeStr := r.FormValue("type")
rrType := uint16(1)
if rrTypeStr == "" {
// Do nothing and continue
} else if v, err := strconv.ParseUint(rrTypeStr, 10, 16); err == nil {
rrType = uint16(v)
} else if v, ok := dns.StringToType[strings.ToUpper(rrTypeStr)]; ok {
Expand All @@ -75,6 +76,7 @@ func (s *Server) parseRequestGoogle(ctx context.Context, w http.ResponseWriter,
if cdStr == "1" || strings.EqualFold(cdStr, "true") {
cd = true
} else if cdStr == "0" || strings.EqualFold(cdStr, "false") || cdStr == "" {
// Do nothing and continue
} else {
return &DNSRequest{
errcode: 400,
Expand Down Expand Up @@ -177,7 +179,7 @@ func parseSubnet(ednsClientSubnet string) (ednsClientFamily uint16, ednsClientAd
return
}

func (s *Server) generateResponseGoogle(ctx context.Context, w http.ResponseWriter, r *http.Request, req *DNSRequest) {
func (s *Server) generateResponseGoogle(_ context.Context, w http.ResponseWriter, _ *http.Request, req *DNSRequest) {
respJSON := jsondns.Marshal(req.response)
respStr, err := json.Marshal(respJSON)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions pkg/doh/ietf.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import (
"github.com/miekg/dns"
)

func (s *Server) parseRequestIETF(ctx context.Context, w http.ResponseWriter, r *http.Request) *DNSRequest {
func (s *Server) parseRequestIETF(_ context.Context, w http.ResponseWriter, r *http.Request) *DNSRequest {
requestBase64 := r.FormValue("dns")
requestBinary, err := base64.RawURLEncoding.DecodeString(requestBase64)
if err != nil {
Expand Down Expand Up @@ -95,7 +95,7 @@ func (s *Server) parseRequestIETF(ctx context.Context, w http.ResponseWriter, r
} else {
questionType = strconv.FormatUint(uint64(question.Qtype), 10)
}
var clientip net.IP = nil
var clientip net.IP
if s.conf.LogGuessedIP {
clientip = s.findClientIP(r)
}
Expand Down Expand Up @@ -166,7 +166,7 @@ func (s *Server) parseRequestIETF(ctx context.Context, w http.ResponseWriter, r
}
}

func (s *Server) generateResponseIETF(ctx context.Context, w http.ResponseWriter, r *http.Request, req *DNSRequest) {
func (s *Server) generateResponseIETF(_ context.Context, w http.ResponseWriter, _ *http.Request, req *DNSRequest) {
respJSON := jsondns.Marshal(req.response)
req.response.Id = req.transactionID
respBytes, err := req.response.Pack()
Expand Down
11 changes: 8 additions & 3 deletions pkg/doh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/miekg/dns"
)

// Server is a DNS-over-HTTPS server runtime
type Server struct {
conf *config
udpClient *dns.Client
Expand All @@ -49,6 +50,7 @@ type Server struct {
servemux *http.ServeMux
}

// DNSRequest is a DNS request
type DNSRequest struct {
request *dns.Msg
response *dns.Msg
Expand All @@ -59,6 +61,7 @@ type DNSRequest struct {
errtext string
}

// NewDefaultConfig creates a new default config
func NewDefaultConfig() *config {
conf := &config{}
if len(conf.Listen) == 0 {
Expand All @@ -80,6 +83,7 @@ func NewDefaultConfig() *config {
return conf
}

// NewServer creates a new Server
func NewServer(conf *config) (*Server, error) {
timeout := time.Duration(conf.Timeout) * time.Second
s := &Server{
Expand Down Expand Up @@ -125,6 +129,7 @@ func NewServer(conf *config) (*Server, error) {
return s, nil
}

// Start starts the server
func (s *Server) Start() error {
servemux := http.Handler(s.servemux)
if s.conf.Verbose {
Expand Down Expand Up @@ -158,7 +163,7 @@ func (s *Server) Start() error {
TLSConfig: &tls.Config{
ClientCAs: clientCAPool,
ClientAuth: tls.RequireAndVerifyClientCert,
GetCertificate: func(info *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
GetCertificate: func(_ *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
c, err := tls.LoadX509KeyPair(s.conf.Cert, s.conf.Key)
if err != nil {
fmt.Printf("Error loading server certificate key pair: %v\n", err)
Expand Down Expand Up @@ -211,8 +216,8 @@ func (s *Server) handlerFunc(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, OPTIONS, POST")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Max-Age", "3600")
w.Header().Set("Server", USER_AGENT)
w.Header().Set("X-Powered-By", USER_AGENT)
w.Header().Set("Server", UserAgent)
w.Header().Set("X-Powered-By", UserAgent)

if r.Method == "OPTIONS" {
w.Header().Set("Content-Length", "0")
Expand Down
4 changes: 2 additions & 2 deletions pkg/doh/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@
package doh

const (
VERSION = "2.3.4"
USER_AGENT = "DNS-over-HTTPS/" + VERSION + " (+https://github.com/m13253/dns-over-https)"
// UserAgent is the default User-Agent string for the HTTP client
UserAgent = "DNS-over-HTTPS (+https://github.com/mosajjal/sniproxy)"
)
2 changes: 1 addition & 1 deletion pkg/httpproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func handle80(c *Config) http.HandlerFunc {
SrcIP: addr,
Domain: r.Host,
}
acl.MakeDecision(&connInfo, c.Acl)
acl.MakeDecision(&connInfo, c.ACL)
if connInfo.Decision == acl.Reject || connInfo.Decision == acl.OriginIP || err != nil {
httplog.Info().Str("src_ip", r.RemoteAddr).Msgf("rejected request")
http.Error(w, "Could not reach origin server", 403)
Expand Down
9 changes: 5 additions & 4 deletions pkg/https.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func handleTLS(c *Config, conn net.Conn, httpslog zerolog.Logger) error {
SrcIP: conn.RemoteAddr(),
Domain: sni,
}
acl.MakeDecision(&connInfo, c.Acl)
acl.MakeDecision(&connInfo, c.ACL)

if connInfo.Decision == acl.Reject {
httpslog.Warn().Msgf("ACL rejection srcip=%s", conn.RemoteAddr().String())
Expand All @@ -80,7 +80,7 @@ func handleTLS(c *Config, conn net.Conn, httpslog zerolog.Logger) error {
rPort = connInfo.DstIP.Port
} else {
// TODO: lookup needs to be both ipv4 and ipv6
rAddrTmp, err := c.DnsClient.lookupDomain(sni, c.PreferredVersion)
rAddrTmp, err := c.DNSClient.lookupDomain(sni, c.PreferredVersion)
if err != nil {
httpslog.Warn().Msg(err.Error())
return err
Expand Down Expand Up @@ -173,10 +173,11 @@ func getPortFromConn(conn net.Conn) int {
return portnum
}

// RunHTTPS starts the HTTPS server on the configured bind
// "bind" format is as ip:port
func RunHTTPS(c *Config, bind string, log zerolog.Logger) {
if l, err := net.Listen("tcp", bind); err != nil {
log.Error().Msg(err.Error())
panic(-1)
log.Fatal().Msg(err.Error())
} else {
log.Info().Msgf("listening https on %s", bind)
defer l.Close()
Expand Down

0 comments on commit 3b37c31

Please sign in to comment.