init
This commit is contained in:
268
matcher/matcher.go
Normal file
268
matcher/matcher.go
Normal file
@ -0,0 +1,268 @@
|
||||
package matcher
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
pref "google.golang.org/protobuf/reflect/protoreflect"
|
||||
|
||||
pf "go.linka.cloud/protofilters"
|
||||
)
|
||||
|
||||
type Matcher interface {
|
||||
Match(m proto.Message, f *pf.FieldsFilter) (bool, error)
|
||||
MatchFilters(m proto.Message, fs ...*pf.FieldFilter) (bool, error)
|
||||
}
|
||||
|
||||
type CachingMatcher interface {
|
||||
Matcher
|
||||
ResetCache()
|
||||
}
|
||||
|
||||
var defaultMatcher CachingMatcher = &matcher{cache: make(map[string]pref.FieldDescriptor)}
|
||||
|
||||
func Match(m proto.Message, f *pf.FieldsFilter) (bool, error) {
|
||||
return defaultMatcher.Match(m, f)
|
||||
}
|
||||
|
||||
func MatchFilters(m proto.Message, fs ...*pf.FieldFilter) (bool, error) {
|
||||
return defaultMatcher.MatchFilters(m, fs...)
|
||||
}
|
||||
|
||||
type matcher struct {
|
||||
mu sync.RWMutex
|
||||
cache map[string]pref.FieldDescriptor
|
||||
}
|
||||
|
||||
func (x *matcher) Match(m proto.Message, f *pf.FieldsFilter) (bool, error) {
|
||||
if m == nil {
|
||||
return false, errors.New("message is null")
|
||||
}
|
||||
if f == nil {
|
||||
return true, nil
|
||||
}
|
||||
for path, filter := range f.Filters {
|
||||
fd, err := x.lookup(m, path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
ok, err := match(m, fd, filter)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (x *matcher) MatchFilters(m proto.Message, fs ...*pf.FieldFilter) (bool, error) {
|
||||
f := pf.New(fs...)
|
||||
return x.Match(m, f)
|
||||
}
|
||||
|
||||
func (x *matcher) ResetCache() {
|
||||
x.mu.Lock()
|
||||
x.cache = make(map[string]pref.FieldDescriptor)
|
||||
x.mu.Unlock()
|
||||
}
|
||||
|
||||
func (x *matcher) lookup(m proto.Message, path string) (pref.FieldDescriptor, error) {
|
||||
if x.cache == nil {
|
||||
x.mu.Lock()
|
||||
x.cache = make(map[string]pref.FieldDescriptor)
|
||||
x.mu.Unlock()
|
||||
}
|
||||
key := fmt.Sprintf("%s.%s", m.ProtoReflect().Descriptor().FullName(), path)
|
||||
x.mu.RLock()
|
||||
fd, ok := x.cache[key]
|
||||
x.mu.RUnlock()
|
||||
if ok {
|
||||
return fd, nil
|
||||
}
|
||||
md0 := m.ProtoReflect().Descriptor()
|
||||
md := md0
|
||||
fd, ok = rangeFields(path, func(field string) (pref.FieldDescriptor, bool) {
|
||||
// Search the field within the message.
|
||||
if md == nil {
|
||||
return nil, false // not within a message
|
||||
}
|
||||
fd := md.Fields().ByName(pref.Name(field))
|
||||
// The real field name of a group is the message name.
|
||||
if fd == nil {
|
||||
gd := md.Fields().ByName(pref.Name(strings.ToLower(field)))
|
||||
if gd != nil && gd.Kind() == pref.GroupKind && string(gd.Message().Name()) == field {
|
||||
fd = gd
|
||||
}
|
||||
} else if fd.Kind() == pref.GroupKind && string(fd.Message().Name()) != field {
|
||||
fd = nil
|
||||
}
|
||||
if fd == nil {
|
||||
return nil, false // message does not have this field
|
||||
}
|
||||
// Identify the next message to search within.
|
||||
md = fd.Message() // may be nil
|
||||
|
||||
// Repeated fields are only allowed at the last postion.
|
||||
if fd.IsList() || fd.IsMap() {
|
||||
md = nil
|
||||
}
|
||||
|
||||
return fd, true
|
||||
})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s does not contain '%s'", md0.FullName(), path)
|
||||
}
|
||||
x.mu.Lock()
|
||||
x.cache[key] = fd
|
||||
x.mu.Unlock()
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
func match(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
|
||||
switch f.GetMatch().(type) {
|
||||
case *pf.Filter_String_:
|
||||
return matchString(m, fd, f)
|
||||
case *pf.Filter_Number:
|
||||
return matchNumber(m, fd, f)
|
||||
case *pf.Filter_Bool:
|
||||
return matchBool(m, fd, f)
|
||||
case *pf.Filter_Null:
|
||||
return matchNull(m, fd, f)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func matchNull(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
|
||||
var match bool
|
||||
switch fd.Kind() {
|
||||
case pref.MessageKind:
|
||||
match = !m.ProtoReflect().Has(fd)
|
||||
case pref.GroupKind:
|
||||
match = m.ProtoReflect().Get(fd).List().Len() == 0
|
||||
default:
|
||||
return false, fmt.Errorf("cannot use null filter on %s", fd.Kind().String())
|
||||
}
|
||||
if f.GetNull().GetNot() {
|
||||
return !match, nil
|
||||
}
|
||||
return match, nil
|
||||
}
|
||||
|
||||
func matchBool(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
|
||||
if fd.Kind() != pref.BoolKind {
|
||||
return false, fmt.Errorf("cannot use bool filter on %s", fd.Kind().String())
|
||||
}
|
||||
return m.ProtoReflect().Get(fd).Bool() == f.GetBool().GetEquals(), nil
|
||||
}
|
||||
|
||||
func matchString(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
|
||||
if fd.Kind() != pref.StringKind && fd.Kind() != pref.EnumKind {
|
||||
return false, fmt.Errorf("cannot use string filter on %s", fd.Kind().String())
|
||||
}
|
||||
insensitive := f.GetString_().GetCaseInsensitive()
|
||||
rval := m.ProtoReflect().Get(fd)
|
||||
value := rval.String()
|
||||
if fd.Kind() == pref.EnumKind {
|
||||
e := fd.Enum().Values().ByNumber(rval.Enum())
|
||||
if e == nil {
|
||||
return false, nil
|
||||
}
|
||||
value = string(e.Name())
|
||||
}
|
||||
var match bool
|
||||
switch f.GetString_().GetCondition().(type) {
|
||||
case *pf.StringFilter_Equals:
|
||||
if insensitive {
|
||||
match = strings.ToLower(f.GetString_().GetEquals()) == strings.ToLower(value)
|
||||
} else {
|
||||
match = value == f.GetString_().GetEquals()
|
||||
}
|
||||
case *pf.StringFilter_Regex:
|
||||
reg, err := regexp.Compile(f.GetString_().GetRegex())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
match = reg.MatchString(value)
|
||||
case *pf.StringFilter_In_:
|
||||
lookup:
|
||||
for _, v := range f.GetString_().GetIn().GetValues() {
|
||||
if (insensitive && strings.ToLower(v) == strings.ToLower(value)) || v == value {
|
||||
match = true
|
||||
break lookup
|
||||
}
|
||||
}
|
||||
}
|
||||
if f.GetString_().GetNot() {
|
||||
return !match, nil
|
||||
}
|
||||
return match, nil
|
||||
}
|
||||
|
||||
func matchNumber(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
|
||||
rval := m.ProtoReflect().Get(fd)
|
||||
var val float64
|
||||
switch fd.Kind() {
|
||||
case pref.Int32Kind,
|
||||
pref.Sint32Kind,
|
||||
pref.Int64Kind,
|
||||
pref.Sint64Kind,
|
||||
pref.Sfixed32Kind,
|
||||
pref.Fixed32Kind,
|
||||
pref.Sfixed64Kind,
|
||||
pref.Fixed64Kind:
|
||||
val = float64(rval.Int())
|
||||
case pref.Uint32Kind, pref.Uint64Kind:
|
||||
val = float64(rval.Uint())
|
||||
case pref.FloatKind, pref.DoubleKind:
|
||||
val = rval.Float()
|
||||
case pref.EnumKind:
|
||||
val = float64(rval.Enum())
|
||||
default:
|
||||
return false, fmt.Errorf("cannot use number filter on %s", fd.Kind().String())
|
||||
}
|
||||
var match bool
|
||||
switch f.GetNumber().GetCondition().(type) {
|
||||
case *pf.NumberFilter_Equals:
|
||||
match = val == f.GetNumber().GetEquals()
|
||||
case *pf.NumberFilter_In_:
|
||||
lookup:
|
||||
for _, v := range f.GetNumber().GetIn().GetValues() {
|
||||
if val == v {
|
||||
match = true
|
||||
break lookup
|
||||
}
|
||||
}
|
||||
}
|
||||
if f.GetNumber().GetNot() {
|
||||
return !match, nil
|
||||
}
|
||||
return match, nil
|
||||
}
|
||||
|
||||
// rangeFields is like strings.Split(path, "."), but avoids allocations by
|
||||
// iterating over each field in place and calling a iterator function.
|
||||
func rangeFields(path string, f func(field string) (pref.FieldDescriptor, bool)) (pref.FieldDescriptor, bool) {
|
||||
for {
|
||||
var field string
|
||||
if i := strings.IndexByte(path, '.'); i >= 0 {
|
||||
field, path = path[:i], path[i:]
|
||||
} else {
|
||||
field, path = path, ""
|
||||
}
|
||||
v, ok := f(field)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if len(path) == 0 {
|
||||
return v, true
|
||||
}
|
||||
path = strings.TrimPrefix(path, ".")
|
||||
}
|
||||
}
|
||||
|
85
matcher/matcher_test.go
Normal file
85
matcher/matcher_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package matcher
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
pf "go.linka.cloud/protofilters"
|
||||
test "go.linka.cloud/protofilters/tests/pb"
|
||||
)
|
||||
|
||||
func TestFieldFilter(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
m := &test.Test{StringField: "ok"}
|
||||
ok, err := Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"noop": pf.StringEquals("ok"),
|
||||
}})
|
||||
assert.Error(err)
|
||||
assert.False(ok)
|
||||
ok, err = Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"messageField": pf.Null(),
|
||||
}})
|
||||
assert.Error(err)
|
||||
assert.False(ok)
|
||||
assert.True(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"string_field": pf.StringEquals("ok"),
|
||||
}}))
|
||||
assert.True(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"string_field": pf.StringIN("other", "ok"),
|
||||
}}))
|
||||
assert.False(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"string_field": pf.StringIN("other", "noop"),
|
||||
}}))
|
||||
assert.True(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"string_field": pf.StringNotIN(),
|
||||
"enum_field": pf.StringIN("NONE"),
|
||||
}}))
|
||||
assert.False(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"string_field": pf.StringNotRegex(`[a-z](.+)`),
|
||||
}}))
|
||||
m.EnumField = test.Test_Type(42)
|
||||
assert.False(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"enum_field": pf.StringIN("OTHER"),
|
||||
}}))
|
||||
assert.True(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"message_field": pf.Null(),
|
||||
}}))
|
||||
m.MessageField = m
|
||||
assert.False(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"message_field": pf.Null(),
|
||||
}}))
|
||||
ok, err = Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"message_field.string_field.message_field": pf.Null(),
|
||||
}})
|
||||
assert.Error(err)
|
||||
assert.False(ok)
|
||||
assert.False(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"message_field.message_field.message_field": pf.Null(),
|
||||
}}))
|
||||
assert.True(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"message_field.message_field.message_field.string_field": pf.StringIN("ok"),
|
||||
}}))
|
||||
|
||||
assert.False(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"enum_field": pf.StringIN("OTHER"),
|
||||
}}))
|
||||
|
||||
m.NumberField = 42
|
||||
assert.False(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"number_field": pf.NumberEquals(0),
|
||||
}}))
|
||||
assert.True(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"number_field": pf.NumberEquals(42),
|
||||
}}))
|
||||
assert.False(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"number_field": pf.NumberIN(0, 22),
|
||||
}}))
|
||||
assert.True(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"enum_field": pf.NumberIN(0, 42),
|
||||
}}))
|
||||
assert.False(Match(m, &pf.FieldsFilter{Filters: map[string]*pf.Filter{
|
||||
"enum_field": pf.NumberNotIN(0, 42),
|
||||
}}))
|
||||
|
||||
}
|
Reference in New Issue
Block a user