diff --git a/internal/config/config.go b/internal/config/config.go index 5ee60cc..534d8e1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,7 +17,8 @@ type DatabaseConfig struct { type MatrixConfig struct { AllowedRooms []string `json:"allowed_rooms"` - DisplayName string `json:"display_name"` + DeviceName string `json:"device_name"` + Greeting string LogLevel uint `json:"log_level"` HomeserverURL string `json:"homeserver_url"` UserIdentifier string `json:"user_identifier"` diff --git a/internal/matrix/client.go b/internal/matrix/client.go index a4fd52e..8c9aa1d 100644 --- a/internal/matrix/client.go +++ b/internal/matrix/client.go @@ -19,12 +19,10 @@ type Client struct { client *mautrix.Client config *config.MatrixConfig db *sql.DB - deviceID id.DeviceID greetedRooms []id.RoomID olmMachine *crypto.OlmMachine store *sqlStore - syncer *mautrix.DefaultSyncer - userID id.UserID + syncer mautrix.Syncer } func NewClient(config *config.MatrixConfig, db *sql.DB) (*Client, error) { @@ -44,36 +42,43 @@ func NewClient(config *config.MatrixConfig, db *sql.DB) (*Client, error) { return nil, err } - userID, err := makeUserID(config.UserIdentifier, config.HomeserverURL) - if err != nil { - return nil, err + c := &Client{ + client: client, + config: config, + db: db, + store: store, + syncer: syncer, } - deviceID := loadDeviceID(db, userID) + syncer.OnEventType(event.StateMember, c.handleMemberEvent) - return &Client{client, config, db, deviceID, nil, nil, store, syncer, userID}, nil + return c, nil } func (c *Client) Login() error { - c.syncer.OnEventType(event.StateMember, c.handleMemberEvent) + userID, err := makeUserID(c.config.UserIdentifier, c.config.HomeserverURL) + if err != nil { + return err + } - resp, err := c.client.Login(&mautrix.ReqLogin{ + deviceID := c.loadDeviceID(userID) + + _, err := c.client.Login(&mautrix.ReqLogin{ Type: mautrix.AuthTypePassword, Identifier: mautrix.UserIdentifier{ Type: mautrix.IdentifierTypeUser, User: c.config.UserIdentifier, }, Password: c.config.Password, - DeviceID: c.deviceID, - InitialDeviceDisplayName: c.config.DisplayName, + DeviceID: deviceID, + InitialDeviceDisplayName: c.config.DeviceName, StoreCredentials: true, }) if err != nil { return err } - c.deviceID = resp.DeviceID - log.Print("device ID: ", c.deviceID) + log.Print("device ID: ", c.client.DeviceID) return nil } @@ -82,8 +87,8 @@ func (c *Client) Encrypt() error { sqlCryptoStore := crypto.NewSQLCryptoStore( c.db, database.DBDriverName, - c.userID.String(), - c.deviceID, + c.client.UserID.String(), + c.client.DeviceID, []byte(c.config.PickleKey), logger{}, ) @@ -107,7 +112,7 @@ func (c *Client) Encrypt() error { return nil } -func (c *Client) Sync(ctx context.Context, wg *sync.WaitGroup) { +func (c *Client) Sync(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup) { wg.Add(1) defer wg.Done() @@ -119,12 +124,13 @@ func (c *Client) Sync(ctx context.Context, wg *sync.WaitGroup) { err := c.client.SyncWithContext(ctx) if err != nil && err != ctx.Err() { log.Print(err) + cancel() } } } } -func (c *Client) Send(roomID id.RoomID, message *event.MessageEventContent) { +func (c *Client) Send(roomID id.RoomID, message *event.MessageEventContent) error { content := event.Content{Parsed: message} if c.store.IsEncrypted(roomID) { @@ -132,32 +138,32 @@ func (c *Client) Send(roomID id.RoomID, message *event.MessageEventContent) { if err == crypto.NoGroupSession || err == crypto.SessionExpired || err == crypto.SessionNotShared { err = c.olmMachine.ShareGroupSession(roomID, c.store.GetRoomMembers(roomID)) if err != nil { - log.Print(err) - return + return err } encrypted, err = c.olmMachine.EncryptMegolmEvent(roomID, event.EventMessage, &content) } if err != nil { - log.Print(err) - return + return err } _, err = c.client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) if err != nil { - log.Print(err) + return err } } else { _, err := c.client.SendMessageEvent(roomID, event.EventMessage, &content) if err != nil { - log.Print(err) + return err } } + + return nil } -func (c *Client) Broadcast(message *event.MessageEventContent) { - for _, roomID := range c.store.FindAllSharedRooms(c.userID) { +func (c *Client) Broadcast(message *event.MessageEventContent) (success bool) { + for _, roomID := range c.store.FindAllSharedRooms(c.client.UserID) { allowed := false for _, room := range c.config.AllowedRooms { if room == roomID.String() { @@ -170,7 +176,12 @@ func (c *Client) Broadcast(message *event.MessageEventContent) { continue } - c.Send(roomID, message) + err := c.Send(roomID, message) + if err != nil { + log.Print(err) + } else { + success = true + } } } @@ -181,7 +192,7 @@ func (c *Client) handleMemberEvent(source mautrix.EventSource, evt *event.Event) c.store.SetMembership(evt.RoomID, evt.GetStateKey(), evt.Content.AsMember().Membership) - if evt.GetStateKey() == c.userID.String() && evt.Content.AsMember().Membership == event.MembershipInvite { + if evt.GetStateKey() == c.client.UserID.String() && evt.Content.AsMember().Membership == event.MembershipInvite { allowed := false for _, room := range c.config.AllowedRooms { if room == evt.RoomID.String() { @@ -204,9 +215,14 @@ func (c *Client) handleMemberEvent(source mautrix.EventSource, evt *event.Event) message := event.MessageEventContent{ MsgType: event.MsgText, - Body: "Ich heiße Marvin \U0001f41e und werde ab jetzt neue Ticketbestellungen verkünden.", + Body: c.config.Greeting, + } + + err = c.Send(evt.RoomID, &message) + if err != nil { + log.Print(err) + return } - c.Send(evt.RoomID, &message) c.greetedRooms = append(c.greetedRooms, evt.RoomID) } else { @@ -222,8 +238,8 @@ func (c *Client) handleEncryptionEvent(source mautrix.EventSource, evt *event.Ev c.store.SetEncryptionEvent(evt.RoomID, evt.Content.AsEncryption()) } -func loadDeviceID(db *sql.DB, accountID id.UserID) (deviceID id.DeviceID) { - row := db.QueryRow("SELECT device_id FROM crypto_account WHERE account_id = ?;", accountID) +func (c *Client) loadDeviceID(accountID id.UserID) (deviceID id.DeviceID) { + row := c.db.QueryRow("SELECT device_id FROM crypto_account WHERE account_id = ?;", accountID) err := row.Scan(&deviceID) if err != nil { diff --git a/internal/pretix/order_placed.go b/internal/pretix/order_placed.go index 663ca51..9a3c7e1 100644 --- a/internal/pretix/order_placed.go +++ b/internal/pretix/order_placed.go @@ -152,7 +152,10 @@ func (s Server) orderPlaced(w http.ResponseWriter, r *http.Request) { Body: fmt.Sprint("\U0001f389 Es ", verb, free, conjunction, paid, noun, total, "bestellt."), } - s.matrix.Broadcast(&message) + success := s.matrix.Broadcast(&message) + if !success { + writeStatus(w, http.StatusInternalServerError) + } } } diff --git a/main.go b/main.go index 01a1b43..a569965 100644 --- a/main.go +++ b/main.go @@ -58,7 +58,7 @@ func main() { stop() }() - go client.Sync(ctx, &wg) + go client.Sync(ctx, stop, &wg) server := pretix.NewServer(&config.Server, db, client)