269 lines
6.6 KiB
Go
269 lines
6.6 KiB
Go
package matcher
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
|
|
"google.golang.org/protobuf/proto"
|
|
pref "google.golang.org/protobuf/reflect/protoreflect"
|
|
|
|
pf "go.linka.cloud/protofilters"
|
|
)
|
|
|
|
type Matcher interface {
|
|
Match(m proto.Message, f *pf.FieldsFilter) (bool, error)
|
|
MatchFilters(m proto.Message, fs ...*pf.FieldFilter) (bool, error)
|
|
}
|
|
|
|
type CachingMatcher interface {
|
|
Matcher
|
|
ResetCache()
|
|
}
|
|
|
|
var defaultMatcher CachingMatcher = &matcher{cache: make(map[string]pref.FieldDescriptor)}
|
|
|
|
func Match(m proto.Message, f *pf.FieldsFilter) (bool, error) {
|
|
return defaultMatcher.Match(m, f)
|
|
}
|
|
|
|
func MatchFilters(m proto.Message, fs ...*pf.FieldFilter) (bool, error) {
|
|
return defaultMatcher.MatchFilters(m, fs...)
|
|
}
|
|
|
|
type matcher struct {
|
|
mu sync.RWMutex
|
|
cache map[string]pref.FieldDescriptor
|
|
}
|
|
|
|
func (x *matcher) Match(m proto.Message, f *pf.FieldsFilter) (bool, error) {
|
|
if m == nil {
|
|
return false, errors.New("message is null")
|
|
}
|
|
if f == nil {
|
|
return true, nil
|
|
}
|
|
for path, filter := range f.Filters {
|
|
fd, err := x.lookup(m, path)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
ok, err := match(m, fd, filter)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if !ok {
|
|
return false, nil
|
|
}
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func (x *matcher) MatchFilters(m proto.Message, fs ...*pf.FieldFilter) (bool, error) {
|
|
f := pf.New(fs...)
|
|
return x.Match(m, f)
|
|
}
|
|
|
|
func (x *matcher) ResetCache() {
|
|
x.mu.Lock()
|
|
x.cache = make(map[string]pref.FieldDescriptor)
|
|
x.mu.Unlock()
|
|
}
|
|
|
|
func (x *matcher) lookup(m proto.Message, path string) (pref.FieldDescriptor, error) {
|
|
if x.cache == nil {
|
|
x.mu.Lock()
|
|
x.cache = make(map[string]pref.FieldDescriptor)
|
|
x.mu.Unlock()
|
|
}
|
|
key := fmt.Sprintf("%s.%s", m.ProtoReflect().Descriptor().FullName(), path)
|
|
x.mu.RLock()
|
|
fd, ok := x.cache[key]
|
|
x.mu.RUnlock()
|
|
if ok {
|
|
return fd, nil
|
|
}
|
|
md0 := m.ProtoReflect().Descriptor()
|
|
md := md0
|
|
fd, ok = rangeFields(path, func(field string) (pref.FieldDescriptor, bool) {
|
|
// Search the field within the message.
|
|
if md == nil {
|
|
return nil, false // not within a message
|
|
}
|
|
fd := md.Fields().ByName(pref.Name(field))
|
|
// The real field name of a group is the message name.
|
|
if fd == nil {
|
|
gd := md.Fields().ByName(pref.Name(strings.ToLower(field)))
|
|
if gd != nil && gd.Kind() == pref.GroupKind && string(gd.Message().Name()) == field {
|
|
fd = gd
|
|
}
|
|
} else if fd.Kind() == pref.GroupKind && string(fd.Message().Name()) != field {
|
|
fd = nil
|
|
}
|
|
if fd == nil {
|
|
return nil, false // message does not have this field
|
|
}
|
|
// Identify the next message to search within.
|
|
md = fd.Message() // may be nil
|
|
|
|
// Repeated fields are only allowed at the last postion.
|
|
if fd.IsList() || fd.IsMap() {
|
|
md = nil
|
|
}
|
|
|
|
return fd, true
|
|
})
|
|
if !ok {
|
|
return nil, fmt.Errorf("%s does not contain '%s'", md0.FullName(), path)
|
|
}
|
|
x.mu.Lock()
|
|
x.cache[key] = fd
|
|
x.mu.Unlock()
|
|
return fd, nil
|
|
}
|
|
|
|
func match(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
|
|
switch f.GetMatch().(type) {
|
|
case *pf.Filter_String_:
|
|
return matchString(m, fd, f)
|
|
case *pf.Filter_Number:
|
|
return matchNumber(m, fd, f)
|
|
case *pf.Filter_Bool:
|
|
return matchBool(m, fd, f)
|
|
case *pf.Filter_Null:
|
|
return matchNull(m, fd, f)
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func matchNull(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
|
|
var match bool
|
|
switch fd.Kind() {
|
|
case pref.MessageKind:
|
|
match = !m.ProtoReflect().Has(fd)
|
|
case pref.GroupKind:
|
|
match = m.ProtoReflect().Get(fd).List().Len() == 0
|
|
default:
|
|
return false, fmt.Errorf("cannot use null filter on %s", fd.Kind().String())
|
|
}
|
|
if f.GetNull().GetNot() {
|
|
return !match, nil
|
|
}
|
|
return match, nil
|
|
}
|
|
|
|
func matchBool(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
|
|
if fd.Kind() != pref.BoolKind {
|
|
return false, fmt.Errorf("cannot use bool filter on %s", fd.Kind().String())
|
|
}
|
|
return m.ProtoReflect().Get(fd).Bool() == f.GetBool().GetEquals(), nil
|
|
}
|
|
|
|
func matchString(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
|
|
if fd.Kind() != pref.StringKind && fd.Kind() != pref.EnumKind {
|
|
return false, fmt.Errorf("cannot use string filter on %s", fd.Kind().String())
|
|
}
|
|
insensitive := f.GetString_().GetCaseInsensitive()
|
|
rval := m.ProtoReflect().Get(fd)
|
|
value := rval.String()
|
|
if fd.Kind() == pref.EnumKind {
|
|
e := fd.Enum().Values().ByNumber(rval.Enum())
|
|
if e == nil {
|
|
return false, nil
|
|
}
|
|
value = string(e.Name())
|
|
}
|
|
var match bool
|
|
switch f.GetString_().GetCondition().(type) {
|
|
case *pf.StringFilter_Equals:
|
|
if insensitive {
|
|
match = strings.ToLower(f.GetString_().GetEquals()) == strings.ToLower(value)
|
|
} else {
|
|
match = value == f.GetString_().GetEquals()
|
|
}
|
|
case *pf.StringFilter_Regex:
|
|
reg, err := regexp.Compile(f.GetString_().GetRegex())
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
match = reg.MatchString(value)
|
|
case *pf.StringFilter_In_:
|
|
lookup:
|
|
for _, v := range f.GetString_().GetIn().GetValues() {
|
|
if (insensitive && strings.ToLower(v) == strings.ToLower(value)) || v == value {
|
|
match = true
|
|
break lookup
|
|
}
|
|
}
|
|
}
|
|
if f.GetString_().GetNot() {
|
|
return !match, nil
|
|
}
|
|
return match, nil
|
|
}
|
|
|
|
func matchNumber(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
|
|
rval := m.ProtoReflect().Get(fd)
|
|
var val float64
|
|
switch fd.Kind() {
|
|
case pref.Int32Kind,
|
|
pref.Sint32Kind,
|
|
pref.Int64Kind,
|
|
pref.Sint64Kind,
|
|
pref.Sfixed32Kind,
|
|
pref.Fixed32Kind,
|
|
pref.Sfixed64Kind,
|
|
pref.Fixed64Kind:
|
|
val = float64(rval.Int())
|
|
case pref.Uint32Kind, pref.Uint64Kind:
|
|
val = float64(rval.Uint())
|
|
case pref.FloatKind, pref.DoubleKind:
|
|
val = rval.Float()
|
|
case pref.EnumKind:
|
|
val = float64(rval.Enum())
|
|
default:
|
|
return false, fmt.Errorf("cannot use number filter on %s", fd.Kind().String())
|
|
}
|
|
var match bool
|
|
switch f.GetNumber().GetCondition().(type) {
|
|
case *pf.NumberFilter_Equals:
|
|
match = val == f.GetNumber().GetEquals()
|
|
case *pf.NumberFilter_In_:
|
|
lookup:
|
|
for _, v := range f.GetNumber().GetIn().GetValues() {
|
|
if val == v {
|
|
match = true
|
|
break lookup
|
|
}
|
|
}
|
|
}
|
|
if f.GetNumber().GetNot() {
|
|
return !match, nil
|
|
}
|
|
return match, nil
|
|
}
|
|
|
|
// rangeFields is like strings.Split(path, "."), but avoids allocations by
|
|
// iterating over each field in place and calling a iterator function.
|
|
func rangeFields(path string, f func(field string) (pref.FieldDescriptor, bool)) (pref.FieldDescriptor, bool) {
|
|
for {
|
|
var field string
|
|
if i := strings.IndexByte(path, '.'); i >= 0 {
|
|
field, path = path[:i], path[i:]
|
|
} else {
|
|
field, path = path, ""
|
|
}
|
|
v, ok := f(field)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
if len(path) == 0 {
|
|
return v, true
|
|
}
|
|
path = strings.TrimPrefix(path, ".")
|
|
}
|
|
}
|
|
|