protofilters/matcher/matcher.go

269 lines
6.6 KiB
Go

package matcher
import (
"errors"
"fmt"
"regexp"
"strings"
"sync"
"google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect"
pf "go.linka.cloud/protofilters"
)
type Matcher interface {
Match(m proto.Message, f *pf.FieldsFilter) (bool, error)
MatchFilters(m proto.Message, fs ...*pf.FieldFilter) (bool, error)
}
type CachingMatcher interface {
Matcher
ResetCache()
}
var defaultMatcher CachingMatcher = &matcher{cache: make(map[string]pref.FieldDescriptor)}
func Match(m proto.Message, f *pf.FieldsFilter) (bool, error) {
return defaultMatcher.Match(m, f)
}
func MatchFilters(m proto.Message, fs ...*pf.FieldFilter) (bool, error) {
return defaultMatcher.MatchFilters(m, fs...)
}
type matcher struct {
mu sync.RWMutex
cache map[string]pref.FieldDescriptor
}
func (x *matcher) Match(m proto.Message, f *pf.FieldsFilter) (bool, error) {
if m == nil {
return false, errors.New("message is null")
}
if f == nil {
return true, nil
}
for path, filter := range f.Filters {
fd, err := x.lookup(m, path)
if err != nil {
return false, err
}
ok, err := match(m, fd, filter)
if err != nil {
return false, err
}
if !ok {
return false, nil
}
}
return true, nil
}
func (x *matcher) MatchFilters(m proto.Message, fs ...*pf.FieldFilter) (bool, error) {
f := pf.New(fs...)
return x.Match(m, f)
}
func (x *matcher) ResetCache() {
x.mu.Lock()
x.cache = make(map[string]pref.FieldDescriptor)
x.mu.Unlock()
}
func (x *matcher) lookup(m proto.Message, path string) (pref.FieldDescriptor, error) {
if x.cache == nil {
x.mu.Lock()
x.cache = make(map[string]pref.FieldDescriptor)
x.mu.Unlock()
}
key := fmt.Sprintf("%s.%s", m.ProtoReflect().Descriptor().FullName(), path)
x.mu.RLock()
fd, ok := x.cache[key]
x.mu.RUnlock()
if ok {
return fd, nil
}
md0 := m.ProtoReflect().Descriptor()
md := md0
fd, ok = rangeFields(path, func(field string) (pref.FieldDescriptor, bool) {
// Search the field within the message.
if md == nil {
return nil, false // not within a message
}
fd := md.Fields().ByName(pref.Name(field))
// The real field name of a group is the message name.
if fd == nil {
gd := md.Fields().ByName(pref.Name(strings.ToLower(field)))
if gd != nil && gd.Kind() == pref.GroupKind && string(gd.Message().Name()) == field {
fd = gd
}
} else if fd.Kind() == pref.GroupKind && string(fd.Message().Name()) != field {
fd = nil
}
if fd == nil {
return nil, false // message does not have this field
}
// Identify the next message to search within.
md = fd.Message() // may be nil
// Repeated fields are only allowed at the last postion.
if fd.IsList() || fd.IsMap() {
md = nil
}
return fd, true
})
if !ok {
return nil, fmt.Errorf("%s does not contain '%s'", md0.FullName(), path)
}
x.mu.Lock()
x.cache[key] = fd
x.mu.Unlock()
return fd, nil
}
func match(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
switch f.GetMatch().(type) {
case *pf.Filter_String_:
return matchString(m, fd, f)
case *pf.Filter_Number:
return matchNumber(m, fd, f)
case *pf.Filter_Bool:
return matchBool(m, fd, f)
case *pf.Filter_Null:
return matchNull(m, fd, f)
}
return false, nil
}
func matchNull(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
var match bool
switch fd.Kind() {
case pref.MessageKind:
match = !m.ProtoReflect().Has(fd)
case pref.GroupKind:
match = m.ProtoReflect().Get(fd).List().Len() == 0
default:
return false, fmt.Errorf("cannot use null filter on %s", fd.Kind().String())
}
if f.GetNull().GetNot() {
return !match, nil
}
return match, nil
}
func matchBool(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
if fd.Kind() != pref.BoolKind {
return false, fmt.Errorf("cannot use bool filter on %s", fd.Kind().String())
}
return m.ProtoReflect().Get(fd).Bool() == f.GetBool().GetEquals(), nil
}
func matchString(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
if fd.Kind() != pref.StringKind && fd.Kind() != pref.EnumKind {
return false, fmt.Errorf("cannot use string filter on %s", fd.Kind().String())
}
insensitive := f.GetString_().GetCaseInsensitive()
rval := m.ProtoReflect().Get(fd)
value := rval.String()
if fd.Kind() == pref.EnumKind {
e := fd.Enum().Values().ByNumber(rval.Enum())
if e == nil {
return false, nil
}
value = string(e.Name())
}
var match bool
switch f.GetString_().GetCondition().(type) {
case *pf.StringFilter_Equals:
if insensitive {
match = strings.ToLower(f.GetString_().GetEquals()) == strings.ToLower(value)
} else {
match = value == f.GetString_().GetEquals()
}
case *pf.StringFilter_Regex:
reg, err := regexp.Compile(f.GetString_().GetRegex())
if err != nil {
return false, err
}
match = reg.MatchString(value)
case *pf.StringFilter_In_:
lookup:
for _, v := range f.GetString_().GetIn().GetValues() {
if (insensitive && strings.ToLower(v) == strings.ToLower(value)) || v == value {
match = true
break lookup
}
}
}
if f.GetString_().GetNot() {
return !match, nil
}
return match, nil
}
func matchNumber(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
rval := m.ProtoReflect().Get(fd)
var val float64
switch fd.Kind() {
case pref.Int32Kind,
pref.Sint32Kind,
pref.Int64Kind,
pref.Sint64Kind,
pref.Sfixed32Kind,
pref.Fixed32Kind,
pref.Sfixed64Kind,
pref.Fixed64Kind:
val = float64(rval.Int())
case pref.Uint32Kind, pref.Uint64Kind:
val = float64(rval.Uint())
case pref.FloatKind, pref.DoubleKind:
val = rval.Float()
case pref.EnumKind:
val = float64(rval.Enum())
default:
return false, fmt.Errorf("cannot use number filter on %s", fd.Kind().String())
}
var match bool
switch f.GetNumber().GetCondition().(type) {
case *pf.NumberFilter_Equals:
match = val == f.GetNumber().GetEquals()
case *pf.NumberFilter_In_:
lookup:
for _, v := range f.GetNumber().GetIn().GetValues() {
if val == v {
match = true
break lookup
}
}
}
if f.GetNumber().GetNot() {
return !match, nil
}
return match, nil
}
// rangeFields is like strings.Split(path, "."), but avoids allocations by
// iterating over each field in place and calling a iterator function.
func rangeFields(path string, f func(field string) (pref.FieldDescriptor, bool)) (pref.FieldDescriptor, bool) {
for {
var field string
if i := strings.IndexByte(path, '.'); i >= 0 {
field, path = path[:i], path[i:]
} else {
field, path = path, ""
}
v, ok := f(field)
if !ok {
return nil, false
}
if len(path) == 0 {
return v, true
}
path = strings.TrimPrefix(path, ".")
}
}