2022-07-24 02:40:44 +02:00
package matrix
import (
"database/sql"
"encoding/json"
"log"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
type sqlStore struct {
db * sql . DB
}
func ( s sqlStore ) CreateTables ( ) error {
tables := [ ] string {
` CREATE TABLE IF NOT EXISTS filter_ids (
user_id TEXT PRIMARY KEY ON CONFLICT REPLACE ,
filter_id TEXT NOT NULL
) ;
` ,
` CREATE TABLE IF NOT EXISTS next_batch_tokens (
user_id TEXT PRIMARY KEY ON CONFLICT REPLACE ,
next_batch_token TEXT NOT NULL
) ;
` ,
` CREATE TABLE IF NOT EXISTS rooms (
room_id TEXT PRIMARY KEY ON CONFLICT REPLACE ,
encryption_event TEXT
) ;
` ,
` CREATE TABLE IF NOT EXISTS room_members (
room_id TEXT ,
user_id TEXT ,
2022-07-24 22:09:24 +02:00
PRIMARY KEY ( room_id , user_id ) ON CONFLICT IGNORE
2022-07-24 02:40:44 +02:00
) ;
` ,
}
tx , err := s . db . Begin ( )
if err != nil {
return err
}
for _ , table := range tables {
_ , err := tx . Exec ( table )
if err != nil {
if err := tx . Rollback ( ) ; err != nil {
log . Print ( err )
}
return err
}
}
err = tx . Commit ( )
if err != nil {
return err
}
return nil
}
func ( s sqlStore ) SaveFilterID ( userID id . UserID , filterID string ) {
_ , err := s . db . Exec ( "INSERT INTO filter_ids VALUES(?, ?);" , userID , filterID )
if err != nil {
log . Print ( err )
}
}
func ( s sqlStore ) LoadFilterID ( userID id . UserID ) ( filterID string ) {
row := s . db . QueryRow ( "SELECT filter_id FROM filter_ids WHERE user_id = ?;" , userID )
err := row . Scan ( & filterID )
if err != nil {
return ""
}
return filterID
}
func ( s sqlStore ) SaveNextBatch ( userID id . UserID , nextBatchToken string ) {
_ , err := s . db . Exec ( "INSERT INTO next_batch_tokens VALUES(?, ?);" , userID , nextBatchToken )
if err != nil {
log . Print ( err )
}
}
func ( s sqlStore ) LoadNextBatch ( userID id . UserID ) ( nextBatchToken string ) {
row := s . db . QueryRow ( "SELECT next_batch_token FROM next_batch_tokens WHERE user_id = ?;" , userID )
err := row . Scan ( & nextBatchToken )
if err != nil {
return ""
}
return nextBatchToken
}
func ( s sqlStore ) SaveRoom ( room * mautrix . Room ) {
}
func ( s sqlStore ) LoadRoom ( roomID id . RoomID ) * mautrix . Room {
return mautrix . NewRoom ( roomID )
}
func ( s sqlStore ) IsEncrypted ( roomID id . RoomID ) ( isEncrypted bool ) {
row := s . db . QueryRow ( "SELECT encryption_event NOT NULL FROM rooms WHERE room_id = ?;" , roomID )
err := row . Scan ( & isEncrypted )
if err != nil {
return false
}
return isEncrypted
}
func ( s sqlStore ) GetEncryptionEvent ( roomID id . RoomID ) ( encryptionEvent * event . EncryptionEventContent ) {
row := s . db . QueryRow ( "SELECT encryption_event FROM rooms WHERE room_id = ?;" , roomID )
var data [ ] byte
if err := row . Scan ( & data ) ; err != nil {
return nil
}
err := json . Unmarshal ( data , encryptionEvent )
if err != nil {
return nil
}
return encryptionEvent
}
func ( s sqlStore ) FindSharedRooms ( userID id . UserID ) ( sharedRooms [ ] id . RoomID ) {
rows , err := s . db . Query ( "SELECT rooms.room_id FROM rooms, (SELECT room_id FROM room_members GROUP BY room_id HAVING COUNT(*) > 1) shared_rooms WHERE shared_rooms.room_id = rooms.room_id AND encryption_event NOT NULL;" )
if err != nil {
return nil
}
defer rows . Close ( )
for rows . Next ( ) {
var roomID string
if err := rows . Scan ( & roomID ) ; err != nil {
continue
}
sharedRooms = append ( sharedRooms , id . RoomID ( roomID ) )
}
if rows . Err ( ) != nil {
return nil
}
return sharedRooms
}
2022-07-24 23:19:52 +02:00
func ( s sqlStore ) FindAllSharedRooms ( userID id . UserID ) ( sharedRooms [ ] id . RoomID ) {
rows , err := s . db . Query ( "SELECT room_id FROM room_members GROUP BY room_id HAVING COUNT(*) > 1;" )
if err != nil {
return nil
}
defer rows . Close ( )
for rows . Next ( ) {
var roomID string
if err := rows . Scan ( & roomID ) ; err != nil {
continue
}
sharedRooms = append ( sharedRooms , id . RoomID ( roomID ) )
}
if rows . Err ( ) != nil {
return nil
}
return sharedRooms
}
func ( s sqlStore ) GetRoomMembers ( roomID id . RoomID ) ( roomMembers [ ] id . UserID ) {
rows , err := s . db . Query ( "SELECT user_id FROM room_members WHERE room_id = ?;" , roomID )
if err != nil {
return nil
}
defer rows . Close ( )
for rows . Next ( ) {
var userID string
if err := rows . Scan ( & userID ) ; err != nil {
continue
}
roomMembers = append ( roomMembers , id . UserID ( userID ) )
}
if rows . Err ( ) != nil {
return nil
}
return roomMembers
}
2022-07-24 02:40:44 +02:00
func ( s sqlStore ) SetMembership ( roomID id . RoomID , userID string , membership event . Membership ) {
if membership . IsInviteOrJoin ( ) {
_ , err := s . db . Exec ( "INSERT INTO room_members VALUES(?, ?);" , roomID , userID )
if err != nil {
log . Print ( err )
}
} else if membership . IsLeaveOrBan ( ) {
_ , err := s . db . Exec ( "DELETE FROM room_members WHERE room_id = ? AND user_id = ?;" , roomID , userID )
if err != nil {
log . Print ( err )
}
}
}
func ( s sqlStore ) SetEncryptionEvent ( roomID id . RoomID , encryptionEvent * event . EncryptionEventContent ) {
var data [ ] byte
if encryptionEvent != nil {
var err error
data , err = json . Marshal ( encryptionEvent )
if err != nil {
log . Print ( err )
return
}
}
_ , err := s . db . Exec ( "INSERT INTO rooms VALUES(?, ?);" , roomID , data )
if err != nil {
log . Print ( err )
}
}