medinet/database.go

683 lines
20 KiB
Go

package main
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"strconv"
"strings"
"time"
_ "github.com/go-sql-driver/mysql"
)
const (
databaseVersion = 1
accountKeyLength = 32 // Was using MD5 hashes
messageMaxLength = 4096
googleOAuthURL = "https://www.googleapis.com/oauth2/v3/userinfo?alt=json&access_token="
)
// TODO: Add indexes
var databaseTables = map[string][]string{
"accounts": {
"`id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT",
"`key` VARCHAR(145) NOT NULL DEFAULT ''",
"`google_id` VARCHAR(200) NOT NULL DEFAULT ''",
"`facebook_id` VARCHAR(200) NOT NULL DEFAULT ''",
"`twitter_id` VARCHAR(200) NOT NULL DEFAULT ''",
"`openid_id` VARCHAR(200) NOT NULL DEFAULT ''",
"`email` VARCHAR(254) NOT NULL DEFAULT ''",
"`name` VARCHAR(50) NOT NULL DEFAULT ''",
"`registered` INTEGER UNSIGNED NOT NULL DEFAULT 0",
"`lastactive` INTEGER UNSIGNED NOT NULL DEFAULT 0",
"`streak` SMALLINT UNSIGNED NOT NULL DEFAULT 0",
"`topstreak` SMALLINT UNSIGNED NOT NULL DEFAULT 0",
"`streakend` INTEGER UNSIGNED NOT NULL DEFAULT 0",
"`streakbuffer` MEDIUMINT UNSIGNED NOT NULL DEFAULT 0",
"`announcement` SMALLINT UNSIGNED NOT NULL DEFAULT 0",
"`sessionspublic` TINYINT UNSIGNED NOT NULL DEFAULT 0",
"`allowcontact` TINYINT UNSIGNED NOT NULL DEFAULT 0"},
"announcements": {
"`id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT",
"`posted` INTEGER UNSIGNED NOT NULL DEFAULT 0",
"`text` TEXT NOT NULL DEFAULT ''",
"`active` TINYINT UNSIGNED NOT NULL DEFAULT 0"},
"sessions": {
"`id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT",
"`account` INTEGER NOT NULL DEFAULT 0",
"`api` VARCHAR(145) NOT NULL DEFAULT ''",
"`ip` VARCHAR(145) NOT NULL DEFAULT ''",
"`market` VARCHAR(145) NOT NULL DEFAULT ''",
"`app` VARCHAR(145) NOT NULL DEFAULT ''",
"`posted` INTEGER UNSIGNED NOT NULL DEFAULT 0",
"`started` INTEGER UNSIGNED NOT NULL DEFAULT 0",
"`length` MEDIUMINT UNSIGNED NOT NULL DEFAULT 0",
"`completed` INTEGER UNSIGNED NOT NULL DEFAULT 0",
"`message` TEXT NOT NULL DEFAULT ''",
"`streakday` SMALLINT UNSIGNED NOT NULL DEFAULT 0",
"`modified` INTEGER UNSIGNED NOT NULL DEFAULT 0"},
"meta": {
"`key` VARCHAR(50) NOT NULL PRIMARY KEY",
"`value` TEXT NOT NULL DEFAULT ''"}}
type database struct {
db *sql.DB
FuncGreatest string
}
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
type account struct {
ID int
Name string
Email string
Registered int
Key string
Streak int
StreakBuffer int
TopStreak int
AllowContact int
Announcement int
}
type session struct {
ID int `json:"id"`
Posted int `json:"posted"`
Started int `json:"started"`
StreakDay int `json:"streakday"`
Length int `json:"length"`
Completed int `json:"completed"`
Message string `json:"message"`
Modified int `json:"modified"`
}
type recentSession struct {
session
AccountID int
AccountName string
AccountEmail string
}
type group struct {
ID int `json:"id"`
Creator int `json:"creator"`
Name string `json:"name"`
Description string `json:"description"`
}
func generateKey() string {
b := make([]rune, accountKeyLength)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}
func connect(driver string, dataSource string) (*database, error) {
var err error
d := new(database)
d.db, err = sql.Open(driver, dataSource)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %s", err)
}
d.FuncGreatest = "GREATEST"
err = d.CreateTables()
if err != nil {
_ = d.db.Close()
return nil, fmt.Errorf("failed to create tables: %s", err)
}
err = d.Migrate()
if err != nil {
_ = d.db.Close()
return nil, fmt.Errorf("failed to migrate database: %s", err)
}
return d, nil
}
func (d *database) CreateTables() error {
var (
tcolumns string
err error
)
createQueryExtra := ""
if config.DBDriver == "mysql" {
createQueryExtra = " ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE utf8_unicode_ci"
}
for tname, tcols := range databaseTables {
tcolumns = strings.Join(tcols, ",")
if config.DBDriver == "mysql" {
tcolumns = strings.Replace(tcolumns, "AUTOINCREMENT", "AUTO_INCREMENT", -1)
}
_, err = d.db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` (%s)", tname, tcolumns) + createQueryExtra)
if err != nil {
return fmt.Errorf("failed to create table %s: %s", tname, err)
}
}
return nil
}
func (d *database) Migrate() error {
rows, err := d.db.Query("SELECT `value` FROM meta WHERE `key`=?", "version")
if err != nil {
return fmt.Errorf("failed to fetch database version: %s", err)
}
version := 0
for rows.Next() {
v := ""
err = rows.Scan(&v)
if err != nil {
return fmt.Errorf("failed to scan database meta: %s", err)
}
version, err = strconv.Atoi(v)
if err != nil {
version = -1
}
}
if version == -1 {
panic("Unable to migrate database: database version unknown")
} else if version == 0 {
_, err := d.db.Exec("UPDATE meta SET `value`=? WHERE `key`=?", strconv.Itoa(databaseVersion), "version")
if err != nil {
return fmt.Errorf("failed to save database version: %s", err)
}
}
migrated := false
for version < databaseVersion {
switch version {
case 1:
// databaseVersion 2 migration queries will go here
}
version++
migrated = true
}
if migrated {
_, err := d.db.Exec("UPDATE meta SET `value`=? WHERE `key`=?", strconv.Itoa(databaseVersion), "version")
if err != nil {
return fmt.Errorf("failed to save updated database version: %s", err)
}
}
return nil
}
func (d *database) authenticate(token string) (*account, error) {
key := ""
resp, err := http.Get(googleOAuthURL + token)
if err != nil {
return nil, fmt.Errorf("failed to get userinfo from Google: %s", err)
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read userinfo response from Google: %s", err)
}
var userinfo map[string]interface{}
err = json.Unmarshal(data, &userinfo)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal userinfo response from Google: %s", err)
}
googleid := ""
email := ""
name := ""
if v, ok := userinfo["sub"]; ok {
googleid = v.(string)
}
if googleid == "" || googleid == "0" {
logDebugf("Userinfo: %+v", userinfo)
logDebugf("Access token: %+v", googleid)
return nil, errors.New("invalid access token")
}
if v, ok := userinfo["email"]; ok {
email = v.(string)
if len(email) > 75 {
email = ""
}
}
if v, ok := userinfo["name"]; ok {
name = v.(string)
if len(name) > 50 {
name = name[0:50]
}
}
err = d.db.QueryRow("SELECT `key` FROM accounts WHERE google_id=?", googleid).Scan(&key)
if err == sql.ErrNoRows {
key = generateKey()
_, err = d.db.Exec("INSERT INTO accounts (`key`, `google_id`, `email`, `name`, `registered`) VALUES(?, ?, ?, ?, ?)", key, googleid, email, name, time.Now().Unix())
if err != nil {
return nil, fmt.Errorf("failed to insert account: %s", err)
}
stats.AccountsCreated++
} else if err != nil {
return nil, fmt.Errorf("failed to fetch account key: %s", err)
}
account, err := d.accountByKey(key)
failOnError(err)
return account, nil
}
func (d *database) scanAccount(rows *sql.Rows) (*account, error) {
a := new(account)
err := rows.Scan(&a.ID, &a.Name, &a.Email, &a.Registered, &a.Key, &a.Streak, &a.StreakBuffer, &a.TopStreak, &a.AllowContact, &a.Announcement)
if err != nil {
return nil, fmt.Errorf("failed to scan account: %s", err)
}
return a, nil
}
func (d *database) accountByID(id int) (*account, error) {
rows, err := d.db.Query("SELECT `id`, `name`, `email`, `registered`, `key`, `streak`, `streakbuffer`, `topstreak`, `allowcontact`, `announcement` FROM accounts WHERE `id`=?", id)
if err != nil {
return nil, fmt.Errorf("accountByID error: %s", err)
}
for rows.Next() {
a, err := d.scanAccount(rows)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, fmt.Errorf("accountByID error: %s", err)
}
return a, nil
}
return nil, nil
}
func (d *database) accountByKey(key string) (*account, error) {
rows, err := d.db.Query("SELECT `id`, `name`, `email`, `registered`, `key`, `streak`, `streakbuffer`, `topstreak`, `allowcontact`, `announcement` FROM accounts WHERE `key`=?", key)
if err != nil {
return nil, fmt.Errorf("accountByID error: %s", err)
}
for rows.Next() {
a, err := d.scanAccount(rows)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, fmt.Errorf("accountByID error: %s", err)
}
return a, nil
}
return nil, nil
}
func (d *database) getStreak(accountID int) (int64, int64, int64, error) {
streakDay := int64(0)
streakEnd := int64(0)
topStreak := int64(0)
err := d.db.QueryRow("SELECT `streak`, `streakend`, `topstreak` FROM accounts WHERE `id`=?", accountID).Scan(&streakDay, &streakEnd, &topStreak)
if err == sql.ErrNoRows {
return 0, 0, 0, errors.New("invalid account ID")
} else if err != nil {
return 0, 0, 0, fmt.Errorf("getStreak error: %s", err)
}
// Expire streak
if streakEnd <= time.Now().Unix() {
streakDay = 0
streakEnd = 0
_, err := d.db.Exec("UPDATE accounts SET `streak`=?, `streakend`=? WHERE `id`=?", streakDay, streakEnd, accountID)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to expire streak: %s", err)
}
}
return streakDay, streakEnd, topStreak, nil
}
func (d *database) updateLastActive(accountID int) error {
_, err := d.db.Exec("UPDATE accounts SET `lastactive`=? WHERE `id`=?", time.Now().Unix(), accountID)
if err != nil {
err = fmt.Errorf("failed to update last active: %s", err)
}
return err
}
func (d *database) updateTopStreak(accountID int) error {
_, err := d.db.Exec("UPDATE accounts SET `topstreak`="+d.FuncGreatest+"(`streak`, `topstreak`) WHERE `id`=?", accountID)
if err != nil {
err = fmt.Errorf("failed to update top streak: %s", err)
}
return err
}
func (d *database) updateStreakBuffer(accountID int, streakBuffer int) error {
_, err := d.db.Exec("UPDATE accounts SET `streakbuffer`=? WHERE `id`=?", streakBuffer, accountID)
if err != nil {
err = fmt.Errorf("failed to update streak buffer: %s", err)
}
return err
}
func (d *database) calculateStreak(accountID int, streakBuffer int, tz *time.Location) (int, error) {
streak := 0
t := time.Now().In(tz)
logDebugf("calculate start %v", t)
if beforeWindowStart(t, streakBuffer) {
t = t.AddDate(0, 0, -1)
logDebugf("calculate added %v", t)
}
for {
exists, err := d.sessionExistsByDate(t, accountID, streakBuffer)
if err != nil {
return 0, fmt.Errorf("failed to check if session exists for date: %s", err)
} else if exists {
streak++
t = t.AddDate(0, 0, -1)
} else {
break
}
}
logDebugf("calculated streak as %d", streak)
return streak, nil
}
func (d *database) setStreak(streakDay int, accountID int, streakBuffer int, tz *time.Location) error {
t := time.Now().In(tz)
if beforeWindowStart(t, streakBuffer) {
t = t.AddDate(0, 0, 1)
} else {
t = t.AddDate(0, 0, 2)
}
t = atWindowStart(t, streakBuffer)
logDebugf("SETTING STREAK Account %d, Day %d, TZ %s, Streak end: %d", accountID, streakDay, tz.String(), t.Unix())
_, err := d.db.Exec("UPDATE accounts SET `streak`=?, `streakend`=? WHERE `id`=?", streakDay, t.Unix(), accountID)
if err != nil {
return fmt.Errorf("failed to update streak: %s", err)
}
err = d.updateTopStreak(accountID)
if err != nil {
return fmt.Errorf("failed to update top streak: %s", err)
}
return nil
}
func (d *database) setSessionStreakDay(started int, streakDay int, accountID int) error {
_, err := d.db.Exec("UPDATE sessions SET `streakday`=? WHERE `account`=? AND `started`=?", streakDay, accountID, started)
if err != nil {
return fmt.Errorf("failed to set session streak day: %s", err)
}
return nil
}
func (d *database) scanSession(rows *sql.Rows) (*session, error) {
s := new(session)
err := rows.Scan(&s.ID, &s.Posted, &s.Started, &s.StreakDay, &s.Length, &s.Completed, &s.Message, &s.Modified)
if err != nil {
return nil, fmt.Errorf("failed to scan session: %s", err)
}
return s, nil
}
func (d *database) scanRecentSession(rows *sql.Rows) (*recentSession, error) {
s := new(recentSession)
err := rows.Scan(&s.ID, &s.Posted, &s.Started, &s.StreakDay, &s.Length, &s.Completed, &s.Message, &s.Modified, &s.AccountID, &s.AccountName, &s.AccountEmail)
if err != nil {
return nil, fmt.Errorf("failed to scan session: %s", err)
}
return s, nil
}
func (d *database) addSession(s session, updateSessionStarted int, accountID int, appVer string, appMarket string) (bool, error) {
var (
existingSession *session
updateSession *session
keepSession *session
err error
)
existingSession, err = d.getSessionByStarted(s.Started, accountID)
if err != nil {
return false, fmt.Errorf("failed to fetch session: %s", err)
}
if updateSessionStarted > 0 && updateSessionStarted != s.Started {
updateSession, err = d.getSessionByStarted(updateSessionStarted, accountID)
if err != nil {
return false, fmt.Errorf("failed to fetch session: %s", err)
}
}
if (existingSession != nil && existingSession.Modified >= s.Modified) || (updateSession != nil && updateSession.Modified >= s.Modified) {
return false, nil
}
if len(s.Message) > messageMaxLength {
s.Message = s.Message[:messageMaxLength]
}
// Fix zero completed from older versions of the app
if s.Completed == 0 {
s.Completed = s.Started + s.Length
}
if existingSession == nil && updateSession == nil {
_, err = d.db.Exec("INSERT INTO sessions (`account`, `market`, `app`, `posted`, `started`, `streakday`, `length`, `completed`, `message`, `modified`) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", accountID, appMarket, appVer, time.Now().Unix(), s.Started, s.StreakDay, s.Length, s.Completed, s.Message, s.Modified)
if err != nil {
return false, fmt.Errorf("failed to add session: %s", err)
}
} else {
keepSession = updateSession
if keepSession == nil {
keepSession = existingSession
} else if existingSession != nil {
_, err = db.deleteSession(existingSession.Started, accountID)
if err != nil {
return false, fmt.Errorf("failed to delete existing session: %s", err)
}
}
_, err = d.db.Exec("UPDATE sessions SET `started`=?, `length`=?, `completed`=?, `message`=?, `modified`=? WHERE `account`=? AND `started`=?", s.Started, s.Length, s.Completed, s.Message, s.Modified, accountID, keepSession.Started)
if err != nil {
return false, fmt.Errorf("failed to update session: %s", err)
}
}
return true, nil
}
func (d *database) getSessionByID(sessionID int, accountID int) (*session, error) {
rows, err := d.db.Query("SELECT `id`, `posted`, `started`, `streakday`, `length`, `completed`, `message`, `modified` FROM sessions WHERE `account`=? AND `id`=?", accountID, sessionID)
if err != nil {
return nil, fmt.Errorf("failed to fetch session: %s", err)
}
defer rows.Close()
for rows.Next() {
return d.scanSession(rows)
}
return nil, nil
}
func (d *database) getSessionByStarted(started int, accountID int) (*session, error) {
rows, err := d.db.Query("SELECT `id`, `posted`, `started`, `streakday`, `length`, `completed`, `message`, `modified` FROM sessions WHERE `account`=? AND `started`=?", accountID, started)
if err != nil {
return nil, fmt.Errorf("failed to fetch session: %s", err)
}
defer rows.Close()
for rows.Next() {
return d.scanSession(rows)
}
return nil, nil
}
func (d *database) sessionExistsByDate(date time.Time, accountID int, streakBuffer int) (bool, error) {
windowStart := atWindowStart(date, streakBuffer)
windowEnd := atWindowStart(windowStart.AddDate(0, 0, 1), streakBuffer)
logDebugf("SESSION EXISTS %v - START %v END %v", date, windowStart.Unix(), windowEnd.Unix())
sessionid := 0
err := d.db.QueryRow("SELECT `id` FROM sessions WHERE `account`=? AND `started`>=? AND `started`<? LIMIT 1", accountID, windowStart.Unix(), windowEnd.Unix()).Scan(&sessionid)
if err != nil && err != sql.ErrNoRows {
return false, fmt.Errorf("sessionExistsByDate failed: %s", err)
}
return sessionid > 0, nil
}
func (d *database) getAllSessions(accountID int, sortDescending bool) ([]*session, error) {
var sessions []*session
querySort := "ASC"
if sortDescending {
querySort = "DESC"
}
rows, err := d.db.Query("SELECT `id`, `posted`, `started`, `streakday`, `length`, `completed`, `message`, `modified` FROM sessions WHERE `account`=? ORDER BY `completed` "+querySort, accountID)
if err != nil {
return nil, fmt.Errorf("failed to fetch sessions: %s", err)
}
defer rows.Close()
for rows.Next() {
s, err := d.scanSession(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan session: %s", err)
}
sessions = append(sessions, s)
}
return sessions, nil
}
func (d *database) getRecentSessions() ([]*recentSession, error) {
var sessions []*recentSession
rows, err := d.db.Query("SELECT `sessions`.`id`, `sessions`.`posted`, `sessions`.`started`, `sessions`.`streakday`, `sessions`.`length`, `sessions`.`completed`, `sessions`.`message`, `sessions`.`modified`, `accounts`.`id` AS `accountid`, `accounts`.`name`, `accounts`.`email` FROM `sessions` LEFT OUTER JOIN `accounts` ON `sessions`.`account` = `accounts`.`id` WHERE `accounts`.`sessionspublic` = 1 AND `sessions`.`length` > 110 ORDER BY `sessions`.`completed` DESC LIMIT 50")
if err != nil {
return nil, fmt.Errorf("failed to fetch recent sessions: %s", err)
}
defer rows.Close()
for rows.Next() {
rs, err := d.scanRecentSession(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan recent session: %s", err)
}
sessions = append(sessions, rs)
}
return sessions, nil
}
func (d *database) deleteSession(started int, accountID int) (bool, error) {
r, err := d.db.Exec("DELETE FROM sessions WHERE `account`=? AND `started`=?", accountID, started)
if err != nil {
return false, fmt.Errorf("failed to delete session: %s", err)
}
affected, err := r.RowsAffected()
if err != nil {
return false, fmt.Errorf("failed to fetch number of deleted sessions: %s", err)
}
return affected > 0, nil
}
func (d *database) getAllGroups() ([]group, error) {
var groups []group
rows, err := d.db.Query("SELECT `id`, `name`, `creator`, `description` FROM groups ORDER BY `name` ASC")
if err != nil {
return nil, fmt.Errorf("failed to fetch groups: %s", err)
}
defer rows.Close()
for rows.Next() {
g := group{}
err = rows.Scan(&g.ID, &g.Name, &g.Creator, &g.Description)
if err != nil {
return nil, fmt.Errorf("failed to scan group: %s", err)
}
groups = append(groups, g)
}
return groups, nil
}
func (d *database) groupMemberCount(groupID int) (int, error) {
rows, err := d.db.Query("SELECT COUNT(*) as c FROM `groupmembers` WHERE `group` =?", groupID)
if err != nil {
return 0, fmt.Errorf("failed to fetch member count: %s", err)
}
defer rows.Close()
var memberCount int
for rows.Next() {
err = rows.Scan(&memberCount)
if err != nil {
return 0, fmt.Errorf("failed to fetch member count: %s", err)
}
}
return memberCount, nil
}
func (d *database) showAnnouncement(a *account) (string, error) {
rows, err := d.db.Query("SELECT `id`, `text` FROM announcements WHERE `active` = 1 AND `id` > ? ORDER BY `id` ASC limit 1", a.Announcement)
if err != nil {
return "", fmt.Errorf("showAnnouncement error: %s", err)
}
var id int
var text string
for rows.Next() {
err = rows.Scan(&id, &text)
if err != nil {
return "", fmt.Errorf("failed to scan announcement: %s", err)
}
a.Announcement = id
_, err = d.db.Exec("UPDATE accounts SET announcement = ? WHERE `id` = ? LIMIT 1", a.Announcement, a.ID)
if err != nil {
return "", fmt.Errorf("failed to update account: %s", err)
}
}
return text, nil
}