269 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			269 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package matcher
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"regexp"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 
 | |
| 	"google.golang.org/protobuf/proto"
 | |
| 	pref "google.golang.org/protobuf/reflect/protoreflect"
 | |
| 
 | |
| 	pf "go.linka.cloud/protofilters"
 | |
| )
 | |
| 
 | |
| type Matcher interface {
 | |
| 	Match(m proto.Message, f *pf.FieldsFilter) (bool, error)
 | |
| 	MatchFilters(m proto.Message, fs ...*pf.FieldFilter) (bool, error)
 | |
| }
 | |
| 
 | |
| type CachingMatcher interface {
 | |
| 	Matcher
 | |
| 	ResetCache()
 | |
| }
 | |
| 
 | |
| var defaultMatcher CachingMatcher = &matcher{cache: make(map[string]pref.FieldDescriptor)}
 | |
| 
 | |
| func Match(m proto.Message, f *pf.FieldsFilter) (bool, error) {
 | |
| 	return defaultMatcher.Match(m, f)
 | |
| }
 | |
| 
 | |
| func MatchFilters(m proto.Message, fs ...*pf.FieldFilter) (bool, error) {
 | |
| 	return defaultMatcher.MatchFilters(m, fs...)
 | |
| }
 | |
| 
 | |
| type matcher struct {
 | |
| 	mu sync.RWMutex
 | |
| 	cache map[string]pref.FieldDescriptor
 | |
| }
 | |
| 
 | |
| func (x *matcher) Match(m proto.Message, f *pf.FieldsFilter) (bool, error) {
 | |
| 	if m == nil {
 | |
| 		return false, errors.New("message is null")
 | |
| 	}
 | |
| 	if f == nil {
 | |
| 		return true, nil
 | |
| 	}
 | |
| 	for path, filter := range f.Filters {
 | |
| 		fd, err := x.lookup(m, path)
 | |
| 		if err != nil {
 | |
| 			return false, err
 | |
| 		}
 | |
| 		ok, err := match(m, fd, filter)
 | |
| 		if err != nil {
 | |
| 			return false, err
 | |
| 		}
 | |
| 		if !ok {
 | |
| 			return false, nil
 | |
| 		}
 | |
| 	}
 | |
| 	return true, nil
 | |
| }
 | |
| 
 | |
| func (x *matcher) MatchFilters(m proto.Message, fs ...*pf.FieldFilter) (bool, error) {
 | |
| 	f := pf.New(fs...)
 | |
| 	return x.Match(m, f)
 | |
| }
 | |
| 
 | |
| func (x *matcher) ResetCache() {
 | |
| 	x.mu.Lock()
 | |
| 	x.cache = make(map[string]pref.FieldDescriptor)
 | |
| 	x.mu.Unlock()
 | |
| }
 | |
| 
 | |
