diff --git a/proxy/upstreams.go b/proxy/upstreams.go index 0692cd576..241d2f9cc 100644 --- a/proxy/upstreams.go +++ b/proxy/upstreams.go @@ -34,8 +34,10 @@ type UpstreamConfig struct { // type check var _ io.Closer = (*UpstreamConfig)(nil) -// ParseUpstreamsConfig returns UpstreamConfig and error if upstreams -// configuration is invalid. +// ParseUpstreamsConfig returns an UpstreamConfig and nil error if the upstream +// configuration is valid. Otherwise returns a partially filled UpstreamConfig +// and wrapped error containing lines with errors. It also skips empty lines +// and comments (lines starting with "#"). // // # Simple upstreams // @@ -84,13 +86,16 @@ var _ io.Closer = (*UpstreamConfig)(nil) // // TODO(e.burkov): Consider supporting multiple upstreams in a single line for // default upstream syntax. -func ParseUpstreamsConfig(upstreamConfig []string, options *upstream.Options) (*UpstreamConfig, error) { - if options == nil { - options = &upstream.Options{} +func ParseUpstreamsConfig( + lines []string, + opts *upstream.Options, +) (conf *UpstreamConfig, err error) { + if opts == nil { + opts = &upstream.Options{} } p := &configParser{ - options: options, + options: opts, upstreamsIndex: map[string]upstream.Upstream{}, domainReservedUpstreams: map[string][]upstream.Upstream{}, specifiedDomainUpstreams: map[string][]upstream.Upstream{}, @@ -98,9 +103,33 @@ func ParseUpstreamsConfig(upstreamConfig []string, options *upstream.Options) (* subdomainsOnlyExclusions: stringutil.NewSet(), } - return p.parse(upstreamConfig) + return p.parse(lines) } +// ParseError is an error which contains an index of the line of the upstream +// list. +type ParseError struct { + // err is the original error. + err error + + // Idx is an index of the lines. See [ParseUpstreamsConfig]. + Idx int +} + +// type check +var _ error = (*ParseError)(nil) + +// Error implements the [error] interface for *ParseError. +func (e *ParseError) Error() (msg string) { + return fmt.Sprintf("parsing error at index %d: %s", e.Idx, e.err) +} + +// type check +var _ errors.Wrapper = (*ParseError)(nil) + +// Unwrap implements the [errors.Wrapper] interface for *ParseError. +func (e *ParseError) Unwrap() (unwrapped error) { return e.err } + // configParser collects the results of parsing an upstream config. type configParser struct { // options contains upstream properties. @@ -129,10 +158,11 @@ type configParser struct { } // parse returns UpstreamConfig and error if upstreams configuration is invalid. -func (p *configParser) parse(conf []string) (c *UpstreamConfig, err error) { - for i, l := range conf { +func (p *configParser) parse(lines []string) (c *UpstreamConfig, err error) { + var errs []error + for i, l := range lines { if err = p.parseLine(i, l); err != nil { - return nil, err + errs = append(errs, &ParseError{Idx: i, err: err}) } } @@ -147,12 +177,16 @@ func (p *configParser) parse(conf []string) (c *UpstreamConfig, err error) { DomainReservedUpstreams: p.domainReservedUpstreams, SpecifiedDomainUpstreams: p.specifiedDomainUpstreams, SubdomainExclusions: p.subdomainsOnlyExclusions, - }, nil + }, errors.Join(errs...) } // parseLine returns an error if upstream configuration line is invalid. func (p *configParser) parseLine(idx int, confLine string) (err error) { - upstreams, domains, err := splitConfigLine(idx, confLine) + if len(confLine) == 0 || confLine[0] == '#' { + return nil + } + + upstreams, domains, err := splitConfigLine(confLine) if err != nil { // Don't wrap the error since it's informative enough as is. return err @@ -165,7 +199,7 @@ func (p *configParser) parseLine(idx int, confLine string) (err error) { } for _, u := range upstreams { - err = p.specifyUpstream(domains, u, idx, confLine) + err = p.specifyUpstream(domains, u, idx) if err != nil { // Don't wrap the error since it's informative enough as is. return err @@ -177,15 +211,15 @@ func (p *configParser) parseLine(idx int, confLine string) (err error) { // splitConfigLine parses upstream configuration line and returns list upstream // addresses (one or many), list of domains for which this upstream is reserved -// (may be nil) or error if something went wrong. -func splitConfigLine(idx int, confLine string) (upstreams, domains []string, err error) { +// (may be nil). It returns an error if the upstream format is incorrect. +func splitConfigLine(confLine string) (upstreams, domains []string, err error) { if !strings.HasPrefix(confLine, "[/") { return []string{confLine}, nil, nil } domainsLine, upstreamsLine, found := strings.Cut(confLine[len("[/"):], "/]") if !found || upstreamsLine == "" { - return nil, nil, fmt.Errorf("wrong upstream specification %d %q", idx, confLine) + return nil, nil, errors.Error("wrong upstream format") } // split domains list @@ -209,20 +243,14 @@ func splitConfigLine(idx int, confLine string) (upstreams, domains []string, err } // specifyUpstream specifies the upstream for domains. -func (p *configParser) specifyUpstream( - domains []string, - u string, - idx int, - confLine string, -) (err error) { +func (p *configParser) specifyUpstream(domains []string, u string, idx int) (err error) { dnsUpstream, ok := p.upstreamsIndex[u] // TODO(e.burkov): Improve identifying duplicate upstreams. if !ok { // create an upstream dnsUpstream, err = upstream.AddressToUpstream(u, p.options.Clone()) if err != nil { - return fmt.Errorf("cannot prepare the upstream %d %q (%s): %s", - idx, confLine, p.options.Bootstrap, err) + return fmt.Errorf("cannot prepare the upstream: %s", err) } // save to the index @@ -231,11 +259,19 @@ func (p *configParser) specifyUpstream( addr := dnsUpstream.Address() if len(domains) == 0 { - log.Debug("dnsproxy: upstream at index %d: %s", idx, addr) + // TODO(s.chzhen): Handle duplicates. p.upstreams = append(p.upstreams, dnsUpstream) + + // TODO(s.chzhen): Logs without index. + log.Debug("dnsproxy: upstream at index %d: %s", idx, addr) } else { - log.Debug("dnsproxy: upstream at index %d: %s is reserved for %s", idx, addr, domains) p.includeToReserved(dnsUpstream, domains) + + log.Debug("dnsproxy: upstream at index %d: %s is reserved for %d domains", + idx, + addr, + len(domains), + ) } return nil diff --git a/upstream/upstream.go b/upstream/upstream.go index 9cd345d87..47627a328 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -151,6 +151,7 @@ const ( // AddressToUpstream converts addr to an Upstream using the specified options. // addr can be either a URL, or a plain address, either a domain name or an IP. // +// - 1.2.3.4 or 1.2.3.4:4321 for plain DNS using IP address; // - udp://5.3.5.3:53 or 5.3.5.3:53 for plain DNS using IP address; // - udp://name.server:53 or name.server:53 for plain DNS using domain name; // - tcp://5.3.5.3:53 for plain DNS-over-TCP using IP address; @@ -178,31 +179,57 @@ func AddressToUpstream(addr string, opts *Options) (u Upstream, err error) { var uu *url.URL if strings.Contains(addr, "://") { - // Parse as URL. uu, err = url.Parse(addr) if err != nil { return nil, fmt.Errorf("failed to parse %s: %w", addr, err) } } else { - // Probably, plain UDP upstream defined by address or address:port. - _, port, splitErr := net.SplitHostPort(addr) - if splitErr == nil { - // Validate port. - _, err = strconv.ParseUint(port, 10, 16) - if err != nil { - return nil, fmt.Errorf("invalid address %s: %w", addr, err) - } - } - uu = &url.URL{ Scheme: "udp", Host: addr, } } + err = validateUpstreamURL(uu) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return nil, err + } + return urlToUpstream(uu, opts) } +// validateUpstreamURL returns an error if the upstream URL is not valid. +func validateUpstreamURL(u *url.URL) (err error) { + if u.Scheme == "sdns" { + return nil + } + + host := u.Host + h, port, splitErr := net.SplitHostPort(host) + if splitErr == nil { + // Validate port. + _, err = strconv.ParseUint(port, 10, 16) + if err != nil { + return fmt.Errorf("invalid port %s: %w", port, err) + } + + host = h + } + + _, err = netip.ParseAddr(host) + if err == nil { + return nil + } + + err = netutil.ValidateHostname(host) + if err != nil { + return fmt.Errorf("invalid address %s: %w", host, err) + } + + return nil +} + // urlToUpstream converts uu to an Upstream using opts. func urlToUpstream(uu *url.URL, opts *Options) (u Upstream, err error) { switch sch := uu.Scheme; sch { diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go index 0ad53b660..cabe37c9e 100644 --- a/upstream/upstream_test.go +++ b/upstream/upstream_test.go @@ -228,6 +228,14 @@ func TestAddressToUpstream(t *testing.T) { addr: "1.1.1.1", opt: nil, want: "1.1.1.1:53", + }, { + addr: "1.1.1.1:5353", + opt: nil, + want: "1.1.1.1:5353", + }, { + addr: "one:5353", + opt: nil, + want: "one:5353", }, { addr: "one.one.one.one", opt: nil, @@ -274,15 +282,48 @@ func TestAddressToUpstream_bads(t *testing.T) { wantErrMsg: "unsupported url scheme: asdf", }, { addr: "12345.1.1.1:1234567", - wantErrMsg: `invalid address 12345.1.1.1:1234567: ` + - `strconv.ParseUint: parsing "1234567": value out of range`, + wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` + + `value out of range`, }, { addr: ":1234567", - wantErrMsg: `invalid address :1234567: ` + - `strconv.ParseUint: parsing "1234567": value out of range`, + wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` + + `value out of range`, }, { addr: "host:", - wantErrMsg: `invalid address host:: strconv.ParseUint: parsing "": invalid syntax`, + wantErrMsg: `invalid port : strconv.ParseUint: parsing "": invalid syntax`, + }, { + addr: ":53", + wantErrMsg: `invalid address : bad hostname "": hostname is empty`, + }, { + addr: "!!!", + wantErrMsg: `invalid address !!!: bad hostname "!!!": bad top-level domain name ` + + `label "!!!": bad top-level domain name label rune '!'`, + }, { + addr: "123", + wantErrMsg: `invalid address 123: bad hostname "123": bad top-level domain name ` + + `label "123": all octets are numeric`, + }, { + addr: "tcp://12345.1.1.1:1234567", + wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` + + `value out of range`, + }, { + addr: "tcp://:1234567", + wantErrMsg: `invalid port 1234567: strconv.ParseUint: parsing "1234567": ` + + `value out of range`, + }, { + addr: "tcp://host:", + wantErrMsg: `invalid port : strconv.ParseUint: parsing "": invalid syntax`, + }, { + addr: "tcp://:53", + wantErrMsg: `invalid address : bad hostname "": hostname is empty`, + }, { + addr: "tcp://!!!", + wantErrMsg: `invalid address !!!: bad hostname "!!!": bad top-level domain name ` + + `label "!!!": bad top-level domain name label rune '!'`, + }, { + addr: "tcp://123", + wantErrMsg: `invalid address 123: bad hostname "123": bad top-level domain name ` + + `label "123": all octets are numeric`, }} for _, tc := range testCases {