Use new (Go 1.9+) concurrent sync.Map

main
Trevor Slocum 2017-09-12 22:51:51 -07:00
parent 66b6c5630c
commit 1aebbcefd5
2 changed files with 87 additions and 65 deletions

View File

@ -1,11 +1,13 @@
package main
import "github.com/orcaman/concurrent-map"
import (
"sync"
)
type Channel struct {
Entity
clients cmap.ConcurrentMap
clients *sync.Map
topic string
topictime int64

146
server.go
View File

@ -16,7 +16,7 @@ import (
"time"
"github.com/BurntSushi/toml"
cmap "github.com/orcaman/concurrent-map"
"github.com/orcaman/concurrent-map"
irc "gopkg.in/sorcix/irc.v2"
)
@ -29,8 +29,8 @@ type Config struct {
type Server struct {
config *Config
created int64
clients cmap.ConcurrentMap
channels cmap.ConcurrentMap
clients *sync.Map
channels *sync.Map
odyssey *os.File
odysseymutex *sync.RWMutex
@ -43,8 +43,8 @@ type Server struct {
func NewServer() *Server {
s := &Server{}
s.created = time.Now().Unix()
s.clients = cmap.New()
s.channels = cmap.New()
s.clients = new(sync.Map)
s.channels = new(sync.Map)
s.restartplain = make(chan bool, 1)
s.restartssl = make(chan bool, 1)
@ -62,7 +62,7 @@ func (s *Server) getAnonymousPrefix(i int) *irc.Prefix {
}
func (s *Server) getChannel(channel string) *Channel {
if ch, ok := s.channels.Get(channel); ok {
if ch, ok := s.channels.Load(channel); ok {
return ch.(*Channel)
}
@ -71,16 +71,21 @@ func (s *Server) getChannel(channel string) *Channel {
func (s *Server) getChannels(client string) map[string]*Channel {
channels := make(map[string]*Channel)
for chs := range s.channels.IterBuffered() {
if s.inChannel(chs.Key, client) {
channels[chs.Key] = chs.Val.(*Channel)
s.channels.Range(func(k, v interface{}) bool {
key := k.(string)
channel := v.(*Channel)
if s.inChannel(key, client) {
channels[key] = channel
}
}
return true
})
return channels
}
func (s *Server) getClient(client string) *Client {
if cl, ok := s.clients.Get(client); ok {
if cl, ok := s.channels.Load(client); ok {
return cl.(*Client)
}
@ -92,15 +97,20 @@ func (s *Server) getClients(channel string) map[string]*Client {
ch := s.getChannel(channel)
for cls := range ch.clients.IterBuffered() {
clients[cls.Key] = cls.Val.(*Client)
}
ch.clients.Range(func(k, v interface{}) bool {
cl := s.getClient(k.(string))
if cl != nil {
clients[cl.identifier] = cl
}
return true
})
return clients
}
func (s *Server) channelExists(channel string) bool {
if _, ok := s.channels.Get(channel); ok {
if _, ok := s.channels.Load(channel); ok {
return true
}
@ -110,7 +120,8 @@ func (s *Server) channelExists(channel string) bool {
func (s *Server) inChannel(channel string, client string) bool {
ch := s.getChannel(channel)
if ch != nil {
return ch.clients.Has(client)
_, ok := ch.clients.Load(client)
return ok
}
return false
@ -129,14 +140,14 @@ func (s *Server) joinChannel(channel string, client string) {
}
if ch == nil {
ch = &Channel{Entity{ENTITY_CHANNEL, channel, time.Now().Unix(), ENTITY_STATE_NORMAL, cmap.New(), new(sync.RWMutex)}, cmap.New(), "", 0}
s.channels.Set(channel, ch)
ch = &Channel{Entity{ENTITY_CHANNEL, channel, time.Now().Unix(), ENTITY_STATE_NORMAL, cmap.New(), new(sync.RWMutex)}, new(sync.Map), "", 0}
s.channels.Store(channel, ch)
} else if ch.hasMode("z") && !cl.ssl {
cl.sendNotice("Unable to join " + channel + ": SSL connections only (channel mode +z)")
cl.sendNotice("Unable to join " + channel + ": SSL connections only (channel mode +z)")///
return
}
ch.clients.Set(client, s.getClientCount(channel, client)+1)
ch.clients.Store(client, s.getClientCount(channel, client)+1)
cl.write(&irc.Message{cl.getPrefix(), irc.JOIN, []string{channel}})
s.sendNames(channel, client)
@ -153,7 +164,7 @@ func (s *Server) partChannel(channel string, client string, reason string) {
}
cl.write(&irc.Message{cl.getPrefix(), irc.PART, []string{channel, reason}})
ch.clients.Remove(client)
ch.clients.Delete(client)
s.updateClientCount(channel, client)
}
@ -184,7 +195,11 @@ func (s *Server) getClientCount(channel string, client string) int {
return 0
}
ccount := ch.clients.Count()
ccount := 0
ch.clients.Range(func(k, v interface{}) bool {
ccount++
return true
})
if (ch.hasMode("c") || cl.hasMode("c")) && ccount >= 2 {
return 2
@ -200,38 +215,36 @@ func (s *Server) updateClientCount(channel string, client string) {
return
}
for cls := range ch.clients.IterBuffered() {
cclient := cls.Key
ccount := cls.Val.(int)
if client != "" && ch.hasMode("D") && cclient != client {
continue
}
cl := s.getClient(cclient)
ch.clients.Range(func(k, v interface{}) bool {
cl := s.getClient(k.(string))
ccount := v.(int)
if cl == nil {
continue
return true
} else if client != "" && ch.hasMode("D") && cl.identifier != client {
return true
}
chancount := s.getClientCount(channel, cclient)
chancount := s.getClientCount(channel, cl.identifier)
if ccount < chancount {
for i := ccount; i < chancount; i++ {
cl.write(&irc.Message{s.getAnonymousPrefix(i), irc.JOIN, []string{channel}})
}
ch.clients.Set(cclient, chancount)
ch.clients.Store(cl.identifier, chancount)
} else if ccount > chancount {
for i := ccount; i > chancount; i-- {
cl.write(&irc.Message{s.getAnonymousPrefix(i - 1), irc.PART, []string{channel}})
}
} else {
continue
return true
}
ch.clients.Set(cclient, chancount)
}
ch.clients.Store(cl.identifier, chancount)
return true
})
}
func (s *Server) sendNames(channel string, clientname string) {
@ -307,9 +320,10 @@ func (s *Server) handleTopic(channel string, client string, topic string) {
ch.topic = topic
ch.topictime = time.Now().Unix()
for cls := range ch.clients.IterBuffered() {
s.sendTopic(channel, cls.Key, true)
}
ch.clients.Range(func(k, v interface{}) bool {
s.sendTopic(channel, k.(string), true)
return true
})
} else {
s.sendTopic(channel, client, false)
}
@ -365,13 +379,14 @@ func (s *Server) handleMode(c *Client, params []string) {
addedmodes = c.getModes()
}
for cls := range ch.clients.IterBuffered() {
cl := s.getClient(cls.Key)
ch.clients.Range(func(k, v interface{}) bool {
cl := s.getClient(k.(string))
if cl != nil {
cl.write(&irc.Message{&anonymous, irc.MODE, []string{params[0], ch.printModes(addedmodes, removedmodes)}})
}
}
return true
})
if resendusercount {
s.updateClientCount(params[0], c.identifier)
@ -436,20 +451,21 @@ func (s *Server) handlePrivmsg(channel string, client string, message string) {
s.updateClientCount(channel, "")
for cls := range ch.clients.IterBuffered() {
ccl := s.getClient(cls.Key)
if ccl != nil && ccl.identifier != client {
ccl.write(&irc.Message{&anonymous, irc.PRIVMSG, []string{channel, message}})
ch.clients.Range(func(k, v interface{}) bool {
cl := s.getClient(k.(string))
if cl != nil && cl.identifier != client {
cl.write(&irc.Message{&anonymous, irc.PRIVMSG, []string{channel, message}})
}
}
return true
})
}
func (s *Server) handleRead(c *Client) {
for {
c.conn.SetDeadline(time.Now().Add(300 * time.Second))
if !s.clients.Has(c.identifier) {
if _, ok := s.clients.Load(c.identifier); !ok {
s.killClient(c)
return
}
@ -533,13 +549,16 @@ func (s *Server) handleRead(c *Client) {
}
} else if msg.Command == irc.LIST {
chans := make(map[string]int)
for chs := range s.channels.IterBuffered() {
ch := s.getChannel(chs.Key)
s.channels.Range(func(k, v interface{}) bool {
key := k.(string)
ch := v.(*Channel)
if ch != nil && !ch.hasMode("p") && !ch.hasMode("s") {
chans[chs.Key] = s.getClientCount(chs.Key, c.identifier)
chans[key] = s.getClientCount(key, c.identifier)
}
}
return true
})
c.write(&irc.Message{&anonirc, irc.RPL_LISTSTART, []string{"Channel", "Users Name"}})
for _, pl := range sortMapByValues(chans) {
@ -624,20 +643,20 @@ func (s *Server) handleConnection(conn net.Conn, ssl bool) {
for {
identifier = randomIdentifier()
if !s.clients.Has(identifier) {
if _, ok := s.clients.Load(identifier); !ok {
break
}
}
client := &Client{Entity{ENTITY_CLIENT, identifier, time.Now().Unix(), ENTITY_STATE_NORMAL, cmap.New(), new(sync.RWMutex)}, ssl, "*", "", "", conn, make(chan *irc.Message, writebuffersize), irc.NewDecoder(conn), irc.NewEncoder(conn), false}
s.clients.Set(client.identifier, client)
s.clients.Store(client.identifier, client)
go s.handleWrite(client)
s.handleRead(client)
s.killClient(client)
close(client.writebuffer)
s.clients.Remove(identifier)
s.clients.Delete(identifier)
}
func (s *Server) killClient(c *Client) {
@ -648,7 +667,7 @@ func (s *Server) killClient(c *Client) {
c.write(nil)
c.conn.Close()
if s.clients.Has(c.identifier) {
if _, ok := s.clients.Load(c.identifier); ok {
s.partAllChannels(c.identifier)
}
}
@ -723,13 +742,14 @@ func (s *Server) listenSSL() {
func (s *Server) pingClients() {
for {
for cls := range s.clients.IterBuffered() {
cl := s.getClient(cls.Key)
s.clients.Range(func(k, v interface{}) bool {
cl := v.(*Client)
if cl != nil {
cl.write(&irc.Message{nil, irc.PING, []string{fmt.Sprintf("anonirc%d%d", int32(time.Now().Unix()), rand.Intn(1000))}})
}
}
return true
})
time.Sleep(90 * time.Second)
}
}