| func (x *matcher) lookup(m proto.Message, path string) (pref.FieldDescriptor, error) {
 | |
| 	if x.cache == nil {
 | |
| 		x.mu.Lock()
 | |
| 		x.cache = make(map[string]pref.FieldDescriptor)
 | |
| 		x.mu.Unlock()
 | |
| 	}
 | |
| 	key := fmt.Sprintf("%s.%s", m.ProtoReflect().Descriptor().FullName(), path)
 | |
| 	x.mu.RLock()
 | |
| 	fd, ok := x.cache[key]
 | |
| 	x.mu.RUnlock()
 | |
| 	if ok {
 | |
| 		return fd, nil
 | |
| 	}
 | |
| 	md0 := m.ProtoReflect().Descriptor()
 | |
| 	md := md0
 | |
| 	fd, ok = rangeFields(path, func(field string) (pref.FieldDescriptor, bool) {
 | |
| 		// Search the field within the message.
 | |
| 		if md == nil {
 | |
| 			return nil, false // not within a message
 | |
| 		}
 | |
| 		fd := md.Fields().ByName(pref.Name(field))
 | |
| 		// The real field name of a group is the message name.
 | |
| 		if fd == nil {
 | |
| 			gd := md.Fields().ByName(pref.Name(strings.ToLower(field)))
 | |
| 			if gd != nil && gd.Kind() == pref.GroupKind && string(gd.Message().Name()) == field {
 | |
| 				fd = gd
 | |
| 			}
 | |
| 		} else if fd.Kind() == pref.GroupKind && string(fd.Message().Name()) != field {
 | |
| 			fd = nil
 | |
| 		}
 | |
| 		if fd == nil {
 | |
| 			return nil, false // message does not have this field
 | |
| 		}
 | |
| 		// Identify the next message to search within.
 | |
| 		md = fd.Message() // may be nil
 | |
| 
 | |
| 		// Repeated fields are only allowed at the last postion.
 | |
| 		if fd.IsList() || fd.IsMap() {
 | |
| 			md = nil
 | |
| 		}
 | |
| 
 | |
| 		return fd, true
 | |
| 	})
 | |
| 	if !ok {
 | |
| 		return nil, fmt.Errorf("%s does not contain '%s'", md0.FullName(), path)
 | |
| 	}
 | |
| 	x.mu.Lock()
 | |
| 	x.cache[key] = fd
 | |
| 	x.mu.Unlock()
 | |
| 	return fd, nil
 | |
| }
 | |
| 
 | |
| func match(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
 | |
| 	switch f.GetMatch().(type) {
 | |
| 	case *pf.Filter_String_:
 | |
| 		return matchString(m, fd, f)
 | |
| 	case *pf.Filter_Number:
 | |
| 		return matchNumber(m, fd, f)
 | |
| 	case *pf.Filter_Bool:
 | |
| 		return matchBool(m, fd, f)
 | |
| 	case *pf.Filter_Null:
 | |
| 		return matchNull(m, fd, f)
 | |
| 	}
 | |
| 	return false, nil
 | |
| }
 | |
| 
 | |
| func matchNull(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
 | |
| 	var match bool
 | |
| 	switch fd.Kind() {
 | |
| 	case pref.MessageKind:
 | |
| 		match = !m.ProtoReflect().Has(fd)
 | |
| 	case pref.GroupKind:
 | |
| 		match = m.ProtoReflect().Get(fd).List().Len() == 0
 | |
| 	default:
 | |
| 		return false, fmt.Errorf("cannot use null filter on %s", fd.Kind().String())
 | |
| 	}
 | |
| 	if f.GetNull().GetNot() {
 | |
| 		return !match, nil
 | |
| 	}
 | |
| 	return match, nil
 | |
| }
 | |
| 
 | |
| func matchBool(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
 | |
| 	if fd.Kind() != pref.BoolKind {
 | |
| 		return false, fmt.Errorf("cannot use bool filter on %s", fd.Kind().String())
 | |
| 	}
 | |
| 	return m.ProtoReflect().Get(fd).Bool() == f.GetBool().GetEquals(), nil
 | |
| }
 | |
| 
 | |
