package main import ( "bytes" "fmt" "regexp" "strconv" "strings" "text/scanner" "golang.org/x/net/html" ) type Selector interface { Match(node *html.Node) bool } type SelectorFunc func(nodes []*html.Node) []*html.Node func Select(s Selector) SelectorFunc { // have to define first to be able to do recursion var selectChildren func(node *html.Node) []*html.Node selectChildren = func(node *html.Node) []*html.Node { selected := []*html.Node{} for child := node.FirstChild; child != nil; child = child.NextSibling { if s.Match(child) { selected = append(selected, child) } else { selected = append(selected, selectChildren(child)...) } } return selected } return func(nodes []*html.Node) []*html.Node { selected := []*html.Node{} for _, node := range nodes { selected = append(selected, selectChildren(node)...) } return selected } } // Defined for the '>' selector func SelectNextSibling(s Selector) SelectorFunc { return func(nodes []*html.Node) []*html.Node { selected := []*html.Node{} for _, node := range nodes { for ns := node.NextSibling; ns != nil; ns = ns.NextSibling { if ns.Type == html.ElementNode { if s.Match(ns) { selected = append(selected, ns) } break } } } return selected } } // Defined for the '+' selector func SelectFromChildren(s Selector) SelectorFunc { return func(nodes []*html.Node) []*html.Node { selected := []*html.Node{} for _, node := range nodes { for c := node.FirstChild; c != nil; c = c.NextSibling { if s.Match(c) { selected = append(selected, c) } } } return selected } } type PseudoClass func(*html.Node) bool type CSSSelector struct { Tag string Attrs map[string]*regexp.Regexp Pseudo PseudoClass } func (s CSSSelector) Match(node *html.Node) bool { if node.Type != html.ElementNode { return false } if s.Tag != "" { if s.Tag != node.DataAtom.String() { return false } } for attrKey, matcher := range s.Attrs { matched := false for _, attr := range node.Attr { if attrKey == attr.Key { if !matcher.MatchString(attr.Val) { return false } matched = true break } } if !matched { return false } } if s.Pseudo == nil { return true } return s.Pseudo(node) } // Parse a selector // e.g. `div#my-button.btn[href^="http"]` func ParseSelector(cmd string) (selector CSSSelector, err error) { selector = CSSSelector{ Tag: "", Attrs: map[string]*regexp.Regexp{}, Pseudo: nil, } var s scanner.Scanner s.Init(strings.NewReader(cmd)) err = ParseTagMatcher(&selector, s) return } // Parse the initial tag // e.g. `div` func ParseTagMatcher(selector *CSSSelector, s scanner.Scanner) error { tag := bytes.NewBuffer([]byte{}) defer func() { selector.Tag = tag.String() }() for { c := s.Next() switch c { case scanner.EOF: return nil case '.': return ParseClassMatcher(selector, s) case '#': return ParseIdMatcher(selector, s) case '[': return ParseAttrMatcher(selector, s) case ':': return ParsePseudo(selector, s) default: if _, err := tag.WriteRune(c); err != nil { return err } } } } // Parse a class matcher // e.g. `.btn` func ParseClassMatcher(selector *CSSSelector, s scanner.Scanner) error { var class bytes.Buffer defer func() { regexpStr := `(\A|\s)` + regexp.QuoteMeta(class.String()) + `(\s|\z)` selector.Attrs["class"] = regexp.MustCompile(regexpStr) }() for { c := s.Next() switch c { case scanner.EOF: return nil case '.': return ParseClassMatcher(selector, s) case '#': return ParseIdMatcher(selector, s) case '[': return ParseAttrMatcher(selector, s) case ':': return ParsePseudo(selector, s) default: if _, err := class.WriteRune(c); err != nil { return err } } } } // Parse an id matcher // e.g. `#my-picture` func ParseIdMatcher(selector *CSSSelector, s scanner.Scanner) error { var id bytes.Buffer defer func() { regexpStr := `^` + regexp.QuoteMeta(id.String()) + `$` selector.Attrs["id"] = regexp.MustCompile(regexpStr) }() for { c := s.Next() switch c { case scanner.EOF: return nil case '.': return ParseClassMatcher(selector, s) case '#': return ParseIdMatcher(selector, s) case '[': return ParseAttrMatcher(selector, s) case ':': return ParsePseudo(selector, s) default: if _, err := id.WriteRune(c); err != nil { return err } } } } // Parse an attribute matcher // e.g. `[attr^="http"]` func ParseAttrMatcher(selector *CSSSelector, s scanner.Scanner) error { var attrKey bytes.Buffer var attrVal bytes.Buffer hasMatchVal := false matchType := '=' defer func() { if hasMatchVal { var regexpStr string switch matchType { case '=': regexpStr = `^` + regexp.QuoteMeta(attrVal.String()) + `$` case '*': regexpStr = regexp.QuoteMeta(attrVal.String()) case '$': regexpStr = regexp.QuoteMeta(attrVal.String()) + `$` case '^': regexpStr = `^` + regexp.QuoteMeta(attrVal.String()) case '~': regexpStr = `(\A|\s)` + regexp.QuoteMeta(attrVal.String()) + `(\s|\z)` } selector.Attrs[attrKey.String()] = regexp.MustCompile(regexpStr) } else { selector.Attrs[attrKey.String()] = regexp.MustCompile(`^.*$`) } }() // After reaching ']' proceed proceed := func() error { switch s.Next() { case scanner.EOF: return nil case '.': return ParseClassMatcher(selector, s) case '#': return ParseIdMatcher(selector, s) case '[': return ParseAttrMatcher(selector, s) case ':': return ParsePseudo(selector, s) default: return fmt.Errorf("Expected selector indicator after ']'") } } // Parse the attribute key matcher for !hasMatchVal { c := s.Next() switch c { case scanner.EOF: return fmt.Errorf("Unmatched open brace '['") case ']': // No attribute value matcher, proceed! return proceed() case '$', '^', '~', '*': matchType = c hasMatchVal = true if s.Next() != '=' { return fmt.Errorf("'%c' must be followed by a '='", matchType) } case '=': matchType = c hasMatchVal = true default: if _, err := attrKey.WriteRune(c); err != nil { return err } } } // figure out if the value is quoted c := s.Next() inQuote := false switch c { case scanner.EOF: return fmt.Errorf("Unmatched open brace '['") case ']': return proceed() case '"': inQuote = true default: if _, err := attrVal.WriteRune(c); err != nil { return err } } if inQuote { for { c := s.Next() switch c { case '\\': // consume another character if c = s.Next(); c == scanner.EOF { return fmt.Errorf("Unmatched open brace '['") } case '"': switch s.Next() { case ']': return proceed() default: return fmt.Errorf("Quote must end at ']'") } } if _, err := attrVal.WriteRune(c); err != nil { return err } } } else { for { c := s.Next() switch c { case scanner.EOF: return fmt.Errorf("Unmatched open brace '['") case ']': // No attribute value matcher, proceed! return proceed() } if _, err := attrVal.WriteRune(c); err != nil { return err } } } } // Parse the selector after ':' func ParsePseudo(selector *CSSSelector, s scanner.Scanner) error { if selector.Pseudo != nil { return fmt.Errorf("Combined multiple pseudo classes") } var b bytes.Buffer for s.Peek() != scanner.EOF { if _, err := b.WriteRune(s.Next()); err != nil { return err } } cmd := b.String() var err error switch { case cmd == "empty": selector.Pseudo = func(n *html.Node) bool { return n.FirstChild == nil } case cmd == "first-child": selector.Pseudo = firstChildPseudo case cmd == "last-child": selector.Pseudo = lastChildPseudo case cmd == "only-child": selector.Pseudo = func(n *html.Node) bool { return firstChildPseudo(n) && lastChildPseudo(n) } case cmd == "first-of-type": selector.Pseudo = firstOfTypePseudo case cmd == "last-of-type": selector.Pseudo = lastOfTypePseudo case cmd == "only-of-type": selector.Pseudo = func(n *html.Node) bool { return firstOfTypePseudo(n) && lastOfTypePseudo(n) } case strings.HasPrefix(cmd, "contains("): selector.Pseudo, err = parseContainsPseudo(cmd[len("contains("):]) if err != nil { return err } case strings.HasPrefix(cmd, "nth-child("), strings.HasPrefix(cmd, "nth-last-child("), strings.HasPrefix(cmd, "nth-last-of-type("), strings.HasPrefix(cmd, "nth-of-type("): if selector.Pseudo, err = parseNthPseudo(cmd); err != nil { return err } case strings.HasPrefix(cmd, "not("): if selector.Pseudo, err = parseNotPseudo(cmd[len("not("):]); err != nil { return err } case strings.HasPrefix(cmd, "parent-of("): if selector.Pseudo, err = parseParentOfPseudo(cmd[len("parent-of("):]); err != nil { return err } default: return fmt.Errorf("%s not a valid pseudo class", cmd) } return nil } // :first-of-child func firstChildPseudo(n *html.Node) bool { for c := n.PrevSibling; c != nil; c = c.PrevSibling { if c.Type == html.ElementNode { return false } } return true } // :last-of-child func lastChildPseudo(n *html.Node) bool { for c := n.NextSibling; c != nil; c = c.NextSibling { if c.Type == html.ElementNode { return false } } return true } // :first-of-type func firstOfTypePseudo(node *html.Node) bool { if node.Type != html.ElementNode { return false } for n := node.PrevSibling; n != nil; n = n.PrevSibling { if n.DataAtom == node.DataAtom { return false } } return true } // :last-of-type func lastOfTypePseudo(node *html.Node) bool { if node.Type != html.ElementNode { return false } for n := node.NextSibling; n != nil; n = n.NextSibling { if n.DataAtom == node.DataAtom { return false } } return true } func parseNthPseudo(cmd string) (PseudoClass, error) { i := strings.IndexRune(cmd, '(') if i < 0 { // really, we should never get here return nil, fmt.Errorf("Fatal error, '%s' does not contain a '('", cmd) } pseudoName := cmd[:i] // Figure out how the counting function works var countNth func(*html.Node) int switch pseudoName { case "nth-child": countNth = func(n *html.Node) int { nth := 1 for sib := n.PrevSibling; sib != nil; sib = sib.PrevSibling { if sib.Type == html.ElementNode { nth++ } } return nth } case "nth-of-type": countNth = func(n *html.Node) int { nth := 1 for sib := n.PrevSibling; sib != nil; sib = sib.PrevSibling { if sib.Type == html.ElementNode && sib.DataAtom == n.DataAtom { nth++ } } return nth } case "nth-last-child": countNth = func(n *html.Node) int { nth := 1 for sib := n.NextSibling; sib != nil; sib = sib.NextSibling { if sib.Type == html.ElementNode { nth++ } } return nth } case "nth-last-of-type": countNth = func(n *html.Node) int { nth := 1 for sib := n.NextSibling; sib != nil; sib = sib.NextSibling { if sib.Type == html.ElementNode && sib.DataAtom == n.DataAtom { nth++ } } return nth } default: return nil, fmt.Errorf("Unrecognized pseudo '%s'", pseudoName) } nthString := cmd[i+1:] i = strings.IndexRune(nthString, ')') if i < 0 { return nil, fmt.Errorf("Unmatched '(' for pseudo class %s", pseudoName) } else if i != len(nthString)-1 { return nil, fmt.Errorf("%s(n) must end selector", pseudoName) } number := nthString[:i] // Check if the number is 'odd' or 'even' oddOrEven := -1 switch number { case "odd": oddOrEven = 1 case "even": oddOrEven = 0 } if oddOrEven > -1 { return func(n *html.Node) bool { return n.Type == html.ElementNode && countNth(n)%2 == oddOrEven }, nil } // Check against '3n+4' pattern r := regexp.MustCompile(`([0-9]+)n[ ]?\+[ ]?([0-9])`) subMatch := r.FindAllStringSubmatch(number, -1) if len(subMatch) == 1 && len(subMatch[0]) == 3 { cycle, _ := strconv.Atoi(subMatch[0][1]) offset, _ := strconv.Atoi(subMatch[0][2]) return func(n *html.Node) bool { return n.Type == html.ElementNode && countNth(n)%cycle == offset }, nil } // check against 'n+2' pattern r = regexp.MustCompile(`n[ ]?\+[ ]?([0-9])`) subMatch = r.FindAllStringSubmatch(number, -1) if len(subMatch) == 1 && len(subMatch[0]) == 2 { offset, _ := strconv.Atoi(subMatch[0][1]) return func(n *html.Node) bool { return n.Type == html.ElementNode && countNth(n) >= offset }, nil } // the only other option is a numeric value nth, err := strconv.Atoi(nthString[:i]) if err != nil { return nil, err } else if nth <= 0 { return nil, fmt.Errorf("Argument to '%s' must be greater than 0", pseudoName) } return func(n *html.Node) bool { return n.Type == html.ElementNode && countNth(n) == nth }, nil } // Parse a :contains("") selector // expects the input to be everything after the open parenthesis // e.g. for `contains("Help")` the argument would be `"Help")` func parseContainsPseudo(cmd string) (PseudoClass, error) { var s scanner.Scanner s.Init(strings.NewReader(cmd)) switch s.Next() { case '"': default: return nil, fmt.Errorf("Malformed 'contains(\"\")' selector") } textToContain := bytes.NewBuffer([]byte{}) for { r := s.Next() switch r { case '"': // ')' then EOF must follow '"' if s.Next() != ')' { return nil, fmt.Errorf("Malformed 'contains(\"\")' selector") } if s.Next() != scanner.EOF { return nil, fmt.Errorf("'contains(\"\")' must end selector") } text := textToContain.String() contains := func(node *html.Node) bool { for c := node.FirstChild; c != nil; c = c.NextSibling { if c.Type == html.TextNode { if strings.Contains(c.Data, text) { return true } } } return false } return contains, nil case '\\': s.Next() case scanner.EOF: return nil, fmt.Errorf("Malformed 'contains(\"\")' selector") default: if _, err := textToContain.WriteRune(r); err != nil { return nil, err } } } } // Parse a :not(selector) selector // expects the input to be everything after the open parenthesis // e.g. for `not(div#id)` the argument would be `div#id)` func parseNotPseudo(cmd string) (PseudoClass, error) { if len(cmd) < 2 { return nil, fmt.Errorf("malformed ':not' selector") } endQuote, cmd := cmd[len(cmd)-1], cmd[:len(cmd)-1] selector, err := ParseSelector(cmd) if err != nil { return nil, err } if endQuote != ')' { return nil, fmt.Errorf("unmatched '('") } return func(n *html.Node) bool { return !selector.Match(n) }, nil } // Parse a :parent-of(selector) selector // expects the input to be everything after the open parenthesis // e.g. for `parent-of(div#id)` the argument would be `div#id)` func parseParentOfPseudo(cmd string) (PseudoClass, error) { if len(cmd) < 2 { return nil, fmt.Errorf("malformed ':parent-of' selector") } endQuote, cmd := cmd[len(cmd)-1], cmd[:len(cmd)-1] selector, err := ParseSelector(cmd) if err != nil { return nil, err } if endQuote != ')' { return nil, fmt.Errorf("unmatched '('") } return func(n *html.Node) bool { for c := n.FirstChild; c != nil; c = c.NextSibling { if c.Type == html.ElementNode && selector.Match(c) { return true } } return false }, nil }