Support WebSocket connections

This commit is contained in:
Trevor Slocum 2023-09-11 22:10:04 -07:00
parent 74f3446111
commit 5ecbc5d583
6 changed files with 219 additions and 16 deletions

View file

@ -0,0 +1,134 @@
package main
import (
"log"
"net"
"net/http"
"sync"
"time"
"code.rocket9labs.com/tslocum/bgammon"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)
var _ bgammon.Client = &webSocketClient{}
type webSocketClient struct {
conn net.Conn
events chan []byte
commands chan<- []byte
terminated bool
wgEvents sync.WaitGroup
}
func newWebSocketClient(r *http.Request, w http.ResponseWriter, commands chan<- []byte, events chan []byte) *webSocketClient {
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
return nil
}
c := &webSocketClient{
conn: conn,
events: events,
commands: commands,
}
return c
}
func (c *webSocketClient) HandleReadWrite() {
if c.terminated {
return
}
go c.writeEvents()
c.readCommands()
}
func (c *webSocketClient) Write(message []byte) {
if c.terminated {
return
}
c.wgEvents.Add(1)
c.events <- message
}
func (c *webSocketClient) readCommands() {
setTimeout := func() {
err := c.conn.SetReadDeadline(time.Now().Add(clientTimeout))
if err != nil {
c.Terminate(err.Error())
return
}
}
setTimeout()
for {
if c.terminated {
continue // TODO wait group
}
msg, op, err := wsutil.ReadClientData(c.conn)
if err != nil {
c.Terminate(err.Error())
return
} else if op != ws.OpText {
continue
}
buf := make([]byte, len(msg))
copy(buf, msg)
c.commands <- buf
log.Printf("<- %s", msg)
setTimeout()
}
}
func (c *webSocketClient) writeEvents() {
setTimeout := func() {
err := c.conn.SetWriteDeadline(time.Now().Add(clientTimeout))
if err != nil {
c.Terminate(err.Error())
return
}
}
var event []byte
for event = range c.events {
if c.terminated {
c.wgEvents.Done()
continue
}
setTimeout()
err := wsutil.WriteServerMessage(c.conn, ws.OpText, event)
if err != nil {
c.Terminate(err.Error())
c.wgEvents.Done()
continue
}
log.Printf("-> %s", event)
c.wgEvents.Done()
}
}
func (c *webSocketClient) Terminate(reason string) {
if c.terminated {
return
}
c.terminated = true
c.conn.Close()
go func() {
time.Sleep(5 * time.Second)
c.wgEvents.Wait()
close(c.events)
close(c.commands)
}()
}
func (c *webSocketClient) Terminated() bool {
return c.terminated
}

View file

@ -9,12 +9,20 @@ import (
)
func main() {
var tcpAddress string
var debug int
var (
tcpAddress string
wsAddress string
debug int
)
flag.StringVar(&tcpAddress, "tcp", "localhost:1337", "TCP listen address")
flag.StringVar(&wsAddress, "ws", "localhost:1338", "WebSocket listen address")
flag.IntVar(&debug, "debug", 0, "print debug information and serve pprof on specified port")
flag.Parse()
if tcpAddress == "" && wsAddress == "" {
log.Fatal("Error: A TCP and/or WebSocket listen address must be specified.")
}
if debug > 0 {
go func() {
log.Fatal(http.ListenAndServe(fmt.Sprintf("localhost:%d", debug), nil))
@ -22,6 +30,11 @@ func main() {
}
s := newServer()
s.listen("tcp", tcpAddress)
if tcpAddress != "" {
s.listen("tcp", tcpAddress)
}
if wsAddress != "" {
s.listen("ws", wsAddress)
}
select {}
}

View file

@ -6,6 +6,7 @@ import (
"log"
"math/rand"
"net"
"net/http"
"regexp"
"strconv"
"strings"
@ -50,7 +51,41 @@ func newServer() *server {
return s
}
func (s *server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
const bufferSize = 8
commands := make(chan []byte, bufferSize)
events := make(chan []byte, bufferSize)
wsClient := newWebSocketClient(r, w, commands, events)
if wsClient == nil {
return
}
now := time.Now().Unix()
c := &serverClient{
id: <-s.newClientIDs,
account: -1,
connected: now,
lastActive: now,
commands: commands,
Client: wsClient,
}
s.handleClient(c)
}
func (s *server) listenWebSocket(address string) {
log.Printf("Listening for WebSocket connections on %s...", address)
err := http.ListenAndServe(address, http.HandlerFunc(s.handleWebSocket))
log.Fatalf("failed to listen on %s: %s", address, err)
}
func (s *server) listen(network string, address string) {
if strings.ToLower(network) == "ws" {
go s.listenWebSocket(address)
return
}
log.Printf("Listening for %s connections on %s...", strings.ToUpper(network), address)
listener, err := net.Listen(network, address)
if err != nil {
@ -128,6 +163,22 @@ func (s *server) handleTerminatedGames() {
}
}
func (s *server) handleClient(c *serverClient) {
s.addClient(c)
log.Printf("Client %s connected", c.label())
go s.handlePingClient(c)
go s.handleClientCommands(c)
c.HandleReadWrite()
// Remove client.
s.removeClient(c)
log.Printf("Client %s disconnected", c.label())
}
func (s *server) handleConnection(conn net.Conn) {
const bufferSize = 8
commands := make(chan []byte, bufferSize)
@ -144,19 +195,7 @@ func (s *server) handleConnection(conn net.Conn) {
Client: newSocketClient(conn, commands, events),
}
s.sendHello(c)
s.addClient(c)
log.Printf("Client %s connected", c.label())
go s.handlePingClient(c)
go s.handleClientCommands(c)
c.HandleReadWrite()
// Remove client.
s.removeClient(c)
log.Printf("Client %s disconnected", c.label())
s.handleClient(c)
}
func (s *server) handlePingClient(c *serverClient) {

8
go.mod
View file

@ -1,3 +1,11 @@
module code.rocket9labs.com/tslocum/bgammon
go 1.20
require github.com/gobwas/ws v1.3.0
require (
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
golang.org/x/sys v0.12.0 // indirect
)

9
go.sum Normal file
View file

@ -0,0 +1,9 @@
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.3.0 h1:sbeU3Y4Qzlb+MOzIe6mQGf7QR4Hkv6ZD0qhGkBFL2O0=
github.com/gobwas/ws v1.3.0/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=