1
0
mirror of https://github.com/ericchiang/pup synced 2024-10-31 20:58:59 +00:00
pup/selector.go
2016-08-19 18:57:26 +02:00

639 lines
15 KiB
Go

package main
import (
"bytes"
"fmt"
"regexp"
"strconv"
"strings"
"text/scanner"
"golang.org/x/net/html"
)
type Selector interface {
Match(node *html.Node) bool
}
type SelectorFunc func(nodes []*html.Node) []*html.Node
func Select(s Selector) SelectorFunc {
// have to define first to be able to do recursion
var selectChildren func(node *html.Node) []*html.Node
selectChildren = func(node *html.Node) []*html.Node {
selected := []*html.Node{}
for child := node.FirstChild; child != nil; child = child.NextSibling {
if s.Match(child) {
selected = append(selected, child)
} else {
selected = append(selected, selectChildren(child)...)
}
}
return selected
}
return func(nodes []*html.Node) []*html.Node {
selected := []*html.Node{}
for _, node := range nodes {
selected = append(selected, selectChildren(node)...)
}
return selected
}
}
// Defined for the '>' selector
func SelectNextSibling(s Selector) SelectorFunc {
return func(nodes []*html.Node) []*html.Node {
selected := []*html.Node{}
for _, node := range nodes {
for ns := node.NextSibling; ns != nil; ns = ns.NextSibling {
if ns.Type == html.ElementNode {
if s.Match(ns) {
selected = append(selected, ns)
}
break
}
}
}
return selected
}
}
// Defined for the '+' selector
func SelectFromChildren(s Selector) SelectorFunc {
return func(nodes []*html.Node) []*html.Node {
selected := []*html.Node{}
for _, node := range nodes {
for c := node.FirstChild; c != nil; c = c.NextSibling {
if s.Match(c) {
selected = append(selected, c)
}
}
}
return selected
}
}
type PseudoClass func(*html.Node) bool
type CSSSelector struct {
Tag string
Attrs map[string]*regexp.Regexp
Pseudo PseudoClass
}
func (s CSSSelector) Match(node *html.Node) bool {
if node.Type != html.ElementNode {
return false
}
if s.Tag != "" {
if s.Tag != node.DataAtom.String() {
return false
}
}
for attrKey, matcher := range s.Attrs {
matched := false
for _, attr := range node.Attr {
if attrKey == attr.Key {
if !matcher.MatchString(attr.Val) {
return false
}
matched = true
break
}
}
if !matched {
return false
}
}
if s.Pseudo == nil {
return true
}
return s.Pseudo(node)
}
// Parse a selector
// e.g. `div#my-button.btn[href^="http"]`
func ParseSelector(cmd string) (selector CSSSelector, err error) {
selector = CSSSelector{
Tag: "",
Attrs: map[string]*regexp.Regexp{},
Pseudo: nil,
}
var s scanner.Scanner
s.Init(strings.NewReader(cmd))
err = ParseTagMatcher(&selector, s)
return
}
// Parse the initial tag
// e.g. `div`
func ParseTagMatcher(selector *CSSSelector, s scanner.Scanner) error {
tag := bytes.NewBuffer([]byte{})
defer func() {
selector.Tag = tag.String()
}()
for {
c := s.Next()
switch c {
case scanner.EOF:
return nil
case '.':
return ParseClassMatcher(selector, s)
case '#':
return ParseIdMatcher(selector, s)
case '[':
return ParseAttrMatcher(selector, s)
case ':':
return ParsePseudo(selector, s)
default:
if _, err := tag.WriteRune(c); err != nil {
return err
}
}
}
}
// Parse a class matcher
// e.g. `.btn`
func ParseClassMatcher(selector *CSSSelector, s scanner.Scanner) error {
var class bytes.Buffer
defer func() {
regexpStr := `(\A|\s)` + regexp.QuoteMeta(class.String()) + `(\s|\z)`
selector.Attrs["class"] = regexp.MustCompile(regexpStr)
}()
for {
c := s.Next()
switch c {
case scanner.EOF:
return nil
case '.':
return ParseClassMatcher(selector, s)
case '#':
return ParseIdMatcher(selector, s)
case '[':
return ParseAttrMatcher(selector, s)
case ':':
return ParsePseudo(selector, s)
default:
if _, err := class.WriteRune(c); err != nil {
return err
}
}
}
}
// Parse an id matcher
// e.g. `#my-picture`
func ParseIdMatcher(selector *CSSSelector, s scanner.Scanner) error {
var id bytes.Buffer
defer func() {
regexpStr := `^` + regexp.QuoteMeta(id.String()) + `$`
selector.Attrs["id"] = regexp.MustCompile(regexpStr)
}()
for {
c := s.Next()
switch c {
case scanner.EOF:
return nil
case '.':
return ParseClassMatcher(selector, s)
case '#':
return ParseIdMatcher(selector, s)
case '[':
return ParseAttrMatcher(selector, s)
case ':':
return ParsePseudo(selector, s)
default:
if _, err := id.WriteRune(c); err != nil {
return err
}
}
}
}
// Parse an attribute matcher
// e.g. `[attr^="http"]`
func ParseAttrMatcher(selector *CSSSelector, s scanner.Scanner) error {
var attrKey bytes.Buffer
var attrVal bytes.Buffer
hasMatchVal := false
matchType := '='
defer func() {
if hasMatchVal {
var regexpStr string
switch matchType {
case '=':
regexpStr = `^` + regexp.QuoteMeta(attrVal.String()) + `$`
case '*':
regexpStr = regexp.QuoteMeta(attrVal.String())
case '$':
regexpStr = regexp.QuoteMeta(attrVal.String()) + `$`
case '^':
regexpStr = `^` + regexp.QuoteMeta(attrVal.String())
case '~':
regexpStr = `(\A|\s)` + regexp.QuoteMeta(attrVal.String()) + `(\s|\z)`
}
selector.Attrs[attrKey.String()] = regexp.MustCompile(regexpStr)
} else {
selector.Attrs[attrKey.String()] = regexp.MustCompile(`^.*$`)
}
}()
// After reaching ']' proceed
proceed := func() error {
switch s.Next() {
case scanner.EOF:
return nil
case '.':
return ParseClassMatcher(selector, s)
case '#':
return ParseIdMatcher(selector, s)
case '[':
return ParseAttrMatcher(selector, s)
case ':':
return ParsePseudo(selector, s)
default:
return fmt.Errorf("Expected selector indicator after ']'")
}
}
// Parse the attribute key matcher
for !hasMatchVal {
c := s.Next()
switch c {
case scanner.EOF:
return fmt.Errorf("Unmatched open brace '['")
case ']':
// No attribute value matcher, proceed!
return proceed()
case '$', '^', '~', '*':
matchType = c
hasMatchVal = true
if s.Next() != '=' {
return fmt.Errorf("'%c' must be followed by a '='", matchType)
}
case '=':
matchType = c
hasMatchVal = true
default:
if _, err := attrKey.WriteRune(c); err != nil {
return err
}
}
}
// figure out if the value is quoted
c := s.Next()
inQuote := false
switch c {
case scanner.EOF:
return fmt.Errorf("Unmatched open brace '['")
case ']':
return proceed()
case '"':
inQuote = true
default:
if _, err := attrVal.WriteRune(c); err != nil {
return err
}
}
if inQuote {
for {
c := s.Next()
switch c {
case '\\':
// consume another character
if c = s.Next(); c == scanner.EOF {
return fmt.Errorf("Unmatched open brace '['")
}
case '"':
switch s.Next() {
case ']':
return proceed()
default:
return fmt.Errorf("Quote must end at ']'")
}
}
if _, err := attrVal.WriteRune(c); err != nil {
return err
}
}
} else {
for {
c := s.Next()
switch c {
case scanner.EOF:
return fmt.Errorf("Unmatched open brace '['")
case ']':
// No attribute value matcher, proceed!
return proceed()
}
if _, err := attrVal.WriteRune(c); err != nil {
return err
}
}
}
}
// Parse the selector after ':'
func ParsePseudo(selector *CSSSelector, s scanner.Scanner) error {
if selector.Pseudo != nil {
return fmt.Errorf("Combined multiple pseudo classes")
}
var b bytes.Buffer
for s.Peek() != scanner.EOF {
if _, err := b.WriteRune(s.Next()); err != nil {
return err
}
}
cmd := b.String()
var err error
switch {
case cmd == "empty":
selector.Pseudo = func(n *html.Node) bool {
return n.FirstChild == nil
}
case cmd == "first-child":
selector.Pseudo = firstChildPseudo
case cmd == "last-child":
selector.Pseudo = lastChildPseudo
case cmd == "only-child":
selector.Pseudo = func(n *html.Node) bool {
return firstChildPseudo(n) && lastChildPseudo(n)
}
case cmd == "first-of-type":
selector.Pseudo = firstOfTypePseudo
case cmd == "last-of-type":
selector.Pseudo = lastOfTypePseudo
case cmd == "only-of-type":
selector.Pseudo = func(n *html.Node) bool {
return firstOfTypePseudo(n) && lastOfTypePseudo(n)
}
case strings.HasPrefix(cmd, "contains("):
selector.Pseudo, err = parseContainsPseudo(cmd[len("contains("):])
if err != nil {
return err
}
case strings.HasPrefix(cmd, "nth-child("),
strings.HasPrefix(cmd, "nth-last-child("),
strings.HasPrefix(cmd, "nth-last-of-type("),
strings.HasPrefix(cmd, "nth-of-type("):
if selector.Pseudo, err = parseNthPseudo(cmd); err != nil {
return err
}
case strings.HasPrefix(cmd, "not("):
if selector.Pseudo, err = parseNotPseudo(cmd[len("not("):]); err != nil {
return err
}
case strings.HasPrefix(cmd, "parent-of("):
if selector.Pseudo, err = parseParentOfPseudo(cmd[len("parent-of("):]); err != nil {
return err
}
default:
return fmt.Errorf("%s not a valid pseudo class", cmd)
}
return nil
}
// :first-of-child
func firstChildPseudo(n *html.Node) bool {
for c := n.PrevSibling; c != nil; c = c.PrevSibling {
if c.Type == html.ElementNode {
return false
}
}
return true
}
// :last-of-child
func lastChildPseudo(n *html.Node) bool {
for c := n.NextSibling; c != nil; c = c.NextSibling {
if c.Type == html.ElementNode {
return false
}
}
return true
}
// :first-of-type
func firstOfTypePseudo(node *html.Node) bool {
if node.Type != html.ElementNode {
return false
}
for n := node.PrevSibling; n != nil; n = n.PrevSibling {
if n.DataAtom == node.DataAtom {
return false
}
}
return true
}
// :last-of-type
func lastOfTypePseudo(node *html.Node) bool {
if node.Type != html.ElementNode {
return false
}
for n := node.NextSibling; n != nil; n = n.NextSibling {
if n.DataAtom == node.DataAtom {
return false
}
}
return true
}
func parseNthPseudo(cmd string) (PseudoClass, error) {
i := strings.IndexRune(cmd, '(')
if i < 0 {
// really, we should never get here
return nil, fmt.Errorf("Fatal error, '%s' does not contain a '('", cmd)
}
pseudoName := cmd[:i]
// Figure out how the counting function works
var countNth func(*html.Node) int
switch pseudoName {
case "nth-child":
countNth = func(n *html.Node) int {
nth := 1
for sib := n.PrevSibling; sib != nil; sib = sib.PrevSibling {
if sib.Type == html.ElementNode {
nth++
}
}
return nth
}
case "nth-of-type":
countNth = func(n *html.Node) int {
nth := 1
for sib := n.PrevSibling; sib != nil; sib = sib.PrevSibling {
if sib.Type == html.ElementNode && sib.DataAtom == n.DataAtom {
nth++
}
}
return nth
}
case "nth-last-child":
countNth = func(n *html.Node) int {
nth := 1
for sib := n.NextSibling; sib != nil; sib = sib.NextSibling {
if sib.Type == html.ElementNode {
nth++
}
}
return nth
}
case "nth-last-of-type":
countNth = func(n *html.Node) int {
nth := 1
for sib := n.NextSibling; sib != nil; sib = sib.NextSibling {
if sib.Type == html.ElementNode && sib.DataAtom == n.DataAtom {
nth++
}
}
return nth
}
default:
return nil, fmt.Errorf("Unrecognized pseudo '%s'", pseudoName)
}
nthString := cmd[i+1:]
i = strings.IndexRune(nthString, ')')
if i < 0 {
return nil, fmt.Errorf("Unmatched '(' for pseudo class %s", pseudoName)
} else if i != len(nthString)-1 {
return nil, fmt.Errorf("%s(n) must end selector", pseudoName)
}
number := nthString[:i]
// Check if the number is 'odd' or 'even'
oddOrEven := -1
switch number {
case "odd":
oddOrEven = 1
case "even":
oddOrEven = 0
}
if oddOrEven > -1 {
return func(n *html.Node) bool {
return n.Type == html.ElementNode && countNth(n)%2 == oddOrEven
}, nil
}
// Check against '3n+4' pattern
r := regexp.MustCompile(`([0-9]+)n[ ]?\+[ ]?([0-9])`)
subMatch := r.FindAllStringSubmatch(number, -1)
if len(subMatch) == 1 && len(subMatch[0]) == 3 {
cycle, _ := strconv.Atoi(subMatch[0][1])
offset, _ := strconv.Atoi(subMatch[0][2])
return func(n *html.Node) bool {
return n.Type == html.ElementNode && countNth(n)%cycle == offset
}, nil
}
// check against 'n+2' pattern
r = regexp.MustCompile(`n[ ]?\+[ ]?([0-9])`)
subMatch = r.FindAllStringSubmatch(number, -1)
if len(subMatch) == 1 && len(subMatch[0]) == 2 {
offset, _ := strconv.Atoi(subMatch[0][1])
return func(n *html.Node) bool {
return n.Type == html.ElementNode && countNth(n) >= offset
}, nil
}
// the only other option is a numeric value
nth, err := strconv.Atoi(nthString[:i])
if err != nil {
return nil, err
} else if nth <= 0 {
return nil, fmt.Errorf("Argument to '%s' must be greater than 0", pseudoName)
}
return func(n *html.Node) bool {
return n.Type == html.ElementNode && countNth(n) == nth
}, nil
}
// Parse a :contains("") selector
// expects the input to be everything after the open parenthesis
// e.g. for `contains("Help")` the argument would be `"Help")`
func parseContainsPseudo(cmd string) (PseudoClass, error) {
var s scanner.Scanner
s.Init(strings.NewReader(cmd))
switch s.Next() {
case '"':
default:
return nil, fmt.Errorf("Malformed 'contains(\"\")' selector")
}
textToContain := bytes.NewBuffer([]byte{})
for {
r := s.Next()
switch r {
case '"':
// ')' then EOF must follow '"'
if s.Next() != ')' {
return nil, fmt.Errorf("Malformed 'contains(\"\")' selector")
}
if s.Next() != scanner.EOF {
return nil, fmt.Errorf("'contains(\"\")' must end selector")
}
text := textToContain.String()
contains := func(node *html.Node) bool {
for c := node.FirstChild; c != nil; c = c.NextSibling {
if c.Type == html.TextNode {
if strings.Contains(c.Data, text) {
return true
}
}
}
return false
}
return contains, nil
case '\\':
s.Next()
case scanner.EOF:
return nil, fmt.Errorf("Malformed 'contains(\"\")' selector")
default:
if _, err := textToContain.WriteRune(r); err != nil {
return nil, err
}
}
}
}
// Parse a :not(selector) selector
// expects the input to be everything after the open parenthesis
// e.g. for `not(div#id)` the argument would be `div#id)`
func parseNotPseudo(cmd string) (PseudoClass, error) {
if len(cmd) < 2 {
return nil, fmt.Errorf("malformed ':not' selector")
}
endQuote, cmd := cmd[len(cmd)-1], cmd[:len(cmd)-1]
selector, err := ParseSelector(cmd)
if err != nil {
return nil, err
}
if endQuote != ')' {
return nil, fmt.Errorf("unmatched '('")
}
return func(n *html.Node) bool {
return !selector.Match(n)
}, nil
}
// Parse a :parent-of(selector) selector
// expects the input to be everything after the open parenthesis
// e.g. for `parent-of(div#id)` the argument would be `div#id)`
func parseParentOfPseudo(cmd string) (PseudoClass, error) {
if len(cmd) < 2 {
return nil, fmt.Errorf("malformed ':parent-of' selector")
}
endQuote, cmd := cmd[len(cmd)-1], cmd[:len(cmd)-1]
selector, err := ParseSelector(cmd)
if err != nil {
return nil, err
}
if endQuote != ')' {
return nil, fmt.Errorf("unmatched '('")
}
return func(n *html.Node) bool {
for c := n.FirstChild; c != nil; c = c.NextSibling {
if c.Type == html.ElementNode && selector.Match(c) {
return true
}
}
return false
}, nil
}