| func matchString(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
 | |
| 	if fd.Kind() != pref.StringKind && fd.Kind() != pref.EnumKind {
 | |
| 		return false, fmt.Errorf("cannot use string filter on %s", fd.Kind().String())
 | |
| 	}
 | |
| 	insensitive := f.GetString_().GetCaseInsensitive()
 | |
| 	rval := m.ProtoReflect().Get(fd)
 | |
| 	value := rval.String()
 | |
| 	if fd.Kind() == pref.EnumKind {
 | |
| 		e := fd.Enum().Values().ByNumber(rval.Enum())
 | |
| 		if e == nil {
 | |
| 			return false, nil
 | |
| 		}
 | |
| 		value = string(e.Name())
 | |
| 	}
 | |
| 	var match bool
 | |
| 	switch f.GetString_().GetCondition().(type) {
 | |
| 	case *pf.StringFilter_Equals:
 | |
| 		if insensitive {
 | |
| 			match = strings.ToLower(f.GetString_().GetEquals()) == strings.ToLower(value)
 | |
| 		} else {
 | |
| 			match = value == f.GetString_().GetEquals()
 | |
| 		}
 | |
| 	case *pf.StringFilter_Regex:
 | |
| 		reg, err := regexp.Compile(f.GetString_().GetRegex())
 | |
| 		if err != nil {
 | |
| 			return false, err
 | |
| 		}
 | |
| 		match = reg.MatchString(value)
 | |
| 	case *pf.StringFilter_In_:
 | |
| 	lookup:
 | |
| 		for _, v := range f.GetString_().GetIn().GetValues() {
 | |
| 			if (insensitive && strings.ToLower(v) == strings.ToLower(value)) || v == value {
 | |
| 				match = true
 | |
| 				break lookup
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	if f.GetString_().GetNot() {
 | |
| 		return !match, nil
 | |
| 	}
 | |
| 	return match, nil
 | |
| }
 | |
| 
 | |
| func matchNumber(m proto.Message, fd pref.FieldDescriptor, f *pf.Filter) (bool, error) {
 | |
| 	rval := m.ProtoReflect().Get(fd)
 | |
| 	var val float64
 | |
| 	switch fd.Kind() {
 | |
| 	case pref.Int32Kind,
 | |
| 		pref.Sint32Kind,
 | |
| 		pref.Int64Kind,
 | |
| 		pref.Sint64Kind,
 | |
| 		pref.Sfixed32Kind,
 | |
| 		pref.Fixed32Kind,
 | |
| 		pref.Sfixed64Kind,
 | |
| 		pref.Fixed64Kind:
 | |
| 		val = float64(rval.Int())
 | |
| 	case pref.Uint32Kind, pref.Uint64Kind:
 | |
| 		val = float64(rval.Uint())
 | |
| 	case pref.FloatKind, pref.DoubleKind:
 | |
| 		val = rval.Float()
 | |
| 	case pref.EnumKind:
 | |
| 		val = float64(rval.Enum())
 | |
| 	default:
 | |
| 		return false, fmt.Errorf("cannot use number filter on %s", fd.Kind().String())
 | |
| 	}
 | |
| 	var match bool
 | |
| 	switch f.GetNumber().GetCondition().(type) {
 | |
| 	case *pf.NumberFilter_Equals:
 | |
| 		match = val == f.GetNumber().GetEquals()
 | |
| 	case *pf.NumberFilter_In_:
 | |
| 	lookup:
 | |
| 		for _, v := range f.GetNumber().GetIn().GetValues() {
 | |
| 			if val == v {
 | |
| 				match = true
 | |
| 				break lookup
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	if f.GetNumber().GetNot() {
 | |
| 		return !match, nil
 | |
| 	}
 | |
| 	return match, nil
 | |
| }
 | |
| 
 | |
| // rangeFields is like strings.Split(path, "."), but avoids allocations by
 | |
| // iterating over each field in place and calling a iterator function.
 | |
| func rangeFields(path string, f func(field string) (pref.FieldDescriptor, bool)) (pref.FieldDescriptor, bool) {
 | |
| 	for {
 | |
| 		var field string
 | |
| 		if i := strings.IndexByte(path, '.'); i >= 0 {
 | |
| 			field, path = path[:i], path[i:]
 | |
| 		} else {
 | |
| 			field, path = path, ""
 | |
| 		}
 | |
| 		v, ok := f(field)
 | |
| 		if !ok {
 | |
| 			return nil, false
 | |
| 		}
 | |
| 		if len(path) == 0 {
 | |
| 			return v, true
 | |
| 		}
 | |
| 		path = strings.TrimPrefix(path, ".")
 | |
| 	}
 | |
| }
 | |
| 
 |