Add accounts, server/channel moderation
parent
1aeaeeffd9
commit
3c9d37f7fb
|
@ -7,6 +7,42 @@
|
|||
revision = "b26d9c308763d68093482582cea63d69be07a0f0"
|
||||
version = "v0.3.0"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/gorilla/securecookie"
|
||||
packages = ["."]
|
||||
revision = "667fe4e3466a040b780561fe9b51a83a3753eefc"
|
||||
version = "v1.1"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/jessevdk/go-flags"
|
||||
packages = ["."]
|
||||
revision = "96dc06278ce32a0e9d957d590bb987c81ee66407"
|
||||
version = "v1.3.0"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/mattn/go-sqlite3"
|
||||
packages = ["."]
|
||||
revision = "ed69081a91fd053f17672236b0dd52ba7485e1a3"
|
||||
version = "v1.4.0"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/pkg/errors"
|
||||
packages = ["."]
|
||||
revision = "645ef00459ed84a119197bfb8d8205042c6df63d"
|
||||
version = "v0.8.0"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/crypto"
|
||||
packages = ["sha3"]
|
||||
revision = "94eea52f7b742c7cbe0b03b22f0c4c8631ece122"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/net"
|
||||
packages = ["context"]
|
||||
revision = "a8b9294777976932365dabb6640cf1468d95c70f"
|
||||
|
||||
[[projects]]
|
||||
branch = "v2"
|
||||
name = "gopkg.in/sorcix/irc.v2"
|
||||
|
@ -16,6 +52,6 @@
|
|||
[solve-meta]
|
||||
analyzer-name = "dep"
|
||||
analyzer-version = 1
|
||||
inputs-digest = "0be9e577dd19c1613689669c289fb1cd6f32b0a47db9d3448a25d405cc1d456d"
|
||||
inputs-digest = "da08912cf0f9aa88d93fa8fb8cf01a421f16b7977841c275f883c0c6821adcca"
|
||||
solver-name = "gps-cdcl"
|
||||
solver-version = 1
|
||||
|
|
97
anonircd.go
97
anonircd.go
|
@ -17,20 +17,22 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sort"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jessevdk/go-flags"
|
||||
irc "gopkg.in/sorcix/irc.v2"
|
||||
)
|
||||
|
||||
var anonymous = irc.Prefix{"Anonymous", "Anon", "IRC"}
|
||||
var anonirc = irc.Prefix{Name: "AnonIRC"}
|
||||
var prefixAnonymous = irc.Prefix{"Anonymous", "Anon", "IRC"}
|
||||
var prefixAnonIRC = irc.Prefix{Name: "AnonIRC"}
|
||||
|
||||
const motd = `
|
||||
_|_| _|_|_| _|_|_| _|_|_|
|
||||
|
@ -42,63 +44,68 @@ _| _| _| _| _|_| _| _| _|_|_| _| _| _|_|_|
|
|||
const letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
const writebuffersize = 10
|
||||
|
||||
type Pair struct {
|
||||
Key string
|
||||
Value int
|
||||
}
|
||||
const (
|
||||
PERMISSION_USER = 0
|
||||
PERMISSION_SUPERADMIN = 1
|
||||
PERMISSION_ADMIN = 2
|
||||
PERMISSION_MODERATOR = 3
|
||||
PERMISSION_VIP = 4
|
||||
)
|
||||
|
||||
type PairList []Pair
|
||||
|
||||
func (p PairList) Len() int {
|
||||
return len(p)
|
||||
}
|
||||
func (p PairList) Less(i, j int) bool {
|
||||
return p[i].Value < p[j].Value
|
||||
}
|
||||
func (p PairList) Swap(i, j int) {
|
||||
p[i], p[j] = p[j], p[i]
|
||||
}
|
||||
|
||||
func sortMapByValues(m map[string]int) PairList {
|
||||
pl := make(PairList, len(m))
|
||||
i := 0
|
||||
for k, v := range m {
|
||||
pl[i] = Pair{k, v}
|
||||
i++
|
||||
}
|
||||
sort.Sort(sort.Reverse(pl))
|
||||
return pl
|
||||
}
|
||||
|
||||
func randomIdentifier() string {
|
||||
b := make([]byte, 10)
|
||||
for i := range b {
|
||||
b[i] = letters[rand.Intn(len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
var debugMode = false
|
||||
var verbose = false
|
||||
|
||||
func main() {
|
||||
rand.Seed(time.Now().UTC().UnixNano())
|
||||
|
||||
server := NewServer()
|
||||
server.loadConfig()
|
||||
var opts struct {
|
||||
ConfigFile string `short:"c" long:"config" description:"Configuration file"`
|
||||
Debug int `short:"d" long:"debug" description:"Serve pprof data on specified port"`
|
||||
BareLog bool `short:"b" long:"bare-log" description:"Don't add current date/time to log entries"`
|
||||
Verbose bool `short:"v" long:"verbose" description:"Log verbosely"`
|
||||
}
|
||||
|
||||
_, err := flags.Parse(&opts)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if opts.Debug > 0 {
|
||||
debugMode = true
|
||||
log.Printf("WARNING: Running in debug mode. pprof data is available at http://localhost:%d/debug/pprof/", opts.Debug)
|
||||
go http.ListenAndServe(fmt.Sprintf("localhost:%d", opts.Debug), nil)
|
||||
}
|
||||
|
||||
if opts.BareLog {
|
||||
log.SetFlags(0)
|
||||
}
|
||||
|
||||
verbose = opts.Verbose
|
||||
|
||||
s := NewServer(opts.ConfigFile)
|
||||
err = s.loadConfig()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s.connectDatabase()
|
||||
defer s.closeDatabase()
|
||||
|
||||
sighup := make(chan os.Signal, 1)
|
||||
signal.Notify(sighup,
|
||||
syscall.SIGHUP)
|
||||
go func() {
|
||||
<-sighup
|
||||
server.reload()
|
||||
err := s.reload()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}()
|
||||
|
||||
var err error
|
||||
server.odyssey, err = os.Open("ODYSSEY")
|
||||
s.odyssey, err = os.Open("ODYSSEY")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer server.odyssey.Close()
|
||||
defer s.odyssey.Close()
|
||||
|
||||
go server.startProfiling()
|
||||
server.listen()
|
||||
s.listen()
|
||||
}
|
||||
|
|
96
channel.go
96
channel.go
|
@ -1,16 +1,40 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
Entity
|
||||
|
||||
clients *sync.Map
|
||||
logs []*ChannelLog
|
||||
|
||||
topic string
|
||||
topictime int64
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
type ChannelLog struct {
|
||||
Timestamp int64
|
||||
Client string
|
||||
IP string
|
||||
Action string
|
||||
Message string
|
||||
}
|
||||
|
||||
const CHANNEL_LOGS_PER_PAGE = 25
|
||||
|
||||
func (cl *ChannelLog) Identifier(index int) string {
|
||||
return fmt.Sprintf("%03d%02d", index+1, cl.Timestamp%100)
|
||||
}
|
||||
|
||||
func (cl *ChannelLog) Print(index int, channel string) string {
|
||||
return strings.TrimSpace(fmt.Sprintf("%s %s %5s %4s %s", time.Unix(0, cl.Timestamp).Format(time.Stamp), channel, cl.Identifier(index), cl.Action, cl.Message))
|
||||
}
|
||||
|
||||
func NewChannel(identifier string) *Channel {
|
||||
|
@ -21,3 +45,75 @@ func NewChannel(identifier string) *Channel {
|
|||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Channel) Log(client *Client, action string, message string) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// TODO: Log size limiting, max capacity will be 998 entries
|
||||
// Log hash of IP address which is used later when connecting/joining
|
||||
|
||||
c.logs = append(c.logs, &ChannelLog{Timestamp: time.Now().UTC().UnixNano(), Client: client.identifier, IP: client.ip, Action: action, Message: message})
|
||||
}
|
||||
|
||||
func (c *Channel) RevealLog(page int) []string {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
// TODO:
|
||||
// Trim old channel logs periodically
|
||||
// Add pagination
|
||||
var ls []string
|
||||
logsRemain := false
|
||||
j := 0
|
||||
for i, l := range c.logs {
|
||||
if page == -1 || i >= (CHANNEL_LOGS_PER_PAGE*(page-1)) {
|
||||
if page > -1 && j == CHANNEL_LOGS_PER_PAGE {
|
||||
logsRemain = true
|
||||
break
|
||||
}
|
||||
ls = append(ls, l.Print(i, c.identifier))
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
if len(ls) == 0 {
|
||||
ls = append(ls, "No log entries match criteria")
|
||||
} else {
|
||||
filterType := "all"
|
||||
if page > -1 {
|
||||
filterType = fmt.Sprintf("page %d", page)
|
||||
}
|
||||
ls = append([]string{fmt.Sprintf("Revealing %s (%s)", c.identifier, filterType)}, ls...)
|
||||
|
||||
finishedMessage := fmt.Sprintf("Finished revealing %s", c.identifier)
|
||||
if logsRemain {
|
||||
finishedMessage = fmt.Sprintf("Additional log entries on page %d", page+1)
|
||||
}
|
||||
ls = append(ls, finishedMessage)
|
||||
}
|
||||
|
||||
return ls
|
||||
}
|
||||
|
||||
func (c *Channel) RevealHash(identifier string) string {
|
||||
if len(identifier) != 5 {
|
||||
return ""
|
||||
}
|
||||
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
for i, l := range c.logs {
|
||||
if l.Identifier(i) == identifier {
|
||||
return l.IP
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *Channel) HasClient(client string) bool {
|
||||
_, ok := c.clients.Load(client)
|
||||
return ok
|
||||
}
|
||||
|
|
57
client.go
57
client.go
|
@ -3,48 +3,93 @@ package main
|
|||
import (
|
||||
"net"
|
||||
|
||||
"sync"
|
||||
|
||||
irc "gopkg.in/sorcix/irc.v2"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Entity
|
||||
ip string
|
||||
|
||||
ssl bool
|
||||
nick string
|
||||
user string
|
||||
host string
|
||||
ssl bool
|
||||
nick string
|
||||
user string
|
||||
host string
|
||||
account int
|
||||
|
||||
conn net.Conn
|
||||
writebuffer chan *irc.Message
|
||||
terminate chan bool
|
||||
|
||||
reader *irc.Decoder
|
||||
writer *irc.Encoder
|
||||
|
||||
capHostInNames bool
|
||||
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewClient(identifier string, conn net.Conn, ssl bool) *Client {
|
||||
c := &Client{}
|
||||
c.Initialize(ENTITY_CLIENT, identifier)
|
||||
|
||||
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.ip = generateHash(ip)
|
||||
// TODO: Check bans, return nil
|
||||
|
||||
c.ssl = ssl
|
||||
c.nick = "*"
|
||||
c.conn = conn
|
||||
c.writebuffer = make(chan *irc.Message, writebuffersize)
|
||||
c.terminate = make(chan bool)
|
||||
c.reader = irc.NewDecoder(conn)
|
||||
c.writer = irc.NewEncoder(conn)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) registered() bool {
|
||||
// TODO
|
||||
return c.account > 0
|
||||
}
|
||||
|
||||
func (c *Client) getPrefix() *irc.Prefix {
|
||||
return &irc.Prefix{Name: c.nick, User: c.user, Host: c.host}
|
||||
}
|
||||
|
||||
func (c *Client) write(msg *irc.Message) {
|
||||
if c.state == ENTITY_STATE_TERMINATING {
|
||||
return
|
||||
}
|
||||
|
||||
c.writebuffer <- msg
|
||||
}
|
||||
|
||||
func (c *Client) sendNotice(notice string) {
|
||||
c.write(&irc.Message{&anonirc, irc.NOTICE, []string{c.nick, "*** " + notice}})
|
||||
func (c *Client) writeMessage(command string, params []string) {
|
||||
c.write(&irc.Message{&prefixAnonIRC, command, params})
|
||||
}
|
||||
|
||||
func (c *Client) sendMessage(message string) {
|
||||
c.writeMessage(irc.PRIVMSG, []string{c.nick, message})
|
||||
}
|
||||
|
||||
func (c *Client) sendPasswordIncorrect() {
|
||||
c.writeMessage(irc.ERR_PASSWDMISMATCH, []string{"Password incorrect"})
|
||||
}
|
||||
|
||||
func (c *Client) sendError(message string) {
|
||||
c.sendMessage("Error! " + message)
|
||||
}
|
||||
|
||||
func (c *Client) sendNotice(message string) {
|
||||
c.sendMessage("*** " + message)
|
||||
}
|
||||
|
||||
func (c *Client) accessDenied() {
|
||||
c.sendNotice("Access denied")
|
||||
}
|
||||
|
|
|
@ -0,0 +1,387 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const DATABASE_VERSION = 1
|
||||
|
||||
var ErrAccountExists = errors.New("account exists")
|
||||
var ErrChannelExists = errors.New("channel exists")
|
||||
|
||||
var tables = map[string][]string{
|
||||
"meta": {
|
||||
"`key` TEXT NULL PRIMARY KEY",
|
||||
"`value` TEXT NULL"},
|
||||
"accounts": {
|
||||
"`id` INTEGER PRIMARY KEY AUTOINCREMENT",
|
||||
"`username` TEXT NULL",
|
||||
"`password` TEXT NULL"},
|
||||
"channels": {
|
||||
"`channel` TEXT PRIMARY KEY",
|
||||
"`topic` TEXT NULL",
|
||||
"`topictime` INTEGER NULL",
|
||||
"`password` TEXT NULL"},
|
||||
"permissions": {
|
||||
"`channel` TEXT NULL",
|
||||
"`account` INTEGER NULL",
|
||||
"`permission` INTEGER NULL"},
|
||||
"bans": {
|
||||
"`channel` TEXT NULL",
|
||||
"`type` INTEGER NULL",
|
||||
"`target` TEXT NULL",
|
||||
"`expires` INTEGER NULL",
|
||||
"`reason` TEXT NULL"}}
|
||||
|
||||
type DBAccount struct {
|
||||
ID int
|
||||
Username string
|
||||
Permission int
|
||||
}
|
||||
|
||||
type DBChannel struct {
|
||||
Channel string
|
||||
Topic string
|
||||
TopicTime int
|
||||
Password string
|
||||
}
|
||||
|
||||
type DBPermission struct {
|
||||
Channel string
|
||||
Account int
|
||||
Permission int
|
||||
}
|
||||
|
||||
type DBBan struct {
|
||||
Channel string
|
||||
Type int
|
||||
Target string
|
||||
Expires int
|
||||
Reason string
|
||||
}
|
||||
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (d *Database) Connect(driver string, dataSource string) error {
|
||||
var err error
|
||||
d.db, err = sql.Open(driver, dataSource)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to connect to %s database", driver)
|
||||
}
|
||||
|
||||
err = d.CreateTables()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create tables")
|
||||
}
|
||||
|
||||
err = d.Migrate()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to migrate database")
|
||||
}
|
||||
|
||||
err = d.Initialize()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to initialize database")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Database) CreateTables() error {
|
||||
for tname, tcolumns := range tables {
|
||||
_, err := d.db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` (%s)", tname, strings.Join(tcolumns, ",")))
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to create %s table", tname)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) Migrate() error {
|
||||
rows, err := d.db.Query("SELECT `value` FROM meta WHERE `key`=? LIMIT 1", "version")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to fetch database version")
|
||||
}
|
||||
|
||||
version := 0
|
||||
for rows.Next() {
|
||||
v := ""
|
||||
err = rows.Scan(&v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to fetch database version")
|
||||
}
|
||||
|
||||
version, err = strconv.Atoi(v)
|
||||
if err != nil {
|
||||
version = -1
|
||||
}
|
||||
}
|
||||
|
||||
if version == -1 {
|
||||
panic("Unable to migrate database: database version unknown")
|
||||
} else if version == 0 {
|
||||
_, err := d.db.Exec("UPDATE meta SET `value`=? WHERE `key`=?", strconv.Itoa(DATABASE_VERSION), "version")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to save database version")
|
||||
}
|
||||
} else if version < DATABASE_VERSION {
|
||||
// DATABASE_VERSION 2 migration queries will go here
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) Initialize() error {
|
||||
username := ""
|
||||
err := d.db.QueryRow("SELECT username FROM accounts").Scan(&username)
|
||||
if err == sql.ErrNoRows {
|
||||
err := d.AddAccount("admin", "password")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create first account")
|
||||
}
|
||||
|
||||
ac := &DBChannel{Channel: "&", Topic: "Secret Area of VIP Quality"}
|
||||
d.AddChannel(1, ac)
|
||||
|
||||
uc := &DBChannel{Channel: "#", Topic: "Welcome to AnonIRC"}
|
||||
d.AddChannel(1, uc)
|
||||
} else if err != nil {
|
||||
return errors.Wrap(err, "failed to check for first account")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) Close() error {
|
||||
err := d.db.Close()
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "failed to close database")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Accounts
|
||||
|
||||
func (d *Database) Account(id int) (*DBAccount, error) {
|
||||
rows, err := d.db.Query("SELECT id, username FROM accounts WHERE id=?", id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to fetch account")
|
||||
}
|
||||
|
||||
var a *DBAccount
|
||||
for rows.Next() {
|
||||
a = new(DBAccount)
|
||||
err = rows.Scan(&a.ID, &a.Username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to scan account")
|
||||
}
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (d *Database) AccountU(username string) (*DBAccount, error) {
|
||||
rows, err := d.db.Query("SELECT id, username FROM accounts WHERE username=?", generateHash(username))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to fetch account by username")
|
||||
}
|
||||
|
||||
var a *DBAccount
|
||||
for rows.Next() {
|
||||
a = new(DBAccount)
|
||||
err = rows.Scan(&a.ID, &a.Username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to scan account")
|
||||
}
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// TODO: Lockout on too many failed attempts
|
||||
func (d *Database) Auth(username string, password string) (int, error) {
|
||||
// TODO: Salt in config
|
||||
rows, err := d.db.Query("SELECT id FROM accounts WHERE username=? AND password=?", generateHash(username), generateHash(username+"-"+password))
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to authenticate account")
|
||||
}
|
||||
|
||||
accountid := 0
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&accountid)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to authenticate account")
|
||||
}
|
||||
}
|
||||
|
||||
return accountid, nil
|
||||
}
|
||||
|
||||
func (d *Database) GenerateToken() string {
|
||||
return base64.URLEncoding.EncodeToString(securecookie.GenerateRandomKey(64))
|
||||
}
|
||||
|
||||
func (d *Database) AddAccount(username string, password string) error {
|
||||
ex, err := d.AccountU(username)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to search for existing account while adding account")
|
||||
} else if ex != nil {
|
||||
return ErrAccountExists
|
||||
}
|
||||
|
||||
_, err = d.db.Exec("INSERT INTO accounts (username, password) VALUES (?, ?)", generateHash(username), generateHash(username+"-"+password))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to add account")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) SetUsername(accountid int, username string, password string) error {
|
||||
ex, err := d.AccountU(username)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to search for existing account while setting username")
|
||||
} else if ex != nil {
|
||||
return ErrAccountExists
|
||||
}
|
||||
|
||||
_, err = d.db.Exec("UPDATE accounts SET username=?, password=? WHERE id=?", generateHash(username), generateHash(username+"-"+password), accountid)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to set username")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) SetPassword(accountid int, username string, password string) error {
|
||||
_, err := d.db.Exec("UPDATE accounts SET password=? WHERE id=?", generateHash(username+"-"+password), accountid)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to set password")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Channels
|
||||
|
||||
func (d *Database) ChannelID(id int) (*DBChannel, error) {
|
||||
rows, err := d.db.Query("SELECT channel, topic FROM channels WHERE id=?", id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to fetch channel")
|
||||
}
|
||||
|
||||
var c *DBChannel
|
||||
for rows.Next() {
|
||||
c = new(DBChannel)
|
||||
err = rows.Scan(&c.Channel, &c.Topic)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to scan channel")
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (d *Database) Channel(channel string) (*DBChannel, error) {
|
||||
rows, err := d.db.Query("SELECT channel, topic FROM channels WHERE channel=?", generateHash(channel))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to fetch channel by key")
|
||||
}
|
||||
|
||||
var c *DBChannel
|
||||
for rows.Next() {
|
||||
c = new(DBChannel)
|
||||
err = rows.Scan(&c.Channel, &c.Topic)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to scan channel")
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (d *Database) AddChannel(accountid int, channel *DBChannel) error {
|
||||
ex, err := d.Channel(channel.Channel)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to search for existing channel while adding channel")
|
||||
} else if ex != nil {
|
||||
return ErrChannelExists
|
||||
}
|
||||
|
||||
chch := channel.Channel
|
||||
channel.Channel = generateHash(strings.ToLower(channel.Channel))
|
||||
_, err = d.db.Exec("INSERT INTO channels (channel, topic, topictime, password) VALUES (?, ?, ?, ?)", channel.Channel, channel.Topic, channel.TopicTime, channel.Password)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to add channel")
|
||||
}
|
||||
|
||||
err = d.SetPermission(accountid, chch, PERMISSION_SUPERADMIN)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to set permission on newly added channel")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func (d *Database) GetPermission(accountid int, channel string) (int, error) {
|
||||
rows, err := d.db.Query("SELECT permission FROM permissions WHERE account=? AND channel=?", accountid, generateHash(channel))
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to authenticate account")
|
||||
}
|
||||
|
||||
permission := PERMISSION_USER
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&permission)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to authenticate account")
|
||||
}
|
||||
}
|
||||
|
||||
return permission, nil
|
||||
}
|
||||
|
||||
func (d *Database) SetPermission(accountid int, channel string, permission int) error {
|
||||
acc, err := d.Account(accountid)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if acc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, err := d.Channel(channel)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to fetch channel while setting permission")
|
||||
} else if ch == nil {
|
||||
return nil
|
||||
}
|
||||
chh := generateHash(channel)
|
||||
|
||||
rows, err := d.db.Query("SELECT permission FROM permissions WHERE account=? AND channel=?", accountid, chh)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to set permission")
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
_, err = d.db.Exec("INSERT INTO permissions (channel, account, permission) VALUES (?, ?, ?)", chh, accountid, permission)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to set permission")
|
||||
}
|
||||
} else {
|
||||
_, err = d.db.Exec("UPDATE permissions SET permission=? WHERE account=? AND channel=?", permission, accountid, chh)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to set permission")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -94,8 +94,8 @@ func (e *Entity) diffModes(lastmodes map[string]string) (map[string]string, map[
|
|||
|
||||
removedmodes := make(map[string]string)
|
||||
for mode := range lastmodes {
|
||||
if m, ok := e.modes.Load(mode); ok {
|
||||
removedmodes[mode] = m.(string)
|
||||
if _, ok := e.modes.Load(mode); !ok {
|
||||
removedmodes[mode] = mode
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"crypto/md5"
|
||||
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
type Pair struct {
|
||||
Key string
|
||||
Value int
|
||||
}
|
||||
|
||||
type PairList []Pair
|
||||
|
||||
func (p PairList) Len() int {
|
||||
return len(p)
|
||||
}
|
||||
func (p PairList) Less(i, j int) bool {
|
||||
return p[i].Value < p[j].Value
|
||||
}
|
||||
func (p PairList) Swap(i, j int) {
|
||||
p[i], p[j] = p[j], p[i]
|
||||
}
|
||||
|
||||
func sortMapByValues(m map[string]int) PairList {
|
||||
pl := make(PairList, len(m))
|
||||
i := 0
|
||||
for k, v := range m {
|
||||
pl[i] = Pair{k, v}
|
||||
i++
|
||||
}
|
||||
sort.Sort(sort.Reverse(pl))
|
||||
return pl
|
||||
}
|
||||
|
||||
func randomIdentifier() string {
|
||||
b := make([]byte, 10)
|
||||
for i := range b {
|
||||
b[i] = letters[rand.Intn(len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func validChannelPrefix(channel string) bool {
|
||||
return channel[0] == '&' || channel[0] == '#'
|
||||
}
|
||||
|
||||
func generateHash(s string) string {
|
||||
sha512 := sha3.New512()
|
||||
_, err := sha512.Write([]byte(strings.Join([]string{s, fmt.Sprintf("%x", md5.Sum([]byte(s))), s}, "-")))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return base64.URLEncoding.EncodeToString(sha512.Sum(nil))
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
language: go
|
||||
sudo: false
|
||||
|
||||
matrix:
|
||||
include:
|
||||
- go: 1.3
|
||||
- go: 1.4
|
||||
- go: 1.5
|
||||
- go: 1.6
|
||||
- go: tip
|
||||
allow_failures:
|
||||
- go: tip
|
||||
|
||||
script:
|
||||
- go get -t -v ./...
|
||||
- diff -u <(echo -n) <(gofmt -d .)
|
||||
- go vet $(go list ./... | grep -v /vendor/)
|
||||
- go test -v -race ./...
|
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,76 @@
|
|||
securecookie
|
||||
============
|
||||
[](https://godoc.org/github.com/gorilla/securecookie) [](https://travis-ci.org/gorilla/securecookie)
|
||||
|
||||
securecookie encodes and decodes authenticated and optionally encrypted
|
||||
cookie values.
|
||||
|
||||
Secure cookies can't be forged, because their values are validated using HMAC.
|
||||
When encrypted, the content is also inaccessible to malicious eyes. It is still
|
||||
recommended that sensitive data not be stored in cookies, and that HTTPS be used
|
||||
to prevent cookie [replay attacks](https://en.wikipedia.org/wiki/Replay_attack).
|
||||
|
||||
## Examples
|
||||
|
||||
To use it, first create a new SecureCookie instance:
|
||||
|
||||
```go
|
||||
// Hash keys should be at least 32 bytes long
|
||||
var hashKey = []byte("very-secret")
|
||||
// Block keys should be 16 bytes (AES-128) or 32 bytes (AES-256) long.
|
||||
// Shorter keys may weaken the encryption used.
|
||||
var blockKey = []byte("a-lot-secret")
|
||||
var s = securecookie.New(hashKey, blockKey)
|
||||
```
|
||||
|
||||
The hashKey is required, used to authenticate the cookie value using HMAC.
|
||||
It is recommended to use a key with 32 or 64 bytes.
|
||||
|
||||
The blockKey is optional, used to encrypt the cookie value -- set it to nil
|
||||
to not use encryption. If set, the length must correspond to the block size
|
||||
of the encryption algorithm. For AES, used by default, valid lengths are
|
||||
16, 24, or 32 bytes to select AES-128, AES-192, or AES-256.
|
||||
|
||||
Strong keys can be created using the convenience function GenerateRandomKey().
|
||||
|
||||
Once a SecureCookie instance is set, use it to encode a cookie value:
|
||||
|
||||
```go
|
||||
func SetCookieHandler(w http.ResponseWriter, r *http.Request) {
|
||||
value := map[string]string{
|
||||
"foo": "bar",
|
||||
}
|
||||
if encoded, err := s.Encode("cookie-name", value); err == nil {
|
||||
cookie := &http.Cookie{
|
||||
Name: "cookie-name",
|
||||
Value: encoded,
|
||||
Path: "/",
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Later, use the same SecureCookie instance to decode and validate a cookie
|
||||
value:
|
||||
|
||||
```go
|
||||
func ReadCookieHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if cookie, err := r.Cookie("cookie-name"); err == nil {
|
||||
value := make(map[string]string)
|
||||
if err = s2.Decode("cookie-name", cookie.Value, &value); err == nil {
|
||||
fmt.Fprintf(w, "The value of foo is %q", value["foo"])
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
We stored a map[string]string, but secure cookies can hold any value that
|
||||
can be encoded using `encoding/gob`. To store custom types, they must be
|
||||
registered first using gob.Register(). For basic types this is not needed;
|
||||
it works out of the box. An optional JSON encoder that uses `encoding/json` is
|
||||
available for types compatible with JSON.
|
||||
|
||||
## License
|
||||
|
||||
BSD licensed. See the LICENSE file for details.
|
|
@ -0,0 +1,61 @@
|
|||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package securecookie encodes and decodes authenticated and optionally
|
||||
encrypted cookie values.
|
||||
|
||||
Secure cookies can't be forged, because their values are validated using HMAC.
|
||||
When encrypted, the content is also inaccessible to malicious eyes.
|
||||
|
||||
To use it, first create a new SecureCookie instance:
|
||||
|
||||
var hashKey = []byte("very-secret")
|
||||
var blockKey = []byte("a-lot-secret")
|
||||
var s = securecookie.New(hashKey, blockKey)
|
||||
|
||||
The hashKey is required, used to authenticate the cookie value using HMAC.
|
||||
It is recommended to use a key with 32 or 64 bytes.
|
||||
|
||||
The blockKey is optional, used to encrypt the cookie value -- set it to nil
|
||||
to not use encryption. If set, the length must correspond to the block size
|
||||
of the encryption algorithm. For AES, used by default, valid lengths are
|
||||
16, 24, or 32 bytes to select AES-128, AES-192, or AES-256.
|
||||
|
||||
Strong keys can be created using the convenience function GenerateRandomKey().
|
||||
|
||||
Once a SecureCookie instance is set, use it to encode a cookie value:
|
||||
|
||||
func SetCookieHandler(w http.ResponseWriter, r *http.Request) {
|
||||
value := map[string]string{
|
||||
"foo": "bar",
|
||||
}
|
||||
if encoded, err := s.Encode("cookie-name", value); err == nil {
|
||||
cookie := &http.Cookie{
|
||||
Name: "cookie-name",
|
||||
Value: encoded,
|
||||
Path: "/",
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
}
|
||||
}
|
||||
|
||||
Later, use the same SecureCookie instance to decode and validate a cookie
|
||||
value:
|
||||
|
||||
func ReadCookieHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if cookie, err := r.Cookie("cookie-name"); err == nil {
|
||||
value := make(map[string]string)
|
||||
if err = s2.Decode("cookie-name", cookie.Value, &value); err == nil {
|
||||
fmt.Fprintf(w, "The value of foo is %q", value["foo"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
We stored a map[string]string, but secure cookies can hold any value that
|
||||
can be encoded using encoding/gob. To store custom types, they must be
|
||||
registered first using gob.Register(). For basic types this is not needed;
|
||||
it works out of the box.
|
||||
*/
|
||||
package securecookie
|
|
@ -0,0 +1,25 @@
|
|||
// +build gofuzz
|
||||
|
||||
package securecookie
|
||||
|
||||
var hashKey = []byte("very-secret12345")
|
||||
var blockKey = []byte("a-lot-secret1234")
|
||||
var s = New(hashKey, blockKey)
|
||||
|
||||
type Cookie struct {
|
||||
B bool
|
||||
I int
|
||||
S string
|
||||
}
|
||||
|
||||
func Fuzz(data []byte) int {
|
||||
datas := string(data)
|
||||
var c Cookie
|
||||
if err := s.Decode("fuzz", datas, &c); err != nil {
|
||||
return 0
|
||||
}
|
||||
if _, err := s.Encode("fuzz", c); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return 1
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing/quick"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
)
|
||||
|
||||
var hashKey = []byte("very-secret12345")
|
||||
var blockKey = []byte("a-lot-secret1234")
|
||||
var s = securecookie.New(hashKey, blockKey)
|
||||
|
||||
type Cookie struct {
|
||||
B bool
|
||||
I int
|
||||
S string
|
||||
}
|
||||
|
||||
func main() {
|
||||
var c Cookie
|
||||
t := reflect.TypeOf(c)
|
||||
rnd := rand.New(rand.NewSource(0))
|
||||
for i := 0; i < 100; i++ {
|
||||
v, ok := quick.Value(t, rnd)
|
||||
if !ok {
|
||||
panic("couldn't generate value")
|
||||
}
|
||||
encoded, err := s.Encode("fuzz", v.Interface())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
f, err := os.Create(fmt.Sprintf("corpus/%d.sc", i))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = io.WriteString(f, encoded)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,646 @@
|
|||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package securecookie
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Error is the interface of all errors returned by functions in this library.
|
||||
type Error interface {
|
||||
error
|
||||
|
||||
// IsUsage returns true for errors indicating the client code probably
|
||||
// uses this library incorrectly. For example, the client may have
|
||||
// failed to provide a valid hash key, or may have failed to configure
|
||||
// the Serializer adequately for encoding value.
|
||||
IsUsage() bool
|
||||
|
||||
// IsDecode returns true for errors indicating that a cookie could not
|
||||
// be decoded and validated. Since cookies are usually untrusted
|
||||
// user-provided input, errors of this type should be expected.
|
||||
// Usually, the proper action is simply to reject the request.
|
||||
IsDecode() bool
|
||||
|
||||
// IsInternal returns true for unexpected errors occurring in the
|
||||
// securecookie implementation.
|
||||
IsInternal() bool
|
||||
|
||||
// Cause, if it returns a non-nil value, indicates that this error was
|
||||
// propagated from some underlying library. If this method returns nil,
|
||||
// this error was raised directly by this library.
|
||||
//
|
||||
// Cause is provided principally for debugging/logging purposes; it is
|
||||
// rare that application logic should perform meaningfully different
|
||||
// logic based on Cause. See, for example, the caveats described on
|
||||
// (MultiError).Cause().
|
||||
Cause() error
|
||||
}
|
||||
|
||||
// errorType is a bitmask giving the error type(s) of an cookieError value.
|
||||
type errorType int
|
||||
|
||||
const (
|
||||
usageError = errorType(1 << iota)
|
||||
decodeError
|
||||
internalError
|
||||
)
|
||||
|
||||
type cookieError struct {
|
||||
typ errorType
|
||||
msg string
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e cookieError) IsUsage() bool { return (e.typ & usageError) != 0 }
|
||||
func (e cookieError) IsDecode() bool { return (e.typ & decodeError) != 0 }
|
||||
func (e cookieError) IsInternal() bool { return (e.typ & internalError) != 0 }
|
||||
|
||||
func (e cookieError) Cause() error { return e.cause }
|
||||
|
||||
func (e cookieError) Error() string {
|
||||
parts := []string{"securecookie: "}
|
||||
if e.msg == "" {
|
||||
parts = append(parts, "error")
|
||||
} else {
|
||||
parts = append(parts, e.msg)
|
||||
}
|
||||
if c := e.Cause(); c != nil {
|
||||
parts = append(parts, " - caused by: ", c.Error())
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
|
||||
var (
|
||||
errGeneratingIV = cookieError{typ: internalError, msg: "failed to generate random iv"}
|
||||
|
||||
errNoCodecs = cookieError{typ: usageError, msg: "no codecs provided"}
|
||||
errHashKeyNotSet = cookieError{typ: usageError, msg: "hash key is not set"}
|
||||
errBlockKeyNotSet = cookieError{typ: usageError, msg: "block key is not set"}
|
||||
errEncodedValueTooLong = cookieError{typ: usageError, msg: "the value is too long"}
|
||||
|
||||
errValueToDecodeTooLong = cookieError{typ: decodeError, msg: "the value is too long"}
|
||||
errTimestampInvalid = cookieError{typ: decodeError, msg: "invalid timestamp"}
|
||||
errTimestampTooNew = cookieError{typ: decodeError, msg: "timestamp is too new"}
|
||||
errTimestampExpired = cookieError{typ: decodeError, msg: "expired timestamp"}
|
||||
errDecryptionFailed = cookieError{typ: decodeError, msg: "the value could not be decrypted"}
|
||||
errValueNotByte = cookieError{typ: decodeError, msg: "value not a []byte."}
|
||||
|
||||
// ErrMacInvalid indicates that cookie decoding failed because the HMAC
|
||||
// could not be extracted and verified. Direct use of this error
|
||||
// variable is deprecated; it is public only for legacy compatibility,
|
||||
// and may be privatized in the future, as it is rarely useful to
|
||||
// distinguish between this error and other Error implementations.
|
||||
ErrMacInvalid = cookieError{typ: decodeError, msg: "the value is not valid"}
|
||||
)
|
||||
|
||||
// Codec defines an interface to encode and decode cookie values.
|
||||
type Codec interface {
|
||||
Encode(name string, value interface{}) (string, error)
|
||||
Decode(name, value string, dst interface{}) error
|
||||
}
|
||||
|
||||
// New returns a new SecureCookie.
|
||||
//
|
||||
// hashKey is required, used to authenticate values using HMAC. Create it using
|
||||
// GenerateRandomKey(). It is recommended to use a key with 32 or 64 bytes.
|
||||
//
|
||||
// blockKey is optional, used to encrypt values. Create it using
|
||||
// GenerateRandomKey(). The key length must correspond to the block size
|
||||
// of the encryption algorithm. For AES, used by default, valid lengths are
|
||||
// 16, 24, or 32 bytes to select AES-128, AES-192, or AES-256.
|
||||
// The default encoder used for cookie serialization is encoding/gob.
|
||||
//
|
||||
// Note that keys created using GenerateRandomKey() are not automatically
|
||||
// persisted. New keys will be created when the application is restarted, and
|
||||
// previously issued cookies will not be able to be decoded.
|
||||
func New(hashKey, blockKey []byte) *SecureCookie {
|
||||
s := &SecureCookie{
|
||||
hashKey: hashKey,
|
||||
blockKey: blockKey,
|
||||
hashFunc: sha256.New,
|
||||
maxAge: 86400 * 30,
|
||||
maxLength: 4096,
|
||||
sz: GobEncoder{},
|
||||
}
|
||||
if hashKey == nil {
|
||||
s.err = errHashKeyNotSet
|
||||
}
|
||||
if blockKey != nil {
|
||||
s.BlockFunc(aes.NewCipher)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// SecureCookie encodes and decodes authenticated and optionally encrypted
|
||||
// cookie values.
|
||||
type SecureCookie struct {
|
||||
hashKey []byte
|
||||
hashFunc func() hash.Hash
|
||||
blockKey []byte
|
||||
block cipher.Block
|
||||
maxLength int
|
||||
maxAge int64
|
||||
minAge int64
|
||||
err error
|
||||
sz Serializer
|
||||
// For testing purposes, the function that returns the current timestamp.
|
||||
// If not set, it will use time.Now().UTC().Unix().
|
||||
timeFunc func() int64
|
||||
}
|
||||
|
||||
// Serializer provides an interface for providing custom serializers for cookie
|
||||
// values.
|
||||
type Serializer interface {
|
||||
Serialize(src interface{}) ([]byte, error)
|
||||
Deserialize(src []byte, dst interface{}) error
|
||||
}
|
||||
|
||||
// GobEncoder encodes cookie values using encoding/gob. This is the simplest
|
||||
// encoder and can handle complex types via gob.Register.
|
||||
type GobEncoder struct{}
|
||||
|
||||
// JSONEncoder encodes cookie values using encoding/json. Users who wish to
|
||||
// encode complex types need to satisfy the json.Marshaller and
|
||||
// json.Unmarshaller interfaces.
|
||||
type JSONEncoder struct{}
|
||||
|
||||
// NopEncoder does not encode cookie values, and instead simply accepts a []byte
|
||||
// (as an interface{}) and returns a []byte. This is particularly useful when
|
||||
// you encoding an object upstream and do not wish to re-encode it.
|
||||
type NopEncoder struct{}
|
||||
|
||||
// MaxLength restricts the maximum length, in bytes, for the cookie value.
|
||||
//
|
||||
// Default is 4096, which is the maximum value accepted by Internet Explorer.
|
||||
func (s *SecureCookie) MaxLength(value int) *SecureCookie {
|
||||
s.maxLength = value
|
||||
return s
|
||||
}
|
||||
|
||||
// MaxAge restricts the maximum age, in seconds, for the cookie value.
|
||||
//
|
||||
// Default is 86400 * 30. Set it to 0 for no restriction.
|
||||
func (s *SecureCookie) MaxAge(value int) *SecureCookie {
|
||||
s.maxAge = int64(value)
|
||||
return s
|
||||
}
|
||||
|
||||
// MinAge restricts the minimum age, in seconds, for the cookie value.
|
||||
//
|
||||
// Default is 0 (no restriction).
|
||||
func (s *SecureCookie) MinAge(value int) *SecureCookie {
|
||||
s.minAge = int64(value)
|
||||
return s
|
||||
}
|
||||
|
||||
// HashFunc sets the hash function used to create HMAC.
|
||||
//
|
||||
// Default is crypto/sha256.New.
|
||||
func (s *SecureCookie) HashFunc(f func() hash.Hash) *SecureCookie {
|
||||
s.hashFunc = f
|
||||
return s
|
||||
}
|
||||
|
||||
// BlockFunc sets the encryption function used to create a cipher.Block.
|
||||
//
|
||||
// Default is crypto/aes.New.
|
||||
func (s *SecureCookie) BlockFunc(f func([]byte) (cipher.Block, error)) *SecureCookie {
|
||||
if s.blockKey == nil {
|
||||
s.err = errBlockKeyNotSet
|
||||
} else if block, err := f(s.blockKey); err == nil {
|
||||
s.block = block
|
||||
} else {
|
||||
s.err = cookieError{cause: err, typ: usageError}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Encoding sets the encoding/serialization method for cookies.
|
||||
//
|
||||
// Default is encoding/gob. To encode special structures using encoding/gob,
|
||||
// they must be registered first using gob.Register().
|
||||
func (s *SecureCookie) SetSerializer(sz Serializer) *SecureCookie {
|
||||
s.sz = sz
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Encode encodes a cookie value.
|
||||
//
|
||||
// It serializes, optionally encrypts, signs with a message authentication code,
|
||||
// and finally encodes the value.
|
||||
//
|
||||
// The name argument is the cookie name. It is stored with the encoded value.
|
||||
// The value argument is the value to be encoded. It can be any value that can
|
||||
// be encoded using the currently selected serializer; see SetSerializer().
|
||||
//
|
||||
// It is the client's responsibility to ensure that value, when encoded using
|
||||
// the current serialization/encryption settings on s and then base64-encoded,
|
||||
// is shorter than the maximum permissible length.
|
||||
func (s *SecureCookie) Encode(name string, value interface{}) (string, error) {
|
||||
if s.err != nil {
|
||||
return "", s.err
|
||||
}
|
||||
if s.hashKey == nil {
|
||||
s.err = errHashKeyNotSet
|
||||
return "", s.err
|
||||
}
|
||||
var err error
|
||||
var b []byte
|
||||
// 1. Serialize.
|
||||
if b, err = s.sz.Serialize(value); err != nil {
|
||||
return "", cookieError{cause: err, typ: usageError}
|
||||
}
|
||||
// 2. Encrypt (optional).
|
||||
if s.block != nil {
|
||||
if b, err = encrypt(s.block, b); err != nil {
|
||||
return "", cookieError{cause: err, typ: usageError}
|
||||
}
|
||||
}
|
||||
b = encode(b)
|
||||
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
|
||||
b = []byte(fmt.Sprintf("%s|%d|%s|", name, s.timestamp(), b))
|
||||
mac := createMac(hmac.New(s.hashFunc, s.hashKey), b[:len(b)-1])
|
||||
// Append mac, remove name.
|
||||
b = append(b, mac...)[len(name)+1:]
|
||||
// 4. Encode to base64.
|
||||
b = encode(b)
|
||||
// 5. Check length.
|
||||
if s.maxLength != 0 && len(b) > s.maxLength {
|
||||
return "", errEncodedValueTooLong
|
||||
}
|
||||
// Done.
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// Decode decodes a cookie value.
|
||||
//
|
||||
// It decodes, verifies a message authentication code, optionally decrypts and
|
||||
// finally deserializes the value.
|
||||
//
|
||||
// The name argument is the cookie name. It must be the same name used when
|
||||
// it was stored. The value argument is the encoded cookie value. The dst
|
||||
// argument is where the cookie will be decoded. It must be a pointer.
|
||||
func (s *SecureCookie) Decode(name, value string, dst interface{}) error {
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
if s.hashKey == nil {
|
||||
s.err = errHashKeyNotSet
|
||||
return s.err
|
||||
}
|
||||
// 1. Check length.
|
||||
if s.maxLength != 0 && len(value) > s.maxLength {
|
||||
return errValueToDecodeTooLong
|
||||
}
|
||||
// 2. Decode from base64.
|
||||
b, err := decode([]byte(value))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 3. Verify MAC. Value is "date|value|mac".
|
||||
parts := bytes.SplitN(b, []byte("|"), 3)
|
||||
if len(parts) != 3 {
|
||||
return ErrMacInvalid
|
||||
}
|
||||
h := hmac.New(s.hashFunc, s.hashKey)
|
||||
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])-1]...)
|
||||
if err = verifyMac(h, b, parts[2]); err != nil {
|
||||
return err
|
||||
}
|
||||
// 4. Verify date ranges.
|
||||
var t1 int64
|
||||
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
|
||||
return errTimestampInvalid
|
||||
}
|
||||
t2 := s.timestamp()
|
||||
if s.minAge != 0 && t1 > t2-s.minAge {
|
||||
return errTimestampTooNew
|
||||
}
|
||||
if s.maxAge != 0 && t1 < t2-s.maxAge {
|
||||
return errTimestampExpired
|
||||
}
|
||||
// 5. Decrypt (optional).
|
||||
b, err = decode(parts[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.block != nil {
|
||||
if b, err = decrypt(s.block, b); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// 6. Deserialize.
|
||||
if err = s.sz.Deserialize(b, dst); err != nil {
|
||||
return cookieError{cause: err, typ: decodeError}
|
||||
}
|
||||
// Done.
|
||||
return nil
|
||||
}
|
||||
|
||||
// timestamp returns the current timestamp, in seconds.
|
||||
//
|
||||
// For testing purposes, the function that generates the timestamp can be
|
||||
// overridden. If not set, it will return time.Now().UTC().Unix().
|
||||
func (s *SecureCookie) timestamp() int64 {
|
||||
if s.timeFunc == nil {
|
||||
return time.Now().UTC().Unix()
|
||||
}
|
||||
return s.timeFunc()
|
||||
}
|
||||
|
||||
// Authentication -------------------------------------------------------------
|
||||
|
||||
// createMac creates a message authentication code (MAC).
|
||||
func createMac(h hash.Hash, value []byte) []byte {
|
||||
h.Write(value)
|
||||