You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

181 lines
3.4 KiB

  1. package tagengine
  2. import (
  3. "sort"
  4. )
  5. type RuleSet struct {
  6. root *node
  7. maxNgram int
  8. sanitize func(...string) string
  9. rules []*Rule
  10. }
  11. func NewRuleSet() *RuleSet {
  12. return &RuleSet{
  13. root: &node{
  14. Token: "/",
  15. Children: map[string]*node{},
  16. },
  17. sanitize: newSanitizer(),
  18. rules: []*Rule{},
  19. }
  20. }
  21. func NewRuleSetFromList(rules []Rule) *RuleSet {
  22. rs := NewRuleSet()
  23. rs.AddRule(rules...)
  24. return rs
  25. }
  26. func (t *RuleSet) Add(ruleOrGroup ...interface{}) {
  27. for _, ix := range ruleOrGroup {
  28. switch x := ix.(type) {
  29. case Rule:
  30. t.AddRule(x)
  31. case RuleGroup:
  32. t.AddRuleGroup(x)
  33. default:
  34. panic("Add expects either Rule or RuleGroup objects.")
  35. }
  36. }
  37. }
  38. func (t *RuleSet) AddRule(rules ...Rule) {
  39. for _, rule := range rules {
  40. rule := rule
  41. // Make sure rule is well-formed.
  42. rule.normalize()
  43. // Update maxNgram.
  44. N := rule.maxNGram()
  45. if N > t.maxNgram {
  46. t.maxNgram = N
  47. }
  48. t.rules = append(t.rules, &rule)
  49. t.root.AddRule(&rule)
  50. }
  51. }
  52. func (t *RuleSet) AddRuleGroup(ruleGroups ...RuleGroup) {
  53. for _, rg := range ruleGroups {
  54. t.AddRule(rg.ToList()...)
  55. }
  56. }
  57. // MatchRules will return a list of all matching rules. The rules are sorted by
  58. // the match's "score". The best match will be first.
  59. func (t *RuleSet) MatchRules(input string) (rules []*Rule) {
  60. input = t.sanitize(input)
  61. tokens := Tokenize(input, t.maxNgram)
  62. rules = t.root.Match(tokens)
  63. if len(rules) == 0 {
  64. return rules
  65. }
  66. // Check excludes.
  67. l := rules[:0]
  68. for _, r := range rules {
  69. if !r.isExcluded(tokens) {
  70. l = append(l, r)
  71. }
  72. }
  73. rules = l
  74. // Sort rules descending.
  75. sort.Slice(rules, func(i, j int) bool {
  76. return ruleLess(rules[j], rules[i])
  77. })
  78. // Update rule stats.
  79. if len(rules) > 0 {
  80. rules[0].FirstCount++
  81. for _, r := range rules {
  82. r.MatchCount++
  83. }
  84. }
  85. return rules
  86. }
  87. type Match struct {
  88. Tag string
  89. Confidence float64 // In the range (0,1].
  90. }
  91. // Return a list of matches with confidence.
  92. func (t *RuleSet) Match(input string) []Match {
  93. rules := t.MatchRules(input)
  94. if len(rules) == 0 {
  95. return []Match{}
  96. }
  97. if len(rules) == 1 {
  98. return []Match{{
  99. Tag: rules[0].Tag,
  100. Confidence: 1,
  101. }}
  102. }
  103. // Create list of blocked tags.
  104. blocks := map[string]struct{}{}
  105. for _, rule := range rules {
  106. for _, tag := range rule.Blocks {
  107. blocks[tag] = struct{}{}
  108. }
  109. }
  110. // Remove rules for blocked tags.
  111. iOut := 0
  112. for _, rule := range rules {
  113. if _, ok := blocks[rule.Tag]; ok {
  114. continue
  115. }
  116. rules[iOut] = rule
  117. iOut++
  118. }
  119. rules = rules[:iOut]
  120. // Matches by index.
  121. matches := map[string]int{}
  122. out := []Match{}
  123. sum := float64(0)
  124. for _, rule := range rules {
  125. idx, ok := matches[rule.Tag]
  126. if !ok {
  127. idx = len(matches)
  128. matches[rule.Tag] = idx
  129. out = append(out, Match{Tag: rule.Tag})
  130. }
  131. out[idx].Confidence += float64(rule.score)
  132. sum += float64(rule.score)
  133. }
  134. for i := range out {
  135. out[i].Confidence /= sum
  136. }
  137. return out
  138. }
  139. // ListRules returns rules used in the ruleset sorted by the rules'
  140. // FirstCount. This is the number of times the given rule was the best match to
  141. // an input.
  142. func (t *RuleSet) ListRules() []*Rule {
  143. sort.Slice(t.rules, func(i, j int) bool {
  144. if t.rules[j].FirstCount != t.rules[i].FirstCount {
  145. return t.rules[j].FirstCount < t.rules[i].FirstCount
  146. }
  147. if t.rules[j].MatchCount != t.rules[i].MatchCount {
  148. return t.rules[j].MatchCount < t.rules[i].MatchCount
  149. }
  150. return t.rules[j].Tag < t.rules[i].Tag
  151. })
  152. return t.rules
  153. }