235 lines
5.1 KiB
Go
235 lines
5.1 KiB
Go
package matrix
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"log"
|
|
"sync"
|
|
|
|
"git.luj0ga.de/franconian/matrix-pretix/internal/config"
|
|
"git.luj0ga.de/franconian/matrix-pretix/internal/database"
|
|
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/crypto"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
type Client struct {
|
|
client *mautrix.Client
|
|
config *config.MatrixConfig
|
|
db *sql.DB
|
|
deviceID id.DeviceID
|
|
greetedRooms []id.RoomID
|
|
olmMachine *crypto.OlmMachine
|
|
store *sqlStore
|
|
syncer *mautrix.DefaultSyncer
|
|
userID id.UserID
|
|
}
|
|
|
|
func NewClient(config *config.MatrixConfig, db *sql.DB) (*Client, error) {
|
|
client, err := mautrix.NewClient(config.HomeserverURL, "", "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
syncer := mautrix.NewDefaultSyncer()
|
|
client.Syncer = syncer
|
|
|
|
store := &sqlStore{db}
|
|
client.Store = store
|
|
|
|
err = store.CreateTables()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
userID, err := makeUserID(config.UserIdentifier, config.HomeserverURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
deviceID := loadDeviceID(db, userID)
|
|
|
|
return &Client{client, config, db, deviceID, nil, nil, store, syncer, userID}, nil
|
|
}
|
|
|
|
func (c *Client) Login() error {
|
|
c.syncer.OnEventType(event.StateMember, c.handleMemberEvent)
|
|
|
|
resp, err := c.client.Login(&mautrix.ReqLogin{
|
|
Type: mautrix.AuthTypePassword,
|
|
Identifier: mautrix.UserIdentifier{
|
|
Type: mautrix.IdentifierTypeUser,
|
|
User: c.config.UserIdentifier,
|
|
},
|
|
Password: c.config.Password,
|
|
DeviceID: c.deviceID,
|
|
InitialDeviceDisplayName: c.config.DisplayName,
|
|
StoreCredentials: true,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c.deviceID = resp.DeviceID
|
|
log.Print("device ID: ", c.deviceID)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) Encrypt() error {
|
|
sqlCryptoStore := crypto.NewSQLCryptoStore(
|
|
c.db,
|
|
database.DBDriverName,
|
|
c.userID.String(),
|
|
c.deviceID,
|
|
[]byte(c.config.PickleKey),
|
|
logger{},
|
|
)
|
|
|
|
err := sqlCryptoStore.CreateTables()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c.olmMachine = crypto.NewOlmMachine(c.client, &logger{}, sqlCryptoStore, c.store)
|
|
|
|
err = c.olmMachine.Load()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c.syncer.OnSync(c.olmMachine.ProcessSyncResponse)
|
|
|
|
c.syncer.OnEventType(event.StateEncryption, c.handleEncryptionEvent)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) Sync(ctx context.Context, wg *sync.WaitGroup) {
|
|
wg.Add(1)
|
|
defer wg.Done()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
err := c.client.SyncWithContext(ctx)
|
|
if err != nil && err != ctx.Err() {
|
|
log.Print(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) Send(roomID id.RoomID, message *event.MessageEventContent) {
|
|
content := event.Content{Parsed: message}
|
|
|
|
if c.store.IsEncrypted(roomID) {
|
|
encrypted, err := c.olmMachine.EncryptMegolmEvent(roomID, event.EventMessage, &content)
|
|
if err == crypto.NoGroupSession || err == crypto.SessionExpired || err == crypto.SessionNotShared {
|
|
err = c.olmMachine.ShareGroupSession(roomID, c.store.GetRoomMembers(roomID))
|
|
if err != nil {
|
|
log.Print(err)
|
|
return
|
|
}
|
|
|
|
encrypted, err = c.olmMachine.EncryptMegolmEvent(roomID, event.EventMessage, &content)
|
|
}
|
|
|
|
if err != nil {
|
|
log.Print(err)
|
|
return
|
|
}
|
|
|
|
_, err = c.client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
|
|
if err != nil {
|
|
log.Print(err)
|
|
}
|
|
} else {
|
|
_, err := c.client.SendMessageEvent(roomID, event.EventMessage, &content)
|
|
if err != nil {
|
|
log.Print(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) Broadcast(message *event.MessageEventContent) {
|
|
for _, roomID := range c.store.FindAllSharedRooms(c.userID) {
|
|
allowed := false
|
|
for _, room := range c.config.AllowedRooms {
|
|
if room == roomID.String() {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !allowed {
|
|
continue
|
|
}
|
|
|
|
c.Send(roomID, message)
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleMemberEvent(source mautrix.EventSource, evt *event.Event) {
|
|
if c.olmMachine != nil {
|
|
c.olmMachine.HandleMemberEvent(evt)
|
|
}
|
|
|
|
c.store.SetMembership(evt.RoomID, evt.GetStateKey(), evt.Content.AsMember().Membership)
|
|
|
|
if evt.GetStateKey() == c.userID.String() && evt.Content.AsMember().Membership == event.MembershipInvite {
|
|
allowed := false
|
|
for _, room := range c.config.AllowedRooms {
|
|
if room == evt.RoomID.String() {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if allowed {
|
|
_, err := c.client.JoinRoomByID(evt.RoomID)
|
|
if err != nil {
|
|
log.Print(err)
|
|
}
|
|
|
|
for _, room := range c.greetedRooms {
|
|
if room == evt.RoomID {
|
|
return
|
|
}
|
|
}
|
|
|
|
message := event.MessageEventContent{
|
|
MsgType: event.MsgText,
|
|
Body: "Ich heiße Marvin \U0001f41e und werde ab jetzt neue Ticketbestellungen verkünden.",
|
|
}
|
|
c.Send(evt.RoomID, &message)
|
|
|
|
c.greetedRooms = append(c.greetedRooms, evt.RoomID)
|
|
} else {
|
|
_, err := c.client.LeaveRoom(evt.RoomID)
|
|
if err != nil {
|
|
log.Print(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleEncryptionEvent(source mautrix.EventSource, evt *event.Event) {
|
|
c.store.SetEncryptionEvent(evt.RoomID, evt.Content.AsEncryption())
|
|
}
|
|
|
|
func loadDeviceID(db *sql.DB, accountID id.UserID) (deviceID id.DeviceID) {
|
|
row := db.QueryRow("SELECT device_id FROM crypto_account WHERE account_id = ?;", accountID)
|
|
|
|
err := row.Scan(&deviceID)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
return deviceID
|
|
}
|