package matrix import ( "database/sql" "encoding/json" "log" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) type sqlStore struct { db *sql.DB } func (s sqlStore) CreateTables() error { tables := []string{ `CREATE TABLE IF NOT EXISTS filter_ids ( user_id TEXT PRIMARY KEY ON CONFLICT REPLACE, filter_id TEXT NOT NULL ); `, `CREATE TABLE IF NOT EXISTS next_batch_tokens ( user_id TEXT PRIMARY KEY ON CONFLICT REPLACE, next_batch_token TEXT NOT NULL ); `, `CREATE TABLE IF NOT EXISTS rooms ( room_id TEXT PRIMARY KEY ON CONFLICT REPLACE, encryption_event TEXT ); `, `CREATE TABLE IF NOT EXISTS room_members ( room_id TEXT, user_id TEXT, PRIMARY KEY (room_id, user_id) ); `, } tx, err := s.db.Begin() if err != nil { return err } for _, table := range tables { _, err := tx.Exec(table) if err != nil { if err := tx.Rollback(); err != nil { log.Print(err) } return err } } err = tx.Commit() if err != nil { return err } return nil } func (s sqlStore) SaveFilterID(userID id.UserID, filterID string) { _, err := s.db.Exec("INSERT INTO filter_ids VALUES(?, ?);", userID, filterID) if err != nil { log.Print(err) } } func (s sqlStore) LoadFilterID(userID id.UserID) (filterID string) { row := s.db.QueryRow("SELECT filter_id FROM filter_ids WHERE user_id = ?;", userID) err := row.Scan(&filterID) if err != nil { return "" } return filterID } func (s sqlStore) SaveNextBatch(userID id.UserID, nextBatchToken string) { _, err := s.db.Exec("INSERT INTO next_batch_tokens VALUES(?, ?);", userID, nextBatchToken) if err != nil { log.Print(err) } } func (s sqlStore) LoadNextBatch(userID id.UserID) (nextBatchToken string) { row := s.db.QueryRow("SELECT next_batch_token FROM next_batch_tokens WHERE user_id = ?;", userID) err := row.Scan(&nextBatchToken) if err != nil { return "" } return nextBatchToken } func (s sqlStore) SaveRoom(room *mautrix.Room) { } func (s sqlStore) LoadRoom(roomID id.RoomID) *mautrix.Room { return mautrix.NewRoom(roomID) } func (s sqlStore) IsEncrypted(roomID id.RoomID) (isEncrypted bool) { row := s.db.QueryRow("SELECT encryption_event NOT NULL FROM rooms WHERE room_id = ?;", roomID) err := row.Scan(&isEncrypted) if err != nil { return false } return isEncrypted } func (s sqlStore) GetEncryptionEvent(roomID id.RoomID) (encryptionEvent *event.EncryptionEventContent) { row := s.db.QueryRow("SELECT encryption_event FROM rooms WHERE room_id = ?;", roomID) var data []byte if err := row.Scan(&data); err != nil { return nil } err := json.Unmarshal(data, encryptionEvent) if err != nil { return nil } return encryptionEvent } func (s sqlStore) FindSharedRooms(userID id.UserID) (sharedRooms []id.RoomID) { rows, err := s.db.Query("SELECT rooms.room_id FROM rooms, (SELECT room_id FROM room_members GROUP BY room_id HAVING COUNT(*) > 1) shared_rooms WHERE shared_rooms.room_id = rooms.room_id AND encryption_event NOT NULL;") if err != nil { return nil } defer rows.Close() for rows.Next() { var roomID string if err := rows.Scan(&roomID); err != nil { continue } sharedRooms = append(sharedRooms, id.RoomID(roomID)) } if rows.Err() != nil { return nil } return sharedRooms } func (s sqlStore) SetMembership(roomID id.RoomID, userID string, membership event.Membership) { if membership.IsInviteOrJoin() { _, err := s.db.Exec("INSERT INTO room_members VALUES(?, ?);", roomID, userID) if err != nil { log.Print(err) } } else if membership.IsLeaveOrBan() { _, err := s.db.Exec("DELETE FROM room_members WHERE room_id = ? AND user_id = ?;", roomID, userID) if err != nil { log.Print(err) } } } func (s sqlStore) SetEncryptionEvent(roomID id.RoomID, encryptionEvent *event.EncryptionEventContent) { var data []byte if encryptionEvent != nil { var err error data, err = json.Marshal(encryptionEvent) if err != nil { log.Print(err) return } } _, err := s.db.Exec("INSERT INTO rooms VALUES(?, ?);", roomID, data) if err != nil { log.Print(err) } }