189 lines
4.0 KiB
Go
189 lines
4.0 KiB
Go
|
package matrix
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"encoding/json"
|
||
|
"log"
|
||
|
|
||
|
"maunium.net/go/mautrix"
|
||
|
"maunium.net/go/mautrix/event"
|
||
|
"maunium.net/go/mautrix/id"
|
||
|
)
|
||
|
|
||
|
type sqlStore struct {
|
||
|
db *sql.DB
|
||
|
}
|
||
|
|
||
|
func (s sqlStore) CreateTables() error {
|
||
|
tables := []string{
|
||
|
`CREATE TABLE IF NOT EXISTS filter_ids (
|
||
|
user_id TEXT PRIMARY KEY ON CONFLICT REPLACE,
|
||
|
filter_id TEXT NOT NULL
|
||
|
);
|
||
|
`,
|
||
|
`CREATE TABLE IF NOT EXISTS next_batch_tokens (
|
||
|
user_id TEXT PRIMARY KEY ON CONFLICT REPLACE,
|
||
|
next_batch_token TEXT NOT NULL
|
||
|
);
|
||
|
`,
|
||
|
`CREATE TABLE IF NOT EXISTS rooms (
|
||
|
room_id TEXT PRIMARY KEY ON CONFLICT REPLACE,
|
||
|
encryption_event TEXT
|
||
|
);
|
||
|
`,
|
||
|
`CREATE TABLE IF NOT EXISTS room_members (
|
||
|
room_id TEXT,
|
||
|
user_id TEXT,
|
||
|
|
||
|
PRIMARY KEY (room_id, user_id)
|
||
|
);
|
||
|
`,
|
||
|
}
|
||
|
|
||
|
tx, err := s.db.Begin()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
for _, table := range tables {
|
||
|
_, err := tx.Exec(table)
|
||
|
if err != nil {
|
||
|
if err := tx.Rollback(); err != nil {
|
||
|
log.Print(err)
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
err = tx.Commit()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s sqlStore) SaveFilterID(userID id.UserID, filterID string) {
|
||
|
_, err := s.db.Exec("INSERT INTO filter_ids VALUES(?, ?);", userID, filterID)
|
||
|
if err != nil {
|
||
|
log.Print(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s sqlStore) LoadFilterID(userID id.UserID) (filterID string) {
|
||
|
row := s.db.QueryRow("SELECT filter_id FROM filter_ids WHERE user_id = ?;", userID)
|
||
|
|
||
|
err := row.Scan(&filterID)
|
||
|
if err != nil {
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
return filterID
|
||
|
}
|
||
|
|
||
|
func (s sqlStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
|
||
|
_, err := s.db.Exec("INSERT INTO next_batch_tokens VALUES(?, ?);", userID, nextBatchToken)
|
||
|
if err != nil {
|
||
|
log.Print(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s sqlStore) LoadNextBatch(userID id.UserID) (nextBatchToken string) {
|
||
|
row := s.db.QueryRow("SELECT next_batch_token FROM next_batch_tokens WHERE user_id = ?;", userID)
|
||
|
|
||
|
err := row.Scan(&nextBatchToken)
|
||
|
if err != nil {
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
return nextBatchToken
|
||
|
}
|
||
|
|
||
|
func (s sqlStore) SaveRoom(room *mautrix.Room) {
|
||
|
}
|
||
|
|
||
|
func (s sqlStore) LoadRoom(roomID id.RoomID) *mautrix.Room {
|
||
|
return mautrix.NewRoom(roomID)
|
||
|
}
|
||
|
|
||
|
func (s sqlStore) IsEncrypted(roomID id.RoomID) (isEncrypted bool) {
|
||
|
row := s.db.QueryRow("SELECT encryption_event NOT NULL FROM rooms WHERE room_id = ?;", roomID)
|
||
|
|
||
|
err := row.Scan(&isEncrypted)
|
||
|
if err != nil {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
return isEncrypted
|
||
|
}
|
||
|
|
||
|
func (s sqlStore) GetEncryptionEvent(roomID id.RoomID) (encryptionEvent *event.EncryptionEventContent) {
|
||
|
row := s.db.QueryRow("SELECT encryption_event FROM rooms WHERE room_id = ?;", roomID)
|
||
|
|
||
|
var data []byte
|
||
|
if err := row.Scan(&data); err != nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
err := json.Unmarshal(data, encryptionEvent)
|
||
|
if err != nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
return encryptionEvent
|
||
|
}
|
||
|
|
||
|
func (s sqlStore) FindSharedRooms(userID id.UserID) (sharedRooms []id.RoomID) {
|
||
|
rows, err := s.db.Query("SELECT rooms.room_id FROM rooms, (SELECT room_id FROM room_members GROUP BY room_id HAVING COUNT(*) > 1) shared_rooms WHERE shared_rooms.room_id = rooms.room_id AND encryption_event NOT NULL;")
|
||
|
if err != nil {
|
||
|
return nil
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
|
||
|
for rows.Next() {
|
||
|
var roomID string
|
||
|
if err := rows.Scan(&roomID); err != nil {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
sharedRooms = append(sharedRooms, id.RoomID(roomID))
|
||
|
}
|
||
|
if rows.Err() != nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
return sharedRooms
|
||
|
}
|
||
|
|
||
|
func (s sqlStore) SetMembership(roomID id.RoomID, userID string, membership event.Membership) {
|
||
|
if membership.IsInviteOrJoin() {
|
||
|
_, err := s.db.Exec("INSERT INTO room_members VALUES(?, ?);", roomID, userID)
|
||
|
if err != nil {
|
||
|
log.Print(err)
|
||
|
}
|
||
|
} else if membership.IsLeaveOrBan() {
|
||
|
_, err := s.db.Exec("DELETE FROM room_members WHERE room_id = ? AND user_id = ?;", roomID, userID)
|
||
|
if err != nil {
|
||
|
log.Print(err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s sqlStore) SetEncryptionEvent(roomID id.RoomID, encryptionEvent *event.EncryptionEventContent) {
|
||
|
var data []byte
|
||
|
if encryptionEvent != nil {
|
||
|
var err error
|
||
|
data, err = json.Marshal(encryptionEvent)
|
||
|
if err != nil {
|
||
|
log.Print(err)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
_, err := s.db.Exec("INSERT INTO rooms VALUES(?, ?);", roomID, data)
|
||
|
if err != nil {
|
||
|
log.Print(err)
|
||
|
}
|
||
|
}
|