From c92554fe6269bc20b149d1706acbd29b833c8084 Mon Sep 17 00:00:00 2001 From: Trevor Slocum Date: Sat, 15 Apr 2017 14:31:16 -0700 Subject: [PATCH] Potentially resolve concurrent read/write crashes --- anonircd.go | 3 +- channel.go | 4 +- entity.go | 30 +++--- server.go | 284 ++++++++++++++++++++++++++++++++-------------------- 4 files changed, 197 insertions(+), 124 deletions(-) diff --git a/anonircd.go b/anonircd.go index 5b3d366..4942a6b 100644 --- a/anonircd.go +++ b/anonircd.go @@ -23,6 +23,7 @@ import ( "sync" "time" + "github.com/orcaman/concurrent-map" irc "gopkg.in/sorcix/irc.v2" "os" "os/signal" @@ -80,7 +81,7 @@ func randomIdentifier() string { func main() { rand.Seed(time.Now().UTC().UnixNano()) - server := Server{&Config{}, time.Now().Unix(), make(map[string]*Client), make(map[string]*Channel), make(chan bool, 1), make(chan bool, 1), new(sync.RWMutex)} + server := Server{&Config{}, time.Now().Unix(), cmap.New(), cmap.New(), make(chan bool, 1), make(chan bool, 1), new(sync.RWMutex)} server.loadConfig() sighup := make(chan os.Signal, 1) diff --git a/channel.go b/channel.go index 87bc29b..62415df 100644 --- a/channel.go +++ b/channel.go @@ -1,9 +1,11 @@ package main +import "github.com/orcaman/concurrent-map" + type Channel struct { Entity - clients map[string]int + clients cmap.ConcurrentMap topic string topictime int64 diff --git a/entity.go b/entity.go index 434c872..55fc07b 100644 --- a/entity.go +++ b/entity.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "github.com/orcaman/concurrent-map" "strings" "sync" ) @@ -17,17 +18,21 @@ type Entity struct { entitytype int identifier string created int64 - modes map[string]string + modes cmap.ConcurrentMap *sync.RWMutex } -func (e *Entity) hasMode(mode string) bool { - if _, ok := e.modes[mode]; ok { - return true +func (e *Entity) getModes() map[string]string { + modes := make(map[string]string) + for ms := range e.modes.IterBuffered() { + modes[ms.Key] = ms.Val.(string) } + return modes +} - return false +func (e *Entity) hasMode(mode string) bool { + return e.modes.Has(mode) } func (e *Entity) addMode(mode string, value string) { @@ -39,7 +44,7 @@ func (e *Entity) addMode(mode string, value string) { } if strings.Index(allowedmodes, mode) != -1 && !e.hasMode(mode) { - e.modes[mode] = value + e.modes.Set(mode, value) } } @@ -51,7 +56,7 @@ func (e *Entity) addModes(modes string) { func (e *Entity) removeMode(mode string) { if e.hasMode(mode) { - delete(e.modes, mode) + e.modes.Remove(mode) } } @@ -64,17 +69,18 @@ func (e *Entity) removeModes(modes string) { func (e *Entity) diffModes(lastmodes map[string]string) (map[string]string, map[string]string) { addedmodes := make(map[string]string) if lastmodes != nil { - for mode := range e.modes { - if _, ok := lastmodes[mode]; !ok { - addedmodes[mode] = lastmodes[mode] + for m := range e.modes.IterBuffered() { + if _, ok := lastmodes[m.Key]; !ok { + addedmodes[m.Key] = lastmodes[m.Key] } } } removedmodes := make(map[string]string) for mode := range lastmodes { - if _, ok := e.modes[mode]; !ok { - removedmodes[mode] = e.modes[mode] + if e.hasMode(mode) { + m, _ := e.modes.Get(mode) + removedmodes[mode] = m.(string) } } diff --git a/server.go b/server.go index 641cf12..143f75b 100644 --- a/server.go +++ b/server.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "github.com/BurntSushi/toml" + cmap "github.com/orcaman/concurrent-map" irc "gopkg.in/sorcix/irc.v2" "math/rand" "os" @@ -25,8 +26,8 @@ type Config struct { type Server struct { config *Config created int64 - clients map[string]*Client - channels map[string]*Channel + clients cmap.ConcurrentMap + channels cmap.ConcurrentMap restartplain chan bool restartssl chan bool @@ -42,19 +43,27 @@ func (s *Server) getAnonymousPrefix(i int) *irc.Prefix { return &prefix } +func (s *Server) getChannel(channel string) *Channel { + if ch, ok := s.channels.Get(channel); ok { + return ch.(*Channel) + } + + return nil +} + func (s *Server) getChannels(client string) map[string]*Channel { channels := make(map[string]*Channel) - for channelname, channel := range s.channels { - if s.inChannel(channelname, client) { - channels[channelname] = channel + for chs := range s.channels.IterBuffered() { + if s.inChannel(chs.Key, client) { + channels[chs.Key] = chs.Val.(*Channel) } } return channels } func (s *Server) getClient(client string) *Client { - if _, ok := s.clients[client]; ok { - return s.clients[client] + if cl, ok := s.clients.Get(client); ok { + return cl.(*Client) } return nil @@ -62,22 +71,18 @@ func (s *Server) getClient(client string) *Client { func (s *Server) getClients(channel string) map[string]*Client { clients := make(map[string]*Client) - if !s.channelExists(channel) { - return clients - } - for clientname := range s.channels[channel].clients { - cl := s.getClient(clientname) - if cl != nil { - clients[clientname] = cl - } + ch := s.getChannel(channel) + + for cls := range ch.clients.IterBuffered() { + clients[cls.Key] = cls.Val.(*Client) } return clients } func (s *Server) channelExists(channel string) bool { - if _, ok := s.channels[channel]; ok { + if _, ok := s.channels.Get(channel); ok { return true } @@ -85,10 +90,9 @@ func (s *Server) channelExists(channel string) bool { } func (s *Server) inChannel(channel string, client string) bool { - if s.channelExists(channel) { - if _, ok := s.channels[channel].clients[client]; ok { - return true - } + ch := s.getChannel(channel) + if ch != nil { + return ch.clients.Has(client) } return false @@ -99,35 +103,41 @@ func (s *Server) joinChannel(channel string, client string) { return // Already in channel } - if !s.channelExists(channel) { - s.channels[channel] = &Channel{Entity{ENTITY_CHANNEL, channel, time.Now().Unix(), make(map[string]string), new(sync.RWMutex)}, make(map[string]int), "", 0} - } else if s.channels[channel].hasMode("z") && !s.clients[client].ssl { - s.clients[client].sendNotice("Unable to join " + channel + ": SSL connections only (channel mode +z)") + ch := s.getChannel(channel) + cl := s.getClient(client) + + if cl == nil { return } - s.channels[channel].Lock() - s.channels[channel].clients[client] = s.getClientCount(channel, client) - s.channels[channel].Unlock() - s.clients[client].write(&irc.Message{s.clients[client].getPrefix(), irc.JOIN, []string{channel}}) + if ch == nil { + ch = &Channel{Entity{ENTITY_CHANNEL, channel, time.Now().Unix(), cmap.New(), new(sync.RWMutex)}, cmap.New(), "", 0} + s.channels.Set(channel, ch) + } else if ch.hasMode("z") && !cl.ssl { + cl.sendNotice("Unable to join " + channel + ": SSL connections only (channel mode +z)") + return + } + + ch.clients.Set(client, s.getClientCount(channel, client)+1) + cl.write(&irc.Message{cl.getPrefix(), irc.JOIN, []string{channel}}) s.sendNames(channel, client) - s.updateClientCount(channel, "") + s.updateClientCount(channel) s.sendTopic(channel, client, false) } func (s *Server) partChannel(channel string, client string, reason string) { - if !s.inChannel(channel, client) { + ch := s.getChannel(channel) + cl := s.getClient(client) + + if cl == nil || !s.inChannel(channel, client) { return } - s.clients[client].write(&irc.Message{s.clients[client].getPrefix(), irc.PART, []string{channel, reason}}) + cl.write(&irc.Message{cl.getPrefix(), irc.PART, []string{channel, reason}}) + ch.clients.Remove(client) - s.channels[channel].Lock() - delete(s.channels[channel].clients, client) - s.channels[channel].Unlock() - - s.updateClientCount(channel, "") + s.updateClientCount(channel) } func (s *Server) partAllChannels(client string) { @@ -137,9 +147,11 @@ func (s *Server) partAllChannels(client string) { } func (s *Server) enforceModes(channel string) { - if s.channels[channel].hasMode("z") { - for client := range s.getClients(channel) { - if !s.clients[client].ssl { + ch := s.getChannel(channel) + + if ch != nil && ch.hasMode("z") { + for client, cl := range s.getClients(channel) { + if !cl.ssl { s.partChannel(channel, client, "Only SSL connections are allowed in this channel") } } @@ -147,42 +159,55 @@ func (s *Server) enforceModes(channel string) { } func (s *Server) getClientCount(channel string, client string) int { - ccount := len(s.channels[channel].clients) + ch := s.getChannel(channel) + cl := s.getClient(client) - if (s.clients[client].hasMode("c") || s.channels[channel].hasMode("c")) && ccount >= 2 { + if ch == nil || cl == nil { + return 0 + } + + ccount := ch.clients.Count() + + if (ch.hasMode("c") || cl.hasMode("c")) && ccount >= 2 { return 2 } return ccount } -func (s *Server) updateClientCount(channel string, client string) { - clients := make(map[string]int) - if client != "" { - clients[client] = s.channels[channel].clients[client] - } else { - clients = s.channels[channel].clients +func (s *Server) updateClientCount(channel string) { + ch := s.getChannel(channel) + + if ch == nil { + return } - for cclient, ccount := range clients { + + for cls := range ch.clients.IterBuffered() { + cclient := cls.Key + ccount := cls.Val.(int) + chancount := s.getClientCount(channel, cclient) + cl := s.getClient(cclient) + + if cl == nil { + continue + } if ccount < chancount { - s.channels[channel].Lock() - for i := ccount; i < chancount - 1; i++ { - s.clients[cclient].write(&irc.Message{s.getAnonymousPrefix(i), irc.JOIN, []string{channel}}) + for i := ccount; i < chancount; i++ { + cl.write(&irc.Message{s.getAnonymousPrefix(i), irc.JOIN, []string{channel}}) } - s.channels[channel].clients[cclient] = chancount - s.channels[channel].Unlock() + ch.clients.Set(cclient, chancount) } else if ccount > chancount { - s.channels[channel].Lock() - for i := ccount; i > chancount - 1; i-- { - s.clients[cclient].write(&irc.Message{s.getAnonymousPrefix(i), irc.PART, []string{channel}}) + for i := ccount; i > chancount; i-- { + cl.write(&irc.Message{s.getAnonymousPrefix(i - 1), irc.PART, []string{channel}}) } - - s.channels[channel].clients[cclient] = chancount - s.channels[channel].Unlock() + } else { + continue } + + ch.clients.Set(cclient, chancount) } } @@ -215,35 +240,49 @@ func (s *Server) sendTopic(channel string, client string, changed bool) { return } - if s.channels[channel].topic != "" { + ch := s.getChannel(channel) + cl := s.getClient(client) + + if ch == nil || cl == nil { + return + } + + if ch.topic != "" { tprefix := anonymous tcommand := irc.TOPIC if !changed { tprefix = anonirc tcommand = irc.RPL_TOPIC } - s.clients[client].write(&irc.Message{&tprefix, tcommand, []string{channel, s.channels[channel].topic}}) + cl.write(&irc.Message{&tprefix, tcommand, []string{channel, ch.topic}}) if !changed { - s.clients[client].write(&irc.Message{&anonirc, strings.Join([]string{irc.RPL_TOPICWHOTIME, s.clients[client].nick, channel, "Anonymous", fmt.Sprintf("%d", s.channels[channel].topictime)}, " "), nil}) + cl.write(&irc.Message{&anonirc, strings.Join([]string{irc.RPL_TOPICWHOTIME, cl.nick, channel, "Anonymous", fmt.Sprintf("%d", ch.topictime)}, " "), nil}) } } } func (s *Server) handleTopic(channel string, client string, topic string) { + ch := s.getChannel(channel) + cl := s.getClient(client) + + if ch == nil || cl == nil { + return + } + if !s.inChannel(channel, client) { - s.clients[client].sendNotice("Invalid use of TOPIC") + cl.sendNotice("Invalid use of TOPIC") return } if topic != "" { - s.channels[channel].Lock() - s.channels[channel].topic = topic - s.channels[channel].topictime = time.Now().Unix() - s.channels[channel].Unlock() + ch.Lock() + ch.topic = topic + ch.topictime = time.Now().Unix() + ch.Unlock() - for sclient := range s.channels[channel].clients { - s.sendTopic(channel, sclient, true) + for cls := range ch.clients.IterBuffered() { + s.sendTopic(channel, cls.Key, true) } } else { s.sendTopic(channel, client, false) @@ -257,34 +296,35 @@ func (s *Server) handleMode(c *Client, params []string) { } if params[0][0] == '#' { - if !s.channelExists(params[0]) { + ch := s.getChannel(params[0]) + + if ch == nil { return } - channel := s.channels[params[0]] if len(params) == 1 || params[1] == "" { - c.write(&irc.Message{&anonirc, strings.Join([]string{irc.RPL_CHANNELMODEIS, c.nick, params[0], channel.printModes(channel.modes, nil)}, " "), []string{}}) + c.write(&irc.Message{&anonirc, strings.Join([]string{irc.RPL_CHANNELMODEIS, c.nick, params[0], ch.printModes(ch.getModes(), nil)}, " "), []string{}}) // Send channel creation time - c.write(&irc.Message{&anonirc, strings.Join([]string{"329", c.nick, params[0], fmt.Sprintf("%d", int32(channel.created))}, " "), []string{}}) + c.write(&irc.Message{&anonirc, strings.Join([]string{"329", c.nick, params[0], fmt.Sprintf("%d", int32(ch.created))}, " "), []string{}}) } else if len(params) > 1 && len(params[1]) > 0 && (params[1][0] == '+' || params[1][0] == '-') { lastmodes := make(map[string]string) - for mode, modevalue := range channel.modes { - lastmodes[mode] = modevalue + for ms := range ch.modes.IterBuffered() { + lastmodes[ms.Key] = ms.Val.(string) } - channel.Lock() + ch.Lock() if params[1][0] == '+' { - channel.addModes(params[1][1:]) + ch.addModes(params[1][1:]) } else { - channel.removeModes(params[1][1:]) + ch.removeModes(params[1][1:]) } - channel.Unlock() + ch.Unlock() s.enforceModes(params[0]) - if !reflect.DeepEqual(channel.modes, lastmodes) { + if !reflect.DeepEqual(ch.modes.Items(), lastmodes) { // TODO: Check if local modes were set/unset, only send changes to local client - addedmodes, removedmodes := channel.diffModes(lastmodes) + addedmodes, removedmodes := ch.diffModes(lastmodes) resendusercount := false if _, ok := addedmodes["c"]; ok { @@ -295,28 +335,29 @@ func (s *Server) handleMode(c *Client, params []string) { } if len(addedmodes) == 0 && len(removedmodes) == 0 { - addedmodes = c.modes + addedmodes = c.getModes() } - for sclient := range channel.clients { - s.clients[sclient].write(&irc.Message{&anonymous, irc.MODE, []string{params[0], channel.printModes(addedmodes, removedmodes)}}) + for cls := range ch.clients.IterBuffered() { + cl := s.getClient(cls.Key) + + if cl != nil { + cl.write(&irc.Message{&anonymous, irc.MODE, []string{params[0], ch.printModes(addedmodes, removedmodes)}}) + } } if resendusercount { - s.updateClientCount(params[0], "") + s.updateClientCount(params[0]) } } } } else { if len(params) == 1 || params[1] == "" { - c.write(&irc.Message{&anonirc, strings.Join([]string{irc.RPL_UMODEIS, c.nick, c.printModes(c.modes, nil)}, " "), []string{}}) + c.write(&irc.Message{&anonirc, strings.Join([]string{irc.RPL_UMODEIS, c.nick, c.printModes(c.getModes(), nil)}, " "), []string{}}) return } - lastmodes := make(map[string]string) - for mode, modevalue := range c.modes { - lastmodes[mode] = modevalue - } + lastmodes := c.getModes() if len(params) > 1 && len(params[1]) > 0 && (params[1][0] == '+' || params[1][0] == '-') { c.Lock() @@ -340,14 +381,14 @@ func (s *Server) handleMode(c *Client, params []string) { } if len(addedmodes) == 0 && len(removedmodes) == 0 { - addedmodes = c.modes + addedmodes = c.getModes() } c.write(&irc.Message{&anonirc, strings.Join([]string{irc.MODE, c.nick}, " "), []string{c.printModes(addedmodes, removedmodes)}}) if resendusercount { for ch := range s.getChannels(c.identifier) { - s.updateClientCount(ch, c.identifier) + s.updateClientCount(ch) } } } @@ -359,9 +400,17 @@ func (s *Server) handlePrivmsg(channel string, client string, message string) { return // Not in channel } - for sclient := range s.channels[channel].clients { - if s.clients[sclient].identifier != client { - s.clients[sclient].write(&irc.Message{&anonymous, irc.PRIVMSG, []string{channel, message}}) + ch := s.getChannel(channel) + + if ch == nil { + return + } + + for cls := range ch.clients.IterBuffered() { + ccl := s.getClient(cls.Key) + + if ccl.identifier != client { + ccl.write(&irc.Message{&anonymous, irc.PRIVMSG, []string{channel, message}}) } } } @@ -369,7 +418,14 @@ func (s *Server) handlePrivmsg(channel string, client string, message string) { func (s *Server) handleRead(c *Client) { for { c.conn.SetDeadline(time.Now().Add(300 * time.Second)) - msg, err := s.clients[c.identifier].reader.Decode() + + cl := s.getClient(c.identifier) + + if cl == nil { + return + } + + msg, err := cl.reader.Decode() if err != nil { log.Println("Unable to read from client:", err) s.partAllChannels(c.identifier) @@ -419,18 +475,20 @@ func (s *Server) handleRead(c *Client) { s.joinChannel("#", c.identifier) } else if msg.Command == irc.LIST { - var ccount int chans := make(map[string]int) - for channelname, channel := range s.channels { - if !channel.hasMode("p") && !channel.hasMode("s") { - ccount = s.getClientCount(channelname, c.identifier) - chans[channelname] = ccount + for chs := range s.channels.IterBuffered() { + ch := s.getChannel(chs.Key) + + if ch != nil && !ch.hasMode("p") && !ch.hasMode("s") { + chans[chs.Key] = s.getClientCount(chs.Key, c.identifier) } } c.write(&irc.Message{&anonirc, irc.RPL_LISTSTART, []string{"Channel", "Users Name"}}) for _, pl := range sortMapByValues(chans) { - c.write(&irc.Message{&anonirc, irc.RPL_LIST, []string{pl.Key, strconv.Itoa(pl.Value), "[" + s.channels[pl.Key].printModes(s.channels[pl.Key].modes, nil) + "] " + s.channels[pl.Key].topic}}) + ch := s.getChannel(pl.Key) + + c.write(&irc.Message{&anonirc, irc.RPL_LIST, []string{pl.Key, strconv.Itoa(pl.Value), "[" + ch.printModes(ch.getModes(), nil) + "] " + ch.topic}}) } c.write(&irc.Message{&anonirc, irc.RPL_LISTEND, []string{"End of /LIST"}}) } else if msg.Command == irc.JOIN && len(msg.Params) > 0 && len(msg.Params[0]) > 0 && msg.Params[0][0] == '#' { @@ -474,6 +532,7 @@ func (s *Server) handleRead(c *Client) { s.partChannel(channel, c.identifier, "") } } else if msg.Command == irc.QUIT { + c.conn.Close() s.partAllChannels(c.identifier) } } @@ -481,23 +540,24 @@ func (s *Server) handleRead(c *Client) { func (s *Server) handleConnection(conn net.Conn, ssl bool) { defer conn.Close() - var identifier string + + s.Lock() + for { identifier = randomIdentifier() - if _, ok := s.clients[identifier]; !ok { + if !s.clients.Has(identifier) { break } } - client := Client{Entity{ENTITY_CLIENT, identifier, time.Now().Unix(), make(map[string]string), new(sync.RWMutex)}, ssl, "*", "", "", conn, make(chan *irc.Message), irc.NewDecoder(conn), irc.NewEncoder(conn), false} + client := &Client{Entity{ENTITY_CLIENT, identifier, time.Now().Unix(), cmap.New(), new(sync.RWMutex)}, ssl, "*", "", "", conn, make(chan *irc.Message), irc.NewDecoder(conn), irc.NewEncoder(conn), false} + s.clients.Set(client.identifier, client) - s.Lock() - s.clients[client.identifier] = &client s.Unlock() go client.handleWrite() - s.handleRead(&client) + s.handleRead(client) } func (s *Server) listenPlain() { @@ -571,8 +631,12 @@ func (s *Server) listenSSL() { func (s *Server) pingClients() { for { s.Lock() - for _, c := range s.clients { - c.write(&irc.Message{nil, irc.PING, []string{fmt.Sprintf("anonirc%d%d", int32(time.Now().Unix()), rand.Intn(1000))}}) + for cls := range s.clients.IterBuffered() { + cl := s.getClient(cls.Key) + + if cl != nil { + cl.write(&irc.Message{nil, irc.PING, []string{fmt.Sprintf("anonirc%d%d", int32(time.Now().Unix()), rand.Intn(1000))}}) + } } s.Unlock() time.Sleep(90 * time.Second)