mirror of
https://github.com/ericchiang/pup
synced 2025-01-28 16:41:32 +00:00
639 lines
15 KiB
Go
639 lines
15 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"text/scanner"
|
|
|
|
"golang.org/x/net/html"
|
|
)
|
|
|
|
type Selector interface {
|
|
Match(node *html.Node) bool
|
|
}
|
|
|
|
type SelectorFunc func(nodes []*html.Node) []*html.Node
|
|
|
|
func Select(s Selector) SelectorFunc {
|
|
// have to define first to be able to do recursion
|
|
var selectChildren func(node *html.Node) []*html.Node
|
|
selectChildren = func(node *html.Node) []*html.Node {
|
|
selected := []*html.Node{}
|
|
for child := node.FirstChild; child != nil; child = child.NextSibling {
|
|
if s.Match(child) {
|
|
selected = append(selected, child)
|
|
} else {
|
|
selected = append(selected, selectChildren(child)...)
|
|
}
|
|
}
|
|
return selected
|
|
}
|
|
return func(nodes []*html.Node) []*html.Node {
|
|
selected := []*html.Node{}
|
|
for _, node := range nodes {
|
|
selected = append(selected, selectChildren(node)...)
|
|
}
|
|
return selected
|
|
}
|
|
}
|
|
|
|
// Defined for the '>' selector
|
|
func SelectNextSibling(s Selector) SelectorFunc {
|
|
return func(nodes []*html.Node) []*html.Node {
|
|
selected := []*html.Node{}
|
|
for _, node := range nodes {
|
|
for ns := node.NextSibling; ns != nil; ns = ns.NextSibling {
|
|
if ns.Type == html.ElementNode {
|
|
if s.Match(ns) {
|
|
selected = append(selected, ns)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return selected
|
|
}
|
|
}
|
|
|
|
// Defined for the '+' selector
|
|
func SelectFromChildren(s Selector) SelectorFunc {
|
|
return func(nodes []*html.Node) []*html.Node {
|
|
selected := []*html.Node{}
|
|
for _, node := range nodes {
|
|
for c := node.FirstChild; c != nil; c = c.NextSibling {
|
|
if s.Match(c) {
|
|
selected = append(selected, c)
|
|
}
|
|
}
|
|
}
|
|
return selected
|
|
}
|
|
}
|
|
|
|
type PseudoClass func(*html.Node) bool
|
|
|
|
type CSSSelector struct {
|
|
Tag string
|
|
Attrs map[string]*regexp.Regexp
|
|
Pseudo PseudoClass
|
|
}
|
|
|
|
func (s CSSSelector) Match(node *html.Node) bool {
|
|
if node.Type != html.ElementNode {
|
|
return false
|
|
}
|
|
if s.Tag != "" {
|
|
if s.Tag != node.DataAtom.String() {
|
|
return false
|
|
}
|
|
}
|
|
for attrKey, matcher := range s.Attrs {
|
|
matched := false
|
|
for _, attr := range node.Attr {
|
|
if attrKey == attr.Key {
|
|
if !matcher.MatchString(attr.Val) {
|
|
return false
|
|
}
|
|
matched = true
|
|
break
|
|
}
|
|
}
|
|
if !matched {
|
|
return false
|
|
}
|
|
}
|
|
if s.Pseudo == nil {
|
|
return true
|
|
}
|
|
return s.Pseudo(node)
|
|
}
|
|
|
|
// Parse a selector
|
|
// e.g. `div#my-button.btn[href^="http"]`
|
|
func ParseSelector(cmd string) (selector CSSSelector, err error) {
|
|
selector = CSSSelector{
|
|
Tag: "",
|
|
Attrs: map[string]*regexp.Regexp{},
|
|
Pseudo: nil,
|
|
}
|
|
var s scanner.Scanner
|
|
s.Init(strings.NewReader(cmd))
|
|
err = ParseTagMatcher(&selector, s)
|
|
return
|
|
}
|
|
|
|
// Parse the initial tag
|
|
// e.g. `div`
|
|
func ParseTagMatcher(selector *CSSSelector, s scanner.Scanner) error {
|
|
tag := bytes.NewBuffer([]byte{})
|
|
defer func() {
|
|
selector.Tag = tag.String()
|
|
}()
|
|
for {
|
|
c := s.Next()
|
|
switch c {
|
|
case scanner.EOF:
|
|
return nil
|
|
case '.':
|
|
return ParseClassMatcher(selector, s)
|
|
case '#':
|
|
return ParseIdMatcher(selector, s)
|
|
case '[':
|
|
return ParseAttrMatcher(selector, s)
|
|
case ':':
|
|
return ParsePseudo(selector, s)
|
|
default:
|
|
if _, err := tag.WriteRune(c); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse a class matcher
|
|
// e.g. `.btn`
|
|
func ParseClassMatcher(selector *CSSSelector, s scanner.Scanner) error {
|
|
var class bytes.Buffer
|
|
defer func() {
|
|
regexpStr := `(\A|\s)` + regexp.QuoteMeta(class.String()) + `(\s|\z)`
|
|
selector.Attrs["class"] = regexp.MustCompile(regexpStr)
|
|
}()
|
|
for {
|
|
c := s.Next()
|
|
switch c {
|
|
case scanner.EOF:
|
|
return nil
|
|
case '.':
|
|
return ParseClassMatcher(selector, s)
|
|
case '#':
|
|
return ParseIdMatcher(selector, s)
|
|
case '[':
|
|
return ParseAttrMatcher(selector, s)
|
|
case ':':
|
|
return ParsePseudo(selector, s)
|
|
default:
|
|
if _, err := class.WriteRune(c); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse an id matcher
|
|
// e.g. `#my-picture`
|
|
func ParseIdMatcher(selector *CSSSelector, s scanner.Scanner) error {
|
|
var id bytes.Buffer
|
|
defer func() {
|
|
regexpStr := `^` + regexp.QuoteMeta(id.String()) + `$`
|
|
selector.Attrs["id"] = regexp.MustCompile(regexpStr)
|
|
}()
|
|
for {
|
|
c := s.Next()
|
|
switch c {
|
|
case scanner.EOF:
|
|
return nil
|
|
case '.':
|
|
return ParseClassMatcher(selector, s)
|
|
case '#':
|
|
return ParseIdMatcher(selector, s)
|
|
case '[':
|
|
return ParseAttrMatcher(selector, s)
|
|
case ':':
|
|
return ParsePseudo(selector, s)
|
|
default:
|
|
if _, err := id.WriteRune(c); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse an attribute matcher
|
|
// e.g. `[attr^="http"]`
|
|
func ParseAttrMatcher(selector *CSSSelector, s scanner.Scanner) error {
|
|
var attrKey bytes.Buffer
|
|
var attrVal bytes.Buffer
|
|
hasMatchVal := false
|
|
matchType := '='
|
|
defer func() {
|
|
if hasMatchVal {
|
|
var regexpStr string
|
|
switch matchType {
|
|
case '=':
|
|
regexpStr = `^` + regexp.QuoteMeta(attrVal.String()) + `$`
|
|
case '*':
|
|
regexpStr = regexp.QuoteMeta(attrVal.String())
|
|
case '$':
|
|
regexpStr = regexp.QuoteMeta(attrVal.String()) + `$`
|
|
case '^':
|
|
regexpStr = `^` + regexp.QuoteMeta(attrVal.String())
|
|
case '~':
|
|
regexpStr = `(\A|\s)` + regexp.QuoteMeta(attrVal.String()) + `(\s|\z)`
|
|
}
|
|
selector.Attrs[attrKey.String()] = regexp.MustCompile(regexpStr)
|
|
} else {
|
|
selector.Attrs[attrKey.String()] = regexp.MustCompile(`^.*$`)
|
|
}
|
|
}()
|
|
// After reaching ']' proceed
|
|
proceed := func() error {
|
|
switch s.Next() {
|
|
case scanner.EOF:
|
|
return nil
|
|
case '.':
|
|
return ParseClassMatcher(selector, s)
|
|
case '#':
|
|
return ParseIdMatcher(selector, s)
|
|
case '[':
|
|
return ParseAttrMatcher(selector, s)
|
|
case ':':
|
|
return ParsePseudo(selector, s)
|
|
default:
|
|
return fmt.Errorf("Expected selector indicator after ']'")
|
|
}
|
|
}
|
|
// Parse the attribute key matcher
|
|
for !hasMatchVal {
|
|
c := s.Next()
|
|
switch c {
|
|
case scanner.EOF:
|
|
return fmt.Errorf("Unmatched open brace '['")
|
|
case ']':
|
|
// No attribute value matcher, proceed!
|
|
return proceed()
|
|
case '$', '^', '~', '*':
|
|
matchType = c
|
|
hasMatchVal = true
|
|
if s.Next() != '=' {
|
|
return fmt.Errorf("'%c' must be followed by a '='", matchType)
|
|
}
|
|
case '=':
|
|
matchType = c
|
|
hasMatchVal = true
|
|
default:
|
|
if _, err := attrKey.WriteRune(c); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
// figure out if the value is quoted
|
|
c := s.Next()
|
|
inQuote := false
|
|
switch c {
|
|
case scanner.EOF:
|
|
return fmt.Errorf("Unmatched open brace '['")
|
|
case ']':
|
|
return proceed()
|
|
case '"':
|
|
inQuote = true
|
|
default:
|
|
if _, err := attrVal.WriteRune(c); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if inQuote {
|
|
for {
|
|
c := s.Next()
|
|
switch c {
|
|
case '\\':
|
|
// consume another character
|
|
if c = s.Next(); c == scanner.EOF {
|
|
return fmt.Errorf("Unmatched open brace '['")
|
|
}
|
|
case '"':
|
|
switch s.Next() {
|
|
case ']':
|
|
return proceed()
|
|
default:
|
|
return fmt.Errorf("Quote must end at ']'")
|
|
}
|
|
}
|
|
if _, err := attrVal.WriteRune(c); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
} else {
|
|
for {
|
|
c := s.Next()
|
|
switch c {
|
|
case scanner.EOF:
|
|
return fmt.Errorf("Unmatched open brace '['")
|
|
case ']':
|
|
// No attribute value matcher, proceed!
|
|
return proceed()
|
|
}
|
|
if _, err := attrVal.WriteRune(c); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse the selector after ':'
|
|
func ParsePseudo(selector *CSSSelector, s scanner.Scanner) error {
|
|
if selector.Pseudo != nil {
|
|
return fmt.Errorf("Combined multiple pseudo classes")
|
|
}
|
|
var b bytes.Buffer
|
|
for s.Peek() != scanner.EOF {
|
|
if _, err := b.WriteRune(s.Next()); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
cmd := b.String()
|
|
var err error
|
|
switch {
|
|
case cmd == "empty":
|
|
selector.Pseudo = func(n *html.Node) bool {
|
|
return n.FirstChild == nil
|
|
}
|
|
case cmd == "first-child":
|
|
selector.Pseudo = firstChildPseudo
|
|
case cmd == "last-child":
|
|
selector.Pseudo = lastChildPseudo
|
|
case cmd == "only-child":
|
|
selector.Pseudo = func(n *html.Node) bool {
|
|
return firstChildPseudo(n) && lastChildPseudo(n)
|
|
}
|
|
case cmd == "first-of-type":
|
|
selector.Pseudo = firstOfTypePseudo
|
|
case cmd == "last-of-type":
|
|
selector.Pseudo = lastOfTypePseudo
|
|
case cmd == "only-of-type":
|
|
selector.Pseudo = func(n *html.Node) bool {
|
|
return firstOfTypePseudo(n) && lastOfTypePseudo(n)
|
|
}
|
|
case strings.HasPrefix(cmd, "contains("):
|
|
selector.Pseudo, err = parseContainsPseudo(cmd[len("contains("):])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case strings.HasPrefix(cmd, "nth-child("),
|
|
strings.HasPrefix(cmd, "nth-last-child("),
|
|
strings.HasPrefix(cmd, "nth-last-of-type("),
|
|
strings.HasPrefix(cmd, "nth-of-type("):
|
|
if selector.Pseudo, err = parseNthPseudo(cmd); err != nil {
|
|
return err
|
|
}
|
|
case strings.HasPrefix(cmd, "not("):
|
|
if selector.Pseudo, err = parseNotPseudo(cmd[len("not("):]); err != nil {
|
|
return err
|
|
}
|
|
case strings.HasPrefix(cmd, "parent-of("):
|
|
if selector.Pseudo, err = parseParentOfPseudo(cmd[len("parent-of("):]); err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return fmt.Errorf("%s not a valid pseudo class", cmd)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// :first-of-child
|
|
func firstChildPseudo(n *html.Node) bool {
|
|
for c := n.PrevSibling; c != nil; c = c.PrevSibling {
|
|
if c.Type == html.ElementNode {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// :last-of-child
|
|
func lastChildPseudo(n *html.Node) bool {
|
|
for c := n.NextSibling; c != nil; c = c.NextSibling {
|
|
if c.Type == html.ElementNode {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// :first-of-type
|
|
func firstOfTypePseudo(node *html.Node) bool {
|
|
if node.Type != html.ElementNode {
|
|
return false
|
|
}
|
|
for n := node.PrevSibling; n != nil; n = n.PrevSibling {
|
|
if n.DataAtom == node.DataAtom {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// :last-of-type
|
|
func lastOfTypePseudo(node *html.Node) bool {
|
|
if node.Type != html.ElementNode {
|
|
return false
|
|
}
|
|
for n := node.NextSibling; n != nil; n = n.NextSibling {
|
|
if n.DataAtom == node.DataAtom {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func parseNthPseudo(cmd string) (PseudoClass, error) {
|
|
i := strings.IndexRune(cmd, '(')
|
|
if i < 0 {
|
|
// really, we should never get here
|
|
return nil, fmt.Errorf("Fatal error, '%s' does not contain a '('", cmd)
|
|
}
|
|
pseudoName := cmd[:i]
|
|
// Figure out how the counting function works
|
|
var countNth func(*html.Node) int
|
|
switch pseudoName {
|
|
case "nth-child":
|
|
countNth = func(n *html.Node) int {
|
|
nth := 1
|
|
for sib := n.PrevSibling; sib != nil; sib = sib.PrevSibling {
|
|
if sib.Type == html.ElementNode {
|
|
nth++
|
|
}
|
|
}
|
|
return nth
|
|
}
|
|
case "nth-of-type":
|
|
countNth = func(n *html.Node) int {
|
|
nth := 1
|
|
for sib := n.PrevSibling; sib != nil; sib = sib.PrevSibling {
|
|
if sib.Type == html.ElementNode && sib.DataAtom == n.DataAtom {
|
|
nth++
|
|
}
|
|
}
|
|
return nth
|
|
}
|
|
case "nth-last-child":
|
|
countNth = func(n *html.Node) int {
|
|
nth := 1
|
|
for sib := n.NextSibling; sib != nil; sib = sib.NextSibling {
|
|
if sib.Type == html.ElementNode {
|
|
nth++
|
|
}
|
|
}
|
|
return nth
|
|
}
|
|
case "nth-last-of-type":
|
|
countNth = func(n *html.Node) int {
|
|
nth := 1
|
|
for sib := n.NextSibling; sib != nil; sib = sib.NextSibling {
|
|
if sib.Type == html.ElementNode && sib.DataAtom == n.DataAtom {
|
|
nth++
|
|
}
|
|
}
|
|
return nth
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("Unrecognized pseudo '%s'", pseudoName)
|
|
}
|
|
|
|
nthString := cmd[i+1:]
|
|
i = strings.IndexRune(nthString, ')')
|
|
if i < 0 {
|
|
return nil, fmt.Errorf("Unmatched '(' for pseudo class %s", pseudoName)
|
|
} else if i != len(nthString)-1 {
|
|
return nil, fmt.Errorf("%s(n) must end selector", pseudoName)
|
|
}
|
|
number := nthString[:i]
|
|
|
|
// Check if the number is 'odd' or 'even'
|
|
oddOrEven := -1
|
|
switch number {
|
|
case "odd":
|
|
oddOrEven = 1
|
|
case "even":
|
|
oddOrEven = 0
|
|
}
|
|
if oddOrEven > -1 {
|
|
return func(n *html.Node) bool {
|
|
return n.Type == html.ElementNode && countNth(n)%2 == oddOrEven
|
|
}, nil
|
|
}
|
|
// Check against '3n+4' pattern
|
|
r := regexp.MustCompile(`([0-9]+)n[ ]?\+[ ]?([0-9])`)
|
|
subMatch := r.FindAllStringSubmatch(number, -1)
|
|
if len(subMatch) == 1 && len(subMatch[0]) == 3 {
|
|
cycle, _ := strconv.Atoi(subMatch[0][1])
|
|
offset, _ := strconv.Atoi(subMatch[0][2])
|
|
return func(n *html.Node) bool {
|
|
return n.Type == html.ElementNode && countNth(n)%cycle == offset
|
|
}, nil
|
|
}
|
|
// check against 'n+2' pattern
|
|
r = regexp.MustCompile(`n[ ]?\+[ ]?([0-9])`)
|
|
subMatch = r.FindAllStringSubmatch(number, -1)
|
|
if len(subMatch) == 1 && len(subMatch[0]) == 2 {
|
|
offset, _ := strconv.Atoi(subMatch[0][1])
|
|
return func(n *html.Node) bool {
|
|
return n.Type == html.ElementNode && countNth(n) >= offset
|
|
}, nil
|
|
}
|
|
// the only other option is a numeric value
|
|
nth, err := strconv.Atoi(nthString[:i])
|
|
if err != nil {
|
|
return nil, err
|
|
} else if nth <= 0 {
|
|
return nil, fmt.Errorf("Argument to '%s' must be greater than 0", pseudoName)
|
|
}
|
|
return func(n *html.Node) bool {
|
|
return n.Type == html.ElementNode && countNth(n) == nth
|
|
}, nil
|
|
}
|
|
|
|
// Parse a :contains("") selector
|
|
// expects the input to be everything after the open parenthesis
|
|
// e.g. for `contains("Help")` the argument would be `"Help")`
|
|
func parseContainsPseudo(cmd string) (PseudoClass, error) {
|
|
var s scanner.Scanner
|
|
s.Init(strings.NewReader(cmd))
|
|
switch s.Next() {
|
|
case '"':
|
|
default:
|
|
return nil, fmt.Errorf("Malformed 'contains(\"\")' selector")
|
|
}
|
|
textToContain := bytes.NewBuffer([]byte{})
|
|
for {
|
|
r := s.Next()
|
|
switch r {
|
|
case '"':
|
|
// ')' then EOF must follow '"'
|
|
if s.Next() != ')' {
|
|
return nil, fmt.Errorf("Malformed 'contains(\"\")' selector")
|
|
}
|
|
if s.Next() != scanner.EOF {
|
|
return nil, fmt.Errorf("'contains(\"\")' must end selector")
|
|
}
|
|
text := textToContain.String()
|
|
contains := func(node *html.Node) bool {
|
|
for c := node.FirstChild; c != nil; c = c.NextSibling {
|
|
if c.Type == html.TextNode {
|
|
if strings.Contains(c.Data, text) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
return contains, nil
|
|
case '\\':
|
|
s.Next()
|
|
case scanner.EOF:
|
|
return nil, fmt.Errorf("Malformed 'contains(\"\")' selector")
|
|
default:
|
|
if _, err := textToContain.WriteRune(r); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse a :not(selector) selector
|
|
// expects the input to be everything after the open parenthesis
|
|
// e.g. for `not(div#id)` the argument would be `div#id)`
|
|
func parseNotPseudo(cmd string) (PseudoClass, error) {
|
|
if len(cmd) < 2 {
|
|
return nil, fmt.Errorf("malformed ':not' selector")
|
|
}
|
|
endQuote, cmd := cmd[len(cmd)-1], cmd[:len(cmd)-1]
|
|
selector, err := ParseSelector(cmd)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if endQuote != ')' {
|
|
return nil, fmt.Errorf("unmatched '('")
|
|
}
|
|
return func(n *html.Node) bool {
|
|
return !selector.Match(n)
|
|
}, nil
|
|
}
|
|
|
|
// Parse a :parent-of(selector) selector
|
|
// expects the input to be everything after the open parenthesis
|
|
// e.g. for `parent-of(div#id)` the argument would be `div#id)`
|
|
func parseParentOfPseudo(cmd string) (PseudoClass, error) {
|
|
if len(cmd) < 2 {
|
|
return nil, fmt.Errorf("malformed ':parent-of' selector")
|
|
}
|
|
endQuote, cmd := cmd[len(cmd)-1], cmd[:len(cmd)-1]
|
|
selector, err := ParseSelector(cmd)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if endQuote != ')' {
|
|
return nil, fmt.Errorf("unmatched '('")
|
|
}
|
|
return func(n *html.Node) bool {
|
|
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
|
if c.Type == html.ElementNode && selector.Match(c) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}, nil
|
|
}
|