diff --git a/split.go b/split.go index d25942c..4b33a0e 100644 --- a/split.go +++ b/split.go @@ -10,6 +10,7 @@ import ( "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/metrics" clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/upstream" "github.com/coredns/coredns/request" "github.com/miekg/dns" ) @@ -18,6 +19,18 @@ import ( // friends to log. var log = clog.NewWithPlugin("split") +const noFallback = "split-no-fallback" + +func isNoFallback(ctx context.Context) bool { + if ctx == nil { + return false + } + if v, ok := ctx.Value(noFallback).(bool); ok { + return v + } + return false +} + // Split is an example plugin to show how to write a plugin. type Split struct { Next plugin.Handler @@ -48,7 +61,7 @@ func (s Split) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( log.Debug("Received response") // Wrap. - pw := s.NewResponsePrinter(w, r) + pw := s.NewResponsePrinter(ctx, w, r) // Export metric with the server label set to the current server handling the request. requestCount.WithLabelValues(metrics.WithServer(ctx)).Inc() @@ -63,6 +76,7 @@ func (s Split) Name() string { return "split" } // ResponsePrinter wrap a dns.ResponseWriter and will write example to standard output when WriteMsg is called. type ResponsePrinter struct { dns.ResponseWriter + ctx context.Context state request.Request r *dns.Msg src net.IP @@ -70,30 +84,22 @@ type ResponsePrinter struct { } // NewResponsePrinter returns ResponseWriter. -func (s Split) NewResponsePrinter(w dns.ResponseWriter, r *dns.Msg) *ResponsePrinter { +func (s Split) NewResponsePrinter(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) *ResponsePrinter { state := request.Request{W: w, Req: r} ip := net.ParseIP(state.IP()) - return &ResponsePrinter{ResponseWriter: w, r: r, src: ip, rules: s.Rules, state: state} + return &ResponsePrinter{ctx: ctx, ResponseWriter: w, r: r, src: ip, rules: s.Rules, state: state} } // WriteMsg calls the underlying ResponseWriter's WriteMsg method and prints "example" to standard output. func (r *ResponsePrinter) WriteMsg(res *dns.Msg) error { - var rule Rule - for _, v := range r.rules { - zone := plugin.Zones(v.Zones).Matches(r.state.Name()) - if zone == "" { - continue - } - rule = v - break - } - var answers []dns.RR - var netAnswers []dns.RR - for _, v := range res.Answer { - rec, ok := v.(*dns.A) - if !ok { - answers = append(answers, v) - continue + filter := func(rec *dns.A) (rule Rule, allowed, match bool) { + for _, v := range r.rules { + zone := plugin.Zones(v.Zones).Matches(r.state.Name()) + if zone == "" { + continue + } + rule = v + break } var net *Network for _, vv := range rule.Networks { @@ -103,34 +109,91 @@ func (r *ResponsePrinter) WriteMsg(res *dns.Msg) error { } } if net == nil { - answers = append(answers, v) - continue + return rule, true, false } - allowed := false + for _, vv := range net.Allowed { if vv.Contains(r.src) { - allowed = true - break + return rule, true, true } } - if allowed { + return rule, false, true + } + var ( + rule Rule + answers []dns.RR + netAnswers []dns.RR + ) + + for _, v := range res.Answer { + switch rec := v.(type) { + case *dns.A: + var allowed, match bool + rule, allowed, match = filter(rec) + if !match { + answers = append(answers, v) + continue + } + if allowed { + answers = append(answers, v) + netAnswers = append(netAnswers, v) + continue + } + log.Infof("request source %s: %s: filtering %s", r.src.String(), rec.Hdr.Name, rec.A) + case *dns.CNAME: + res, err := r.query(rec.Target) + if err != nil { + log.Errorf("error querying %s: %s", rec.Target, err) + continue + } + if res == nil || len(res.Answer) == 0 { + log.Debugf("no answers for %s", rec.Target) + continue + } answers = append(answers, v) - netAnswers = append(netAnswers, v) - continue + case *dns.SRV: + res, err := r.query(rec.Target) + if err != nil { + log.Errorf("error querying %s: %s", rec.Target, err) + continue + } + if res == nil || len(res.Answer) == 0 { + log.Debugf("no answers for %s", rec.Target) + continue + } + answers = append(answers, v) + case *dns.PTR: + a, err := r.query(rec.Ptr) + if err != nil { + log.Errorf("error querying %s: %s", rec.Ptr, err) + continue + } + if res == nil || len(a.Answer) == 0 { + log.Debugf("no answer for %s", rec.Ptr) + continue + } + answers = append(answers, v) + default: + return r.ResponseWriter.WriteMsg(res) } - log.Infof("request source %s: %s: filtering %s", r.src.String(), rec.Hdr.Name, rec.A) } if len(netAnswers) != 0 { res.Answer = netAnswers } else { res.Answer = answers } - if len(res.Answer) != 0 { + if len(res.Answer) != 0 || len(rule.Zones) == 0 { + return r.ResponseWriter.WriteMsg(res) + } + if isNoFallback(r.ctx) { + log.Debugf("no fallback requested for %s", r.state.Name()) return r.ResponseWriter.WriteMsg(res) } if rule.Fallback == nil { - return nil + log.Debugf("no fallback configured for zones %v", rule.Zones) + return r.ResponseWriter.WriteMsg(res) } + log.Debugf("request source %s: %s: using fallback %s", r.src.String(), r.state.Name(), rule.Fallback) c := new(dns.Client) req := r.state.Req.Copy() req.Id = dns.Id() @@ -141,3 +204,13 @@ func (r *ResponsePrinter) WriteMsg(res *dns.Msg) error { res.Answer = append(res.Answer, in.Answer...) return r.ResponseWriter.WriteMsg(res) } + +func (r *ResponsePrinter) query(name string) (*dns.Msg, error) { + log.Debugf("internally querying %s", name) + ctx := context.WithValue(r.ctx, noFallback, true) + res, err := upstream.New().Lookup(ctx, r.state, name, dns.TypeA) + if err != nil { + return nil, err + } + return res, nil +}