matrix/store.go

233 lines
4.9 KiB
Go

package matrix
import (
"database/sql"
"encoding/json"
"log"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
type sqlStore struct {
db *sql.DB
}
func (s sqlStore) CreateTables() error {
tables := []string{
`CREATE TABLE IF NOT EXISTS filter_ids (
user_id TEXT PRIMARY KEY ON CONFLICT REPLACE,
filter_id TEXT NOT NULL
);
`,
`CREATE TABLE IF NOT EXISTS next_batch_tokens (
user_id TEXT PRIMARY KEY ON CONFLICT REPLACE,
next_batch_token TEXT NOT NULL
);
`,
`CREATE TABLE IF NOT EXISTS rooms (
room_id TEXT PRIMARY KEY ON CONFLICT REPLACE,
encryption_event TEXT
);
`,
`CREATE TABLE IF NOT EXISTS room_members (
room_id TEXT,
user_id TEXT,
PRIMARY KEY (room_id, user_id) ON CONFLICT IGNORE
);
`,
}
tx, err := s.db.Begin()
if err != nil {
return err
}
for _, table := range tables {
_, err := tx.Exec(table)
if err != nil {
if err := tx.Rollback(); err != nil {
log.Print(err)
}
return err
}
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}
func (s sqlStore) SaveFilterID(userID id.UserID, filterID string) {
_, err := s.db.Exec("INSERT INTO filter_ids VALUES(?, ?);", userID, filterID)
if err != nil {
log.Print(err)
}
}
func (s sqlStore) LoadFilterID(userID id.UserID) (filterID string) {
row := s.db.QueryRow("SELECT filter_id FROM filter_ids WHERE user_id = ?;", userID)
err := row.Scan(&filterID)
if err != nil {
return ""
}
return filterID
}
func (s sqlStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
_, err := s.db.Exec("INSERT INTO next_batch_tokens VALUES(?, ?);", userID, nextBatchToken)
if err != nil {
log.Print(err)
}
}
func (s sqlStore) LoadNextBatch(userID id.UserID) (nextBatchToken string) {
row := s.db.QueryRow("SELECT next_batch_token FROM next_batch_tokens WHERE user_id = ?;", userID)
err := row.Scan(&nextBatchToken)
if err != nil {
return ""
}
return nextBatchToken
}
func (s sqlStore) SaveRoom(room *mautrix.Room) {
}
func (s sqlStore) LoadRoom(roomID id.RoomID) *mautrix.Room {
return mautrix.NewRoom(roomID)
}
func (s sqlStore) IsEncrypted(roomID id.RoomID) (isEncrypted bool) {
row := s.db.QueryRow("SELECT encryption_event NOT NULL FROM rooms WHERE room_id = ?;", roomID)
err := row.Scan(&isEncrypted)
if err != nil {
return false
}
return isEncrypted
}
func (s sqlStore) GetEncryptionEvent(roomID id.RoomID) (encryptionEvent *event.EncryptionEventContent) {
row := s.db.QueryRow("SELECT encryption_event FROM rooms WHERE room_id = ?;", roomID)
var data []byte
if err := row.Scan(&data); err != nil {
return nil
}
err := json.Unmarshal(data, encryptionEvent)
if err != nil {
return nil
}
return encryptionEvent
}
func (s sqlStore) FindSharedRooms(userID id.UserID) (sharedRooms []id.RoomID) {
rows, err := s.db.Query("SELECT rooms.room_id FROM rooms, (SELECT room_id FROM room_members GROUP BY room_id HAVING COUNT(*) > 1) shared_rooms WHERE shared_rooms.room_id = rooms.room_id AND encryption_event NOT NULL;")
if err != nil {
return nil
}
defer rows.Close()
for rows.Next() {
var roomID string
if err := rows.Scan(&roomID); err != nil {
continue
}
sharedRooms = append(sharedRooms, id.RoomID(roomID))
}
if rows.Err() != nil {
return nil
}
return sharedRooms
}
func (s sqlStore) FindAllSharedRooms(userID id.UserID) (sharedRooms []id.RoomID) {
rows, err := s.db.Query("SELECT room_id FROM room_members GROUP BY room_id HAVING COUNT(*) > 1;")
if err != nil {
return nil
}
defer rows.Close()
for rows.Next() {
var roomID string
if err := rows.Scan(&roomID); err != nil {
continue
}
sharedRooms = append(sharedRooms, id.RoomID(roomID))
}
if rows.Err() != nil {
return nil
}
return sharedRooms
}
func (s sqlStore) GetRoomMembers(roomID id.RoomID) (roomMembers []id.UserID) {
rows, err := s.db.Query("SELECT user_id FROM room_members WHERE room_id = ?;", roomID)
if err != nil {
return nil
}
defer rows.Close()
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
continue
}
roomMembers = append(roomMembers, id.UserID(userID))
}
if rows.Err() != nil {
return nil
}
return roomMembers
}
func (s sqlStore) SetMembership(roomID id.RoomID, userID string, membership event.Membership) {
if membership.IsInviteOrJoin() {
_, err := s.db.Exec("INSERT INTO room_members VALUES(?, ?);", roomID, userID)
if err != nil {
log.Print(err)
}
} else if membership.IsLeaveOrBan() {
_, err := s.db.Exec("DELETE FROM room_members WHERE room_id = ? AND user_id = ?;", roomID, userID)
if err != nil {
log.Print(err)
}
}
}
func (s sqlStore) SetEncryptionEvent(roomID id.RoomID, encryptionEvent *event.EncryptionEventContent) {
var data []byte
if encryptionEvent != nil {
var err error
data, err = json.Marshal(encryptionEvent)
if err != nil {
log.Print(err)
return
}
}
_, err := s.db.Exec("INSERT INTO rooms VALUES(?, ?);", roomID, data)
if err != nil {
log.Print(err)
}
}