Compare commits

..

No commits in common. "bdb42f9699f307a3fe12cd8ec3e3b330ddc2f9a2" and "4fb9fa20603b606a68201c128c244b91c26e388b" have entirely different histories.

4 changed files with 34 additions and 56 deletions

View File

@ -17,8 +17,7 @@ type DatabaseConfig struct {
type MatrixConfig struct { type MatrixConfig struct {
AllowedRooms []string `json:"allowed_rooms"` AllowedRooms []string `json:"allowed_rooms"`
DeviceName string `json:"device_name"` DisplayName string `json:"display_name"`
Greeting string
LogLevel uint `json:"log_level"` LogLevel uint `json:"log_level"`
HomeserverURL string `json:"homeserver_url"` HomeserverURL string `json:"homeserver_url"`
UserIdentifier string `json:"user_identifier"` UserIdentifier string `json:"user_identifier"`

View File

@ -19,10 +19,12 @@ type Client struct {
client *mautrix.Client client *mautrix.Client
config *config.MatrixConfig config *config.MatrixConfig
db *sql.DB db *sql.DB
deviceID id.DeviceID
greetedRooms []id.RoomID greetedRooms []id.RoomID
olmMachine *crypto.OlmMachine olmMachine *crypto.OlmMachine
store *sqlStore store *sqlStore
syncer *mautrix.DefaultSyncer syncer *mautrix.DefaultSyncer
userID id.UserID
} }
func NewClient(config *config.MatrixConfig, db *sql.DB) (*Client, error) { func NewClient(config *config.MatrixConfig, db *sql.DB) (*Client, error) {
@ -42,43 +44,36 @@ func NewClient(config *config.MatrixConfig, db *sql.DB) (*Client, error) {
return nil, err return nil, err
} }
c := &Client{ userID, err := makeUserID(config.UserIdentifier, config.HomeserverURL)
client: client, if err != nil {
config: config, return nil, err
db: db,
store: store,
syncer: syncer,
} }
syncer.OnEventType(event.StateMember, c.handleMemberEvent) deviceID := loadDeviceID(db, userID)
return c, nil return &Client{client, config, db, deviceID, nil, nil, store, syncer, userID}, nil
} }
func (c *Client) Login() error { func (c *Client) Login() error {
userID, err := makeUserID(c.config.UserIdentifier, c.config.HomeserverURL) c.syncer.OnEventType(event.StateMember, c.handleMemberEvent)
if err != nil {
return err
}
deviceID := c.loadDeviceID(userID) resp, err := c.client.Login(&mautrix.ReqLogin{
_, err = c.client.Login(&mautrix.ReqLogin{
Type: mautrix.AuthTypePassword, Type: mautrix.AuthTypePassword,
Identifier: mautrix.UserIdentifier{ Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser, Type: mautrix.IdentifierTypeUser,
User: c.config.UserIdentifier, User: c.config.UserIdentifier,
}, },
Password: c.config.Password, Password: c.config.Password,
DeviceID: deviceID, DeviceID: c.deviceID,
InitialDeviceDisplayName: c.config.DeviceName, InitialDeviceDisplayName: c.config.DisplayName,
StoreCredentials: true, StoreCredentials: true,
}) })
if err != nil { if err != nil {
return err return err
} }
log.Print("device ID: ", c.client.DeviceID) c.deviceID = resp.DeviceID
log.Print("device ID: ", c.deviceID)
return nil return nil
} }
@ -87,8 +82,8 @@ func (c *Client) Encrypt() error {
sqlCryptoStore := crypto.NewSQLCryptoStore( sqlCryptoStore := crypto.NewSQLCryptoStore(
c.db, c.db,
database.DBDriverName, database.DBDriverName,
c.client.UserID.String(), c.userID.String(),
c.client.DeviceID, c.deviceID,
[]byte(c.config.PickleKey), []byte(c.config.PickleKey),
logger{}, logger{},
) )
@ -112,7 +107,7 @@ func (c *Client) Encrypt() error {
return nil return nil
} }
func (c *Client) Sync(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup) { func (c *Client) Sync(ctx context.Context, wg *sync.WaitGroup) {
wg.Add(1) wg.Add(1)
defer wg.Done() defer wg.Done()
@ -124,13 +119,12 @@ func (c *Client) Sync(ctx context.Context, cancel context.CancelFunc, wg *sync.W
err := c.client.SyncWithContext(ctx) err := c.client.SyncWithContext(ctx)
if err != nil && err != ctx.Err() { if err != nil && err != ctx.Err() {
log.Print(err) log.Print(err)
cancel()
} }
} }
} }
} }
func (c *Client) Send(roomID id.RoomID, message *event.MessageEventContent) error { func (c *Client) Send(roomID id.RoomID, message *event.MessageEventContent) {
content := event.Content{Parsed: message} content := event.Content{Parsed: message}
if c.store.IsEncrypted(roomID) { if c.store.IsEncrypted(roomID) {
@ -138,32 +132,32 @@ func (c *Client) Send(roomID id.RoomID, message *event.MessageEventContent) erro
if err == crypto.NoGroupSession || err == crypto.SessionExpired || err == crypto.SessionNotShared { if err == crypto.NoGroupSession || err == crypto.SessionExpired || err == crypto.SessionNotShared {
err = c.olmMachine.ShareGroupSession(roomID, c.store.GetRoomMembers(roomID)) err = c.olmMachine.ShareGroupSession(roomID, c.store.GetRoomMembers(roomID))
if err != nil { if err != nil {
return err log.Print(err)
return
} }
encrypted, err = c.olmMachine.EncryptMegolmEvent(roomID, event.EventMessage, &content) encrypted, err = c.olmMachine.EncryptMegolmEvent(roomID, event.EventMessage, &content)
} }
if err != nil { if err != nil {
return err log.Print(err)
return
} }
_, err = c.client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) _, err = c.client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
if err != nil { if err != nil {
return err log.Print(err)
} }
} else { } else {
_, err := c.client.SendMessageEvent(roomID, event.EventMessage, &content) _, err := c.client.SendMessageEvent(roomID, event.EventMessage, &content)
if err != nil { if err != nil {
return err log.Print(err)
} }
} }
return nil
} }
func (c *Client) Broadcast(message *event.MessageEventContent) (success bool) { func (c *Client) Broadcast(message *event.MessageEventContent) {
for _, roomID := range c.store.FindAllSharedRooms(c.client.UserID) { for _, roomID := range c.store.FindAllSharedRooms(c.userID) {
allowed := false allowed := false
for _, room := range c.config.AllowedRooms { for _, room := range c.config.AllowedRooms {
if room == roomID.String() { if room == roomID.String() {
@ -176,15 +170,8 @@ func (c *Client) Broadcast(message *event.MessageEventContent) (success bool) {
continue continue
} }
err := c.Send(roomID, message) c.Send(roomID, message)
if err != nil {
log.Print(err)
} else {
success = true
} }
}
return success
} }
func (c *Client) handleMemberEvent(source mautrix.EventSource, evt *event.Event) { func (c *Client) handleMemberEvent(source mautrix.EventSource, evt *event.Event) {
@ -194,7 +181,7 @@ func (c *Client) handleMemberEvent(source mautrix.EventSource, evt *event.Event)
c.store.SetMembership(evt.RoomID, evt.GetStateKey(), evt.Content.AsMember().Membership) c.store.SetMembership(evt.RoomID, evt.GetStateKey(), evt.Content.AsMember().Membership)
if evt.GetStateKey() == c.client.UserID.String() && evt.Content.AsMember().Membership == event.MembershipInvite { if evt.GetStateKey() == c.userID.String() && evt.Content.AsMember().Membership == event.MembershipInvite {
allowed := false allowed := false
for _, room := range c.config.AllowedRooms { for _, room := range c.config.AllowedRooms {
if room == evt.RoomID.String() { if room == evt.RoomID.String() {
@ -217,14 +204,9 @@ func (c *Client) handleMemberEvent(source mautrix.EventSource, evt *event.Event)
message := event.MessageEventContent{ message := event.MessageEventContent{
MsgType: event.MsgText, MsgType: event.MsgText,
Body: c.config.Greeting, Body: "Ich heiße Marvin \U0001f41e und werde ab jetzt neue Ticketbestellungen verkünden.",
}
err = c.Send(evt.RoomID, &message)
if err != nil {
log.Print(err)
return
} }
c.Send(evt.RoomID, &message)
c.greetedRooms = append(c.greetedRooms, evt.RoomID) c.greetedRooms = append(c.greetedRooms, evt.RoomID)
} else { } else {
@ -240,8 +222,8 @@ func (c *Client) handleEncryptionEvent(source mautrix.EventSource, evt *event.Ev
c.store.SetEncryptionEvent(evt.RoomID, evt.Content.AsEncryption()) c.store.SetEncryptionEvent(evt.RoomID, evt.Content.AsEncryption())
} }
func (c *Client) loadDeviceID(accountID id.UserID) (deviceID id.DeviceID) { func loadDeviceID(db *sql.DB, accountID id.UserID) (deviceID id.DeviceID) {
row := c.db.QueryRow("SELECT device_id FROM crypto_account WHERE account_id = ?;", accountID) row := db.QueryRow("SELECT device_id FROM crypto_account WHERE account_id = ?;", accountID)
err := row.Scan(&deviceID) err := row.Scan(&deviceID)
if err != nil { if err != nil {

View File

@ -152,10 +152,7 @@ func (s Server) orderPlaced(w http.ResponseWriter, r *http.Request) {
Body: fmt.Sprint("\U0001f389 Es ", verb, free, conjunction, paid, noun, total, "bestellt."), Body: fmt.Sprint("\U0001f389 Es ", verb, free, conjunction, paid, noun, total, "bestellt."),
} }
success := s.matrix.Broadcast(&message) s.matrix.Broadcast(&message)
if !success {
writeStatus(w, http.StatusInternalServerError)
}
} }
} }

View File

@ -58,7 +58,7 @@ func main() {
stop() stop()
}() }()
go client.Sync(ctx, stop, &wg) go client.Sync(ctx, &wg)
server := pretix.NewServer(&config.Server, db, client) server := pretix.NewServer(&config.Server, db, client)