From 65311696cd871f0fe6651af443ec555e2004cac0 Mon Sep 17 00:00:00 2001 From: Trevor Slocum Date: Thu, 5 Sep 2024 12:42:21 -0700 Subject: [PATCH] Serve secure websocket connections directly --- cmd/bgammon-server/main.go | 7 ++++++- pkg/server/client_websocket.go | 7 +------ pkg/server/server.go | 11 ++++++++++- pkg/server/server_full.go | 25 ++++++++++++++++++++++++- 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/cmd/bgammon-server/main.go b/cmd/bgammon-server/main.go index 74b3224..69f6733 100644 --- a/cmd/bgammon-server/main.go +++ b/cmd/bgammon-server/main.go @@ -51,6 +51,11 @@ func main() { resetSalt = os.Getenv("BGAMMON_SALT_RESET") ipSalt = os.Getenv("BGAMMON_SALT_IP") + certDomain := os.Getenv("BGAMMON_CERT_DOMAIN") + certFolder := os.Getenv("BGAMMON_CERT_FOLDER") + certEmail := os.Getenv("BGAMMON_CERT_EMAIL") + certAddress := os.Getenv("BGAMMON_CERT_ADDRESS") + if rollStatistics { printRollStatistics() return @@ -66,7 +71,7 @@ func main() { }() } - s := server.NewServer(tz, dataSource, mailServer, passwordSalt, resetSalt, ipSalt, false, verbose || debug > 0, debugCommands) + s := server.NewServer(tz, dataSource, mailServer, passwordSalt, resetSalt, ipSalt, certDomain, certFolder, certEmail, certAddress, false, verbose || debug > 0, debugCommands) if tcpAddress != "" { s.Listen("tcp", tcpAddress) } diff --git a/pkg/server/client_websocket.go b/pkg/server/client_websocket.go index e838b29..d58f3f6 100644 --- a/pkg/server/client_websocket.go +++ b/pkg/server/client_websocket.go @@ -35,14 +35,9 @@ func newWebSocketClient(r *http.Request, w http.ResponseWriter, commands chan<- return nil } - address := r.Header.Get("X-Forwarded-For") - if address == "" { - address = r.RemoteAddr - } - return &webSocketClient{ conn: conn, - address: address, + address: r.RemoteAddr, events: events, commands: commands, verbose: verbose, diff --git a/pkg/server/server.go b/pkg/server/server.go index b048c5c..72ebad9 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -94,6 +94,11 @@ type server struct { languageTags []language.Tag languageNames [][]byte + certDomain string + certFolder string + certEmail string + certAddress string + relayChat bool // Chats are not relayed normally. This option is only used by local servers. verbose bool @@ -101,7 +106,7 @@ type server struct { shutdownReason string } -func NewServer(tz string, dataSource string, mailServer string, passwordSalt string, resetSalt string, ipSalt string, relayChat bool, verbose bool, allowDebug bool) *server { +func NewServer(tz string, dataSource string, mailServer string, passwordSalt string, resetSalt string, ipSalt string, certDomain string, certFolder string, certEmail string, certAddress string, relayChat bool, verbose bool, allowDebug bool) *server { const bufferSize = 10 s := &server{ newGameIDs: make(chan int), @@ -113,6 +118,10 @@ func NewServer(tz string, dataSource string, mailServer string, passwordSalt str passwordSalt: passwordSalt, resetSalt: resetSalt, ipSalt: ipSalt, + certDomain: certDomain, + certFolder: certFolder, + certEmail: certEmail, + certAddress: certAddress, relayChat: relayChat, verbose: verbose, } diff --git a/pkg/server/server_full.go b/pkg/server/server_full.go index b749dee..85e6a72 100644 --- a/pkg/server/server_full.go +++ b/pkg/server/server_full.go @@ -3,6 +3,7 @@ package server import ( + "crypto/tls" "encoding/json" "fmt" "log" @@ -14,6 +15,7 @@ import ( "code.rocket9labs.com/tslocum/bgammon" "github.com/gorilla/mux" + "golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/sha3" ) @@ -76,7 +78,28 @@ func (s *server) listenWebSocket(address string) { m.HandleFunc("/stats/{username:[A-Za-z0-9_\\-]+}/tabula.json", s.handleAccountStatsFunc(matchTypeCasual, bgammon.VariantTabula)) m.HandleFunc("/", s.handleWebSocket) - err := http.ListenAndServe(address, m) + certManager := autocert.Manager{ + Prompt: autocert.AcceptTOS, + Cache: autocert.DirCache(s.certFolder), + HostPolicy: autocert.HostWhitelist(s.certDomain), + Email: s.certEmail, + } + + server := &http.Server{ + Addr: address, + Handler: m, + TLSConfig: &tls.Config{ + GetCertificate: certManager.GetCertificate, + MinVersion: tls.VersionTLS12, + }, + } + + go func() { + err := http.ListenAndServe(s.certAddress, certManager.HTTPHandler(m)) + log.Fatalf("failed to listen on %s: %s", s.certAddress, err) + }() + + err := server.ListenAndServeTLS("", "") log.Fatalf("failed to listen on %s: %s", address, err) }