package matrix import ( "context" "database/sql" "log" "sync" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) type Client struct { client *mautrix.Client config *Config db *sql.DB greetedRooms []id.RoomID olmMachine *crypto.OlmMachine store *sqlStore syncer *mautrix.DefaultSyncer } type Config struct { AllowedRooms []string `json:"allowed_rooms"` DeviceName string `json:"device_name"` Greeting string LogLevel uint `json:"log_level"` HomeserverURL string `json:"homeserver_url"` UserIdentifier string `json:"user_identifier"` Password string PickleKey string `json:"pickle_key"` } func NewClient(config *Config, db *sql.DB) (*Client, error) { client, err := mautrix.NewClient(config.HomeserverURL, "", "") if err != nil { return nil, err } syncer := mautrix.NewDefaultSyncer() client.Syncer = syncer store := &sqlStore{db} client.Store = store err = store.CreateTables() if err != nil { return nil, err } c := &Client{ client: client, config: config, db: db, store: store, syncer: syncer, } syncer.OnEventType(event.StateMember, c.handleMemberEvent) return c, nil } func (c *Client) Login() error { userID, err := makeUserID(c.config.UserIdentifier, c.config.HomeserverURL) if err != nil { return err } deviceID := c.loadDeviceID(userID) _, err = c.client.Login(&mautrix.ReqLogin{ Type: mautrix.AuthTypePassword, Identifier: mautrix.UserIdentifier{ Type: mautrix.IdentifierTypeUser, User: c.config.UserIdentifier, }, Password: c.config.Password, DeviceID: deviceID, InitialDeviceDisplayName: c.config.DeviceName, StoreCredentials: true, }) if err != nil { return err } log.Print("device ID: ", c.client.DeviceID) return nil } func (c *Client) Encrypt() error { sqlCryptoStore := crypto.NewSQLCryptoStore( c.db, "sqlite3", c.client.UserID.String(), c.client.DeviceID, []byte(c.config.PickleKey), logger{ Level: c.config.LogLevel, }, ) err := sqlCryptoStore.CreateTables() if err != nil { return err } c.olmMachine = crypto.NewOlmMachine(c.client, &logger{ Level: c.config.LogLevel, }, sqlCryptoStore, c.store) err = c.olmMachine.Load() if err != nil { return err } c.syncer.OnSync(c.olmMachine.ProcessSyncResponse) c.syncer.OnEventType(event.StateEncryption, c.handleEncryptionEvent) return nil } func (c *Client) Sync(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup) { wg.Add(1) defer wg.Done() for { select { case <-ctx.Done(): return default: err := c.client.SyncWithContext(ctx) if err != nil && err != ctx.Err() { log.Print(err) cancel() } } } } func (c *Client) Send(roomID id.RoomID, message *event.MessageEventContent) error { content := event.Content{Parsed: message} if c.store.IsEncrypted(roomID) { encrypted, err := c.olmMachine.EncryptMegolmEvent(roomID, event.EventMessage, &content) if err == crypto.NoGroupSession || err == crypto.SessionExpired || err == crypto.SessionNotShared { err = c.olmMachine.ShareGroupSession(roomID, c.store.GetRoomMembers(roomID)) if err != nil { return err } encrypted, err = c.olmMachine.EncryptMegolmEvent(roomID, event.EventMessage, &content) } if err != nil { return err } _, err = c.client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) if err != nil { return err } } else { _, err := c.client.SendMessageEvent(roomID, event.EventMessage, &content) if err != nil { return err } } return nil } func (c *Client) Broadcast(message *event.MessageEventContent) (success bool) { for _, roomID := range c.store.FindAllSharedRooms(c.client.UserID) { allowed := false for _, room := range c.config.AllowedRooms { if room == roomID.String() { allowed = true break } } if !allowed { continue } err := c.Send(roomID, message) if err != nil { log.Print(err) } else { success = true } } return success } func (c *Client) handleMemberEvent(source mautrix.EventSource, evt *event.Event) { if c.olmMachine != nil { c.olmMachine.HandleMemberEvent(evt) } c.store.SetMembership(evt.RoomID, evt.GetStateKey(), evt.Content.AsMember().Membership) if evt.GetStateKey() == c.client.UserID.String() && evt.Content.AsMember().Membership == event.MembershipInvite { allowed := false for _, room := range c.config.AllowedRooms { if room == evt.RoomID.String() { allowed = true break } } if allowed { _, err := c.client.JoinRoomByID(evt.RoomID) if err != nil { log.Print(err) } for _, room := range c.greetedRooms { if room == evt.RoomID { return } } message := event.MessageEventContent{ MsgType: event.MsgText, Body: c.config.Greeting, } err = c.Send(evt.RoomID, &message) if err != nil { log.Print(err) return } c.greetedRooms = append(c.greetedRooms, evt.RoomID) } else { _, err := c.client.LeaveRoom(evt.RoomID) if err != nil { log.Print(err) } } } } func (c *Client) handleEncryptionEvent(source mautrix.EventSource, evt *event.Event) { c.store.SetEncryptionEvent(evt.RoomID, evt.Content.AsEncryption()) } func (c *Client) loadDeviceID(accountID id.UserID) (deviceID id.DeviceID) { row := c.db.QueryRow("SELECT device_id FROM crypto_account WHERE account_id = ?;", accountID) err := row.Scan(&deviceID) if err != nil { return "" } return deviceID }