Clear write queue on client disconnect

This commit is contained in:
Trevor Slocum 2023-10-26 13:07:22 -07:00
parent e01ccd8e3c
commit e62a9053a3
4 changed files with 50 additions and 15 deletions

View file

@ -34,8 +34,12 @@ func (c *socketClient) HandleReadWrite() {
return
}
go c.writeEvents()
closeWrite := make(chan struct{}, 1)
go c.writeEvents(closeWrite)
c.readCommands()
closeWrite <- struct{}{}
}
func (c *socketClient) Write(message []byte) {
@ -60,7 +64,7 @@ func (c *socketClient) readCommands() {
var scanner = bufio.NewScanner(c.conn)
for scanner.Scan() {
if c.terminated {
continue // TODO wait group
return
}
if scanner.Err() != nil {
@ -73,11 +77,12 @@ func (c *socketClient) readCommands() {
c.commands <- buf
logClientRead(scanner.Bytes())
setTimeout()
}
}
func (c *socketClient) writeEvents() {
func (c *socketClient) writeEvents(closeWrite chan struct{}) {
setTimeout := func() {
err := c.conn.SetWriteDeadline(time.Now().Add(clientTimeout))
if err != nil {
@ -88,13 +93,26 @@ func (c *socketClient) writeEvents() {
setTimeout()
var event []byte
for event = range c.events {
for {
select {
case <-closeWrite:
for {
select {
case <-c.events:
c.wgEvents.Done()
default:
return
}
}
case event = <-c.events:
}
if c.terminated {
c.wgEvents.Done()
continue
}
setTimeout()
setTimeout()
_, err := c.conn.Write(append(event, '\n'))
if err != nil {
c.Terminate(err.Error())
@ -115,8 +133,8 @@ func (c *socketClient) Terminate(reason string) {
}
c.terminated = true
c.conn.Close()
go func() {
time.Sleep(5 * time.Second)
c.wgEvents.Wait()
close(c.events)
close(c.commands)

View file

@ -41,8 +41,12 @@ func (c *webSocketClient) HandleReadWrite() {
return
}
go c.writeEvents()
closeWrite := make(chan struct{}, 1)
go c.writeEvents(closeWrite)
c.readCommands()
closeWrite <- struct{}{}
}
func (c *webSocketClient) Write(message []byte) {
@ -63,12 +67,12 @@ func (c *webSocketClient) readCommands() {
}
}
setTimeout()
for {
if c.terminated {
continue // TODO wait group
return
}
setTimeout()
msg, op, err := wsutil.ReadClientData(c.conn)
if err != nil {
c.Terminate(err.Error())
@ -82,11 +86,10 @@ func (c *webSocketClient) readCommands() {
c.commands <- buf
logClientRead(msg)
setTimeout()
}
}
func (c *webSocketClient) writeEvents() {
func (c *webSocketClient) writeEvents(closeWrite chan struct{}) {
setTimeout := func() {
err := c.conn.SetWriteDeadline(time.Now().Add(clientTimeout))
if err != nil {
@ -98,12 +101,25 @@ func (c *webSocketClient) writeEvents() {
setTimeout()
var event []byte
for event = range c.events {
select {
case <-closeWrite:
for {
select {
case <-c.events:
c.wgEvents.Done()
default:
return
}
}
case event = <-c.events:
}
if c.terminated {
c.wgEvents.Done()
continue
}
setTimeout()
setTimeout()
err := wsutil.WriteServerMessage(c.conn, ws.OpText, event)
if err != nil {
c.Terminate(err.Error())
@ -124,6 +140,7 @@ func (c *webSocketClient) Terminate(reason string) {
}
c.terminated = true
c.conn.Close()
go func() {
time.Sleep(5 * time.Second)
c.wgEvents.Wait()

2
go.mod
View file

@ -7,5 +7,5 @@ require github.com/gobwas/ws v1.3.0
require (
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/sys v0.13.0 // indirect
)

4
go.sum
View file

@ -5,5 +5,5 @@ github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6Wezm
github.com/gobwas/ws v1.3.0 h1:sbeU3Y4Qzlb+MOzIe6mQGf7QR4Hkv6ZD0qhGkBFL2O0=
github.com/gobwas/ws v1.3.0/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=