clair/vendor/github.com/tylerb/graceful/graceful_test.go
2017-05-04 13:59:57 -04:00

693 lines
14 KiB
Go

package graceful
import (
"bytes"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
"reflect"
"strings"
"sync"
"syscall"
"testing"
"time"
)
const (
// The tests will run a test server on this port.
port = 9654
concurrentRequestN = 8
killTime = 500 * time.Millisecond
timeoutTime = 1000 * time.Millisecond
waitTime = 100 * time.Millisecond
)
func runQuery(t *testing.T, expected int, shouldErr bool, wg *sync.WaitGroup, once *sync.Once) {
defer wg.Done()
client := http.Client{}
r, err := client.Get(fmt.Sprintf("http://localhost:%d", port))
if shouldErr && err == nil {
once.Do(func() {
t.Error("Expected an error but none was encountered.")
})
} else if shouldErr && err != nil {
if checkErr(t, err, once) {
return
}
}
if r != nil && r.StatusCode != expected {
once.Do(func() {
t.Errorf("Incorrect status code on response. Expected %d. Got %d", expected, r.StatusCode)
})
} else if r == nil {
once.Do(func() {
t.Error("No response when a response was expected.")
})
}
}
func checkErr(t *testing.T, err error, once *sync.Once) bool {
if err.(*url.Error).Err == io.EOF {
return true
}
var errno syscall.Errno
switch oe := err.(*url.Error).Err.(type) {
case *net.OpError:
switch e := oe.Err.(type) {
case syscall.Errno:
errno = e
case *os.SyscallError:
errno = e.Err.(syscall.Errno)
}
if errno == syscall.ECONNREFUSED {
return true
} else if err != nil {
once.Do(func() {
t.Error("Error on Get:", err)
})
}
default:
if strings.Contains(err.Error(), "transport closed before response was received") {
return true
}
if strings.Contains(err.Error(), "server closed connection") {
return true
}
fmt.Printf("unknown err: %s, %#v\n", err, err)
}
return false
}
func createListener(sleep time.Duration) (*http.Server, net.Listener, error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
time.Sleep(sleep)
rw.WriteHeader(http.StatusOK)
})
server := &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: mux}
l, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
return server, l, err
}
func launchTestQueries(t *testing.T, wg *sync.WaitGroup, c chan os.Signal) {
defer wg.Done()
var once sync.Once
for i := 0; i < concurrentRequestN; i++ {
wg.Add(1)
go runQuery(t, http.StatusOK, false, wg, &once)
}
time.Sleep(waitTime)
c <- os.Interrupt
time.Sleep(waitTime)
for i := 0; i < concurrentRequestN; i++ {
wg.Add(1)
go runQuery(t, 0, true, wg, &once)
}
}
func TestGracefulRun(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
c := make(chan os.Signal, 1)
server, l, err := createListener(killTime / 2)
if err != nil {
t.Fatal(err)
}
wg.Add(1)
go func() {
defer wg.Done()
srv := &Server{Timeout: killTime, Server: server, interrupt: c}
srv.Serve(l)
}()
wg.Add(1)
go launchTestQueries(t, &wg, c)
}
func TestGracefulRunLimitKeepAliveListener(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
c := make(chan os.Signal, 1)
server, l, err := createListener(killTime / 2)
if err != nil {
t.Fatal(err)
}
wg.Add(1)
go func() {
defer wg.Done()
srv := &Server{
Timeout: killTime,
ListenLimit: concurrentRequestN,
TCPKeepAlive: 1 * time.Second,
Server: server,
interrupt: c,
}
srv.Serve(l)
}()
wg.Add(1)
go launchTestQueries(t, &wg, c)
}
func TestGracefulRunTimesOut(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
c := make(chan os.Signal, 1)
server, l, err := createListener(killTime * 10)
if err != nil {
t.Fatal(err)
}
wg.Add(1)
go func() {
defer wg.Done()
srv := &Server{Timeout: killTime, Server: server, interrupt: c}
srv.Serve(l)
}()
wg.Add(1)
go func() {
defer wg.Done()
var once sync.Once
for i := 0; i < concurrentRequestN; i++ {
wg.Add(1)
go runQuery(t, 0, true, &wg, &once)
}
time.Sleep(waitTime)
c <- os.Interrupt
time.Sleep(waitTime)
for i := 0; i < concurrentRequestN; i++ {
wg.Add(1)
go runQuery(t, 0, true, &wg, &once)
}
}()
}
func TestGracefulRunDoesntTimeOut(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
c := make(chan os.Signal, 1)
server, l, err := createListener(killTime * 2)
if err != nil {
t.Fatal(err)
}
wg.Add(1)
go func() {
defer wg.Done()
srv := &Server{Timeout: 0, Server: server, interrupt: c}
srv.Serve(l)
}()
wg.Add(1)
go launchTestQueries(t, &wg, c)
}
func TestGracefulRunDoesntTimeOutAfterConnectionCreated(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
c := make(chan os.Signal, 1)
server, l, err := createListener(killTime)
if err != nil {
t.Fatal(err)
}
wg.Add(1)
go func() {
defer wg.Done()
srv := &Server{Timeout: 0, Server: server, interrupt: c}
srv.Serve(l)
}()
time.Sleep(waitTime)
// Make a sample first request. The connection will be left idle.
resp, err := http.Get(fmt.Sprintf("http://localhost:%d", port))
if err != nil {
panic(fmt.Sprintf("first request failed: %v", err))
}
resp.Body.Close()
wg.Add(1)
go func() {
defer wg.Done()
// With idle connections improperly handled, the server doesn't wait for this
// to complete and the request fails. It should be allowed to complete successfully.
_, err := http.Get(fmt.Sprintf("http://localhost:%d", port))
if err != nil {
t.Errorf("Get failed: %v", err)
}
}()
// Ensure the request goes out
time.Sleep(waitTime)
c <- os.Interrupt
wg.Wait()
}
func TestGracefulRunNoRequests(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
c := make(chan os.Signal, 1)
server, l, err := createListener(killTime * 2)
if err != nil {
t.Fatal(err)
}
wg.Add(1)
go func() {
defer wg.Done()
srv := &Server{Timeout: 0, Server: server, interrupt: c}
srv.Serve(l)
}()
c <- os.Interrupt
}
func TestGracefulForwardsConnState(t *testing.T) {
var stateLock sync.Mutex
states := make(map[http.ConnState]int)
connState := func(conn net.Conn, state http.ConnState) {
stateLock.Lock()
states[state]++
stateLock.Unlock()
}
var wg sync.WaitGroup
defer wg.Wait()
expected := map[http.ConnState]int{
http.StateNew: concurrentRequestN,
http.StateActive: concurrentRequestN,
http.StateClosed: concurrentRequestN,
}
c := make(chan os.Signal, 1)
server, l, err := createListener(killTime / 2)
if err != nil {
t.Fatal(err)
}
wg.Add(1)
go func() {
defer wg.Done()
srv := &Server{
ConnState: connState,
Timeout: killTime,
Server: server,
interrupt: c,
}
srv.Serve(l)
}()
wg.Add(1)
go launchTestQueries(t, &wg, c)
wg.Wait()
stateLock.Lock()
if !reflect.DeepEqual(states, expected) {
t.Errorf("Incorrect connection state tracking.\n actual: %v\nexpected: %v\n", states, expected)
}
stateLock.Unlock()
}
func TestGracefulExplicitStop(t *testing.T) {
server, l, err := createListener(1 * time.Millisecond)
if err != nil {
t.Fatal(err)
}
srv := &Server{Timeout: killTime, Server: server}
go func() {
go srv.Serve(l)
time.Sleep(waitTime)
srv.Stop(killTime)
}()
// block on the stopChan until the server has shut down
select {
case <-srv.StopChan():
case <-time.After(timeoutTime):
t.Fatal("Timed out while waiting for explicit stop to complete")
}
}
func TestGracefulExplicitStopOverride(t *testing.T) {
server, l, err := createListener(1 * time.Millisecond)
if err != nil {
t.Fatal(err)
}
srv := &Server{Timeout: killTime, Server: server}
go func() {
go srv.Serve(l)
time.Sleep(waitTime)
srv.Stop(killTime / 2)
}()
// block on the stopChan until the server has shut down
select {
case <-srv.StopChan():
case <-time.After(killTime):
t.Fatal("Timed out while waiting for explicit stop to complete")
}
}
func TestBeforeShutdownAndShutdownInitiatedCallbacks(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
server, l, err := createListener(1 * time.Millisecond)
if err != nil {
t.Fatal(err)
}
beforeShutdownCalled := make(chan struct{})
cb1 := func() bool { close(beforeShutdownCalled); return true }
shutdownInitiatedCalled := make(chan struct{})
cb2 := func() { close(shutdownInitiatedCalled) }
wg.Add(2)
srv := &Server{Server: server, BeforeShutdown: cb1, ShutdownInitiated: cb2}
go func() {
defer wg.Done()
srv.Serve(l)
}()
go func() {
defer wg.Done()
time.Sleep(waitTime)
srv.Stop(killTime)
}()
beforeShutdown := false
shutdownInitiated := false
for i := 0; i < 2; i++ {
select {
case <-beforeShutdownCalled:
beforeShutdownCalled = nil
beforeShutdown = true
case <-shutdownInitiatedCalled:
shutdownInitiatedCalled = nil
shutdownInitiated = true
case <-time.After(killTime):
t.Fatal("Timed out while waiting for ShutdownInitiated callback to be called")
}
}
if !beforeShutdown {
t.Fatal("beforeShutdown should be true")
}
if !shutdownInitiated {
t.Fatal("shutdownInitiated should be true")
}
}
func TestBeforeShutdownCanceled(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
server, l, err := createListener(1 * time.Millisecond)
if err != nil {
t.Fatal(err)
}
beforeShutdownCalled := make(chan struct{})
cb1 := func() bool { close(beforeShutdownCalled); return false }
shutdownInitiatedCalled := make(chan struct{})
cb2 := func() { close(shutdownInitiatedCalled) }
srv := &Server{Server: server, BeforeShutdown: cb1, ShutdownInitiated: cb2}
go func() {
srv.Serve(l)
wg.Done()
}()
go func() {
time.Sleep(waitTime)
srv.Stop(killTime)
}()
beforeShutdown := false
shutdownInitiated := false
timeouted := false
for i := 0; i < 2; i++ {
select {
case <-beforeShutdownCalled:
beforeShutdownCalled = nil
beforeShutdown = true
case <-shutdownInitiatedCalled:
shutdownInitiatedCalled = nil
shutdownInitiated = true
case <-time.After(killTime):
timeouted = true
}
}
if !beforeShutdown {
t.Fatal("beforeShutdown should be true")
}
if !timeouted {
t.Fatal("timeouted should be true")
}
if shutdownInitiated {
t.Fatal("shutdownInitiated shouldn't be true")
}
srv.BeforeShutdown = func() bool { return true }
srv.Stop(killTime)
wg.Wait()
}
func hijackingListener(srv *Server) (*http.Server, net.Listener, error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
conn, bufrw, err := rw.(http.Hijacker).Hijack()
if err != nil {
http.Error(rw, "webserver doesn't support hijacking", http.StatusInternalServerError)
return
}
defer conn.Close()
bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
bufrw.Flush()
})
server := &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: mux}
l, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
return server, l, err
}
func TestNotifyClosed(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
c := make(chan os.Signal, 1)
srv := &Server{Timeout: killTime, interrupt: c}
server, l, err := hijackingListener(srv)
if err != nil {
t.Fatal(err)
}
srv.Server = server
wg.Add(1)
go func() {
defer wg.Done()
srv.Serve(l)
}()
var once sync.Once
for i := 0; i < concurrentRequestN; i++ {
wg.Add(1)
runQuery(t, http.StatusOK, false, &wg, &once)
}
srv.Stop(0)
// block on the stopChan until the server has shut down
select {
case <-srv.StopChan():
case <-time.After(timeoutTime):
t.Fatal("Timed out while waiting for explicit stop to complete")
}
if len(srv.connections) > 0 {
t.Fatal("hijacked connections should not be managed")
}
}
func TestStopDeadlock(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
c := make(chan struct{})
server, l, err := createListener(1 * time.Millisecond)
if err != nil {
t.Fatal(err)
}
srv := &Server{Server: server, NoSignalHandling: true}
wg.Add(2)
go func() {
defer wg.Done()
time.Sleep(waitTime)
srv.Serve(l)
}()
go func() {
defer wg.Done()
srv.Stop(0)
close(c)
}()
select {
case <-c:
l.Close()
case <-time.After(timeoutTime):
t.Fatal("Timed out while waiting for explicit stop to complete")
}
}
// Run with --race
func TestStopRace(t *testing.T) {
server, l, err := createListener(1 * time.Millisecond)
if err != nil {
t.Fatal(err)
}
srv := &Server{Timeout: killTime, Server: server}
go func() {
go srv.Serve(l)
srv.Stop(killTime)
}()
srv.Stop(0)
select {
case <-srv.StopChan():
case <-time.After(timeoutTime):
t.Fatal("Timed out while waiting for explicit stop to complete")
}
}
func TestInterruptLog(t *testing.T) {
c := make(chan os.Signal, 1)
server, l, err := createListener(killTime * 10)
if err != nil {
t.Fatal(err)
}
var buf bytes.Buffer
var tbuf bytes.Buffer
logger := log.New(&buf, "", 0)
expected := log.New(&tbuf, "", 0)
srv := &Server{Timeout: killTime, Server: server, Logger: logger, interrupt: c}
go func() { srv.Serve(l) }()
stop := srv.StopChan()
c <- os.Interrupt
expected.Print("shutdown initiated")
<-stop
if buf.String() != tbuf.String() {
t.Fatal("shutdown log incorrect - got '" + buf.String() + "'")
}
}
func TestMultiInterrupts(t *testing.T) {
c := make(chan os.Signal, 1)
server, l, err := createListener(killTime * 10)
if err != nil {
t.Fatal(err)
}
var wg sync.WaitGroup
var bu bytes.Buffer
buf := SyncBuffer{&wg, &bu}
var tbuf bytes.Buffer
logger := log.New(&buf, "", 0)
expected := log.New(&tbuf, "", 0)
srv := &Server{Timeout: killTime, Server: server, Logger: logger, interrupt: c}
go func() { srv.Serve(l) }()
stop := srv.StopChan()
buf.Add(1 + 10) // Expecting 11 log calls
c <- os.Interrupt
expected.Printf("shutdown initiated")
for i := 0; i < 10; i++ {
c <- os.Interrupt
expected.Printf("already shutting down")
}
<-stop
wg.Wait()
bb, bt := buf.Bytes(), tbuf.Bytes()
for i, b := range bb {
if b != bt[i] {
t.Fatal(fmt.Sprintf("shutdown log incorrect - got '%s', expected '%s'", buf.String(), tbuf.String()))
}
}
}
func TestLogFunc(t *testing.T) {
c := make(chan os.Signal, 1)
server, l, err := createListener(killTime * 10)
if err != nil {
t.Fatal(err)
}
var called bool
srv := &Server{Timeout: killTime, Server: server,
LogFunc: func(format string, args ...interface{}) {
called = true
}, interrupt: c}
stop := srv.StopChan()
go func() { srv.Serve(l) }()
c <- os.Interrupt
<-stop
if called != true {
t.Fatal("Expected LogFunc to be called.")
}
}
// SyncBuffer calls Done on the embedded wait group after each call to Write.
type SyncBuffer struct {
*sync.WaitGroup
*bytes.Buffer
}
func (buf *SyncBuffer) Write(b []byte) (int, error) {
defer buf.Done()
return buf.Buffer.Write(b)
}