/* * * Copyright 2016, Google Inc. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are * met: * * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above * copyright notice, this list of conditions and the following disclaimer * in the documentation and/or other materials provided with the * distribution. * * Neither the name of Google Inc. nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * */ package credentials import ( "crypto/tls" "net" "testing" "golang.org/x/net/context" ) func TestTLSOverrideServerName(t *testing.T) { expectedServerName := "server.name" c := NewTLS(nil) c.OverrideServerName(expectedServerName) if c.Info().ServerName != expectedServerName { t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) } } func TestTLSClone(t *testing.T) { expectedServerName := "server.name" c := NewTLS(nil) c.OverrideServerName(expectedServerName) cc := c.Clone() if cc.Info().ServerName != expectedServerName { t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName) } cc.OverrideServerName("") if c.Info().ServerName != expectedServerName { t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) } } const tlsDir = "../test/testdata/" type serverHandshake func(net.Conn) (AuthInfo, error) func TestClientHandshakeReturnsAuthInfo(t *testing.T) { done := make(chan AuthInfo, 1) lis := launchServer(t, tlsServerHandshake, done) defer lis.Close() lisAddr := lis.Addr().String() clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr) // wait until server sends serverAuthInfo or fails. serverAuthInfo, ok := <-done if !ok { t.Fatalf("Error at server-side") } if !compare(clientAuthInfo, serverAuthInfo) { t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo) } } func TestServerHandshakeReturnsAuthInfo(t *testing.T) { done := make(chan AuthInfo, 1) lis := launchServer(t, gRPCServerHandshake, done) defer lis.Close() clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String()) // wait until server sends serverAuthInfo or fails. serverAuthInfo, ok := <-done if !ok { t.Fatalf("Error at server-side") } if !compare(clientAuthInfo, serverAuthInfo) { t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo) } } func TestServerAndClientHandshake(t *testing.T) { done := make(chan AuthInfo, 1) lis := launchServer(t, gRPCServerHandshake, done) defer lis.Close() clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String()) // wait until server sends serverAuthInfo or fails. serverAuthInfo, ok := <-done if !ok { t.Fatalf("Error at server-side") } if !compare(clientAuthInfo, serverAuthInfo) { t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo) } } func compare(a1, a2 AuthInfo) bool { if a1.AuthType() != a2.AuthType() { return false } switch a1.AuthType() { case "tls": state1 := a1.(TLSInfo).State state2 := a2.(TLSInfo).State if state1.Version == state2.Version && state1.HandshakeComplete == state2.HandshakeComplete && state1.CipherSuite == state2.CipherSuite && state1.NegotiatedProtocol == state2.NegotiatedProtocol { return true } return false default: return false } } func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener { lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen: %v", err) } go serverHandle(t, hs, done, lis) return lis } // Is run in a seperate goroutine. func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) { serverRawConn, err := lis.Accept() if err != nil { t.Errorf("Server failed to accept connection: %v", err) close(done) return } serverAuthInfo, err := hs(serverRawConn) if err != nil { t.Errorf("Server failed while handshake. Error: %v", err) serverRawConn.Close() close(done) return } done <- serverAuthInfo } func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo { conn, err := net.Dial("tcp", lisAddr) if err != nil { t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err) } defer conn.Close() clientAuthInfo, err := hs(conn, lisAddr) if err != nil { t.Fatalf("Error on client while handshake. Error: %v", err) } return clientAuthInfo } // Server handshake implementation in gRPC. func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) { serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") if err != nil { return nil, err } _, serverAuthInfo, err := serverTLS.ServerHandshake(conn) if err != nil { return nil, err } return serverAuthInfo, nil } // Client handshake implementation in gRPC. func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) { clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true}) _, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn) if err != nil { return nil, err } return authInfo, nil } func tlsServerHandshake(conn net.Conn) (AuthInfo, error) { cert, err := tls.LoadX509KeyPair(tlsDir+"server1.pem", tlsDir+"server1.key") if err != nil { return nil, err } serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}} serverConn := tls.Server(conn, serverTLSConfig) err = serverConn.Handshake() if err != nil { return nil, err } return TLSInfo{State: serverConn.ConnectionState()}, nil } func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) { clientTLSConfig := &tls.Config{InsecureSkipVerify: true} clientConn := tls.Client(conn, clientTLSConfig) if err := clientConn.Handshake(); err != nil { return nil, err } return TLSInfo{State: clientConn.ConnectionState()}, nil }