Add mutual TLS support

This commit is contained in:
Douglas De Toni Machado
2021-06-24 11:54:12 -03:00
committed by GitHub
parent 6516f0ace6
commit 9baca36b2c

34
app.go
View File

@ -2,10 +2,13 @@ package main
import ( import (
"bytes" "bytes"
"crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -31,6 +34,7 @@ const (
var ( var (
cert string cert string
key string key string
ca string
port string port string
name string name string
) )
@ -38,6 +42,7 @@ var (
func init() { func init() {
flag.StringVar(&cert, "cert", "", "give me a certificate") flag.StringVar(&cert, "cert", "", "give me a certificate")
flag.StringVar(&key, "key", "", "give me a key") flag.StringVar(&key, "key", "", "give me a key")
flag.StringVar(&ca, "cacert", "", "give me a CA chain, enforces mutual TLS")
flag.StringVar(&port, "port", "80", "give me a port number") flag.StringVar(&port, "port", "80", "give me a port number")
flag.StringVar(&name, "name", os.Getenv("WHOAMI_NAME"), "give me a name") flag.StringVar(&name, "name", os.Getenv("WHOAMI_NAME"), "give me a name")
} }
@ -60,11 +65,38 @@ func main() {
fmt.Println("Starting up on port " + port) fmt.Println("Starting up on port " + port)
if len(cert) > 0 && len(key) > 0 { if len(cert) > 0 && len(key) > 0 {
log.Fatal(http.ListenAndServeTLS(":"+port, cert, key, nil)) server := &http.Server{
Addr: ":" + port,
}
if len(ca) > 0 {
server.TLSConfig = setupMutualTLS(ca)
}
log.Fatal(server.ListenAndServeTLS(cert, key))
} }
log.Fatal(http.ListenAndServe(":"+port, nil)) log.Fatal(http.ListenAndServe(":"+port, nil))
} }
func setupMutualTLS(ca string) *tls.Config {
clientCACert, err := ioutil.ReadFile(ca)
if err != nil {
log.Fatal(err)
}
clientCertPool := x509.NewCertPool()
clientCertPool.AppendCertsFromPEM(clientCACert)
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: clientCertPool,
PreferServerCipherSuites: true,
MinVersion: tls.VersionTLS12,
}
return tlsConfig
}
func benchHandler(w http.ResponseWriter, _ *http.Request) { func benchHandler(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Connection", "keep-alive") w.Header().Set("Connection", "keep-alive")
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")