230 lines
5.2 KiB
Go
230 lines
5.2 KiB
Go
|
// Copyright (c) 2020 Tulir Asokan
|
||
|
//
|
||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||
|
|
||
|
package crypto
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"time"
|
||
|
|
||
|
"maunium.net/go/mautrix/crypto/olm"
|
||
|
"maunium.net/go/mautrix/event"
|
||
|
|
||
|
"maunium.net/go/mautrix/id"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
SessionNotShared = errors.New("session has not been shared")
|
||
|
SessionExpired = errors.New("session has expired")
|
||
|
)
|
||
|
|
||
|
// OlmSessionList is a list of OlmSessions.
|
||
|
// It implements sort.Interface so that the session with recent successful decryptions comes first.
|
||
|
type OlmSessionList []*OlmSession
|
||
|
|
||
|
func (o OlmSessionList) Len() int {
|
||
|
return len(o)
|
||
|
}
|
||
|
|
||
|
func (o OlmSessionList) Less(i, j int) bool {
|
||
|
return o[i].LastDecryptedTime.After(o[j].LastEncryptedTime)
|
||
|
}
|
||
|
|
||
|
func (o OlmSessionList) Swap(i, j int) {
|
||
|
o[i], o[j] = o[j], o[i]
|
||
|
}
|
||
|
|
||
|
type OlmSession struct {
|
||
|
Internal olm.Session
|
||
|
ExpirationMixin
|
||
|
id id.SessionID
|
||
|
}
|
||
|
|
||
|
func (session *OlmSession) ID() id.SessionID {
|
||
|
if session.id == "" {
|
||
|
session.id = session.Internal.ID()
|
||
|
}
|
||
|
return session.id
|
||
|
}
|
||
|
|
||
|
func (session *OlmSession) Describe() string {
|
||
|
return session.Internal.Describe()
|
||
|
}
|
||
|
|
||
|
func wrapSession(session *olm.Session) *OlmSession {
|
||
|
return &OlmSession{
|
||
|
Internal: *session,
|
||
|
ExpirationMixin: ExpirationMixin{
|
||
|
TimeMixin: TimeMixin{
|
||
|
CreationTime: time.Now(),
|
||
|
LastEncryptedTime: time.Now(),
|
||
|
LastDecryptedTime: time.Now(),
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, ciphertext string) (*OlmSession, error) {
|
||
|
session, err := account.Internal.NewInboundSessionFrom(senderKey, ciphertext)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
_ = account.Internal.RemoveOneTimeKeys(session)
|
||
|
return wrapSession(session), nil
|
||
|
}
|
||
|
|
||
|
func (session *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
|
||
|
session.LastEncryptedTime = time.Now()
|
||
|
return session.Internal.Encrypt(plaintext)
|
||
|
}
|
||
|
|
||
|
func (session *OlmSession) Decrypt(ciphertext string, msgType id.OlmMsgType) ([]byte, error) {
|
||
|
msg, err := session.Internal.Decrypt(ciphertext, msgType)
|
||
|
if err == nil {
|
||
|
session.LastDecryptedTime = time.Now()
|
||
|
}
|
||
|
return msg, err
|
||
|
}
|
||
|
|
||
|
type InboundGroupSession struct {
|
||
|
Internal olm.InboundGroupSession
|
||
|
|
||
|
SigningKey id.Ed25519
|
||
|
SenderKey id.Curve25519
|
||
|
RoomID id.RoomID
|
||
|
|
||
|
ForwardingChains []string
|
||
|
|
||
|
id id.SessionID
|
||
|
}
|
||
|
|
||
|
func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string) (*InboundGroupSession, error) {
|
||
|
igs, err := olm.NewInboundGroupSession([]byte(sessionKey))
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return &InboundGroupSession{
|
||
|
Internal: *igs,
|
||
|
SigningKey: signingKey,
|
||
|
SenderKey: senderKey,
|
||
|
RoomID: roomID,
|
||
|
ForwardingChains: nil,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (igs *InboundGroupSession) ID() id.SessionID {
|
||
|
if igs.id == "" {
|
||
|
igs.id = igs.Internal.ID()
|
||
|
}
|
||
|
return igs.id
|
||
|
}
|
||
|
|
||
|
type OGSState int
|
||
|
|
||
|
const (
|
||
|
OGSNotShared OGSState = iota
|
||
|
OGSAlreadyShared
|
||
|
OGSIgnored
|
||
|
)
|
||
|
|
||
|
type UserDevice struct {
|
||
|
UserID id.UserID
|
||
|
DeviceID id.DeviceID
|
||
|
}
|
||
|
|
||
|
type OutboundGroupSession struct {
|
||
|
Internal olm.OutboundGroupSession
|
||
|
|
||
|
ExpirationMixin
|
||
|
MaxMessages int
|
||
|
MessageCount int
|
||
|
|
||
|
Users map[UserDevice]OGSState
|
||
|
RoomID id.RoomID
|
||
|
Shared bool
|
||
|
|
||
|
id id.SessionID
|
||
|
content *event.RoomKeyEventContent
|
||
|
}
|
||
|
|
||
|
func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) *OutboundGroupSession {
|
||
|
ogs := &OutboundGroupSession{
|
||
|
Internal: *olm.NewOutboundGroupSession(),
|
||
|
ExpirationMixin: ExpirationMixin{
|
||
|
TimeMixin: TimeMixin{
|
||
|
CreationTime: time.Now(),
|
||
|
LastEncryptedTime: time.Now(),
|
||
|
},
|
||
|
MaxAge: 7 * 24 * time.Hour,
|
||
|
},
|
||
|
MaxMessages: 100,
|
||
|
Shared: false,
|
||
|
Users: make(map[UserDevice]OGSState),
|
||
|
RoomID: roomID,
|
||
|
}
|
||
|
if encryptionContent != nil {
|
||
|
if encryptionContent.RotationPeriodMillis != 0 {
|
||
|
ogs.MaxAge = time.Duration(encryptionContent.RotationPeriodMillis) * time.Millisecond
|
||
|
}
|
||
|
if encryptionContent.RotationPeriodMessages != 0 {
|
||
|
ogs.MaxMessages = encryptionContent.RotationPeriodMessages
|
||
|
}
|
||
|
}
|
||
|
return ogs
|
||
|
}
|
||
|
|
||
|
func (ogs *OutboundGroupSession) ShareContent() event.Content {
|
||
|
if ogs.content == nil {
|
||
|
ogs.content = &event.RoomKeyEventContent{
|
||
|
Algorithm: id.AlgorithmMegolmV1,
|
||
|
RoomID: ogs.RoomID,
|
||
|
SessionID: ogs.ID(),
|
||
|
SessionKey: ogs.Internal.Key(),
|
||
|
}
|
||
|
}
|
||
|
return event.Content{Parsed: ogs.content}
|
||
|
}
|
||
|
|
||
|
func (ogs *OutboundGroupSession) ID() id.SessionID {
|
||
|
if ogs.id == "" {
|
||
|
ogs.id = ogs.Internal.ID()
|
||
|
}
|
||
|
return ogs.id
|
||
|
}
|
||
|
|
||
|
func (ogs *OutboundGroupSession) Expired() bool {
|
||
|
return ogs.MessageCount >= ogs.MaxMessages || ogs.ExpirationMixin.Expired()
|
||
|
}
|
||
|
|
||
|
func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
|
||
|
if !ogs.Shared {
|
||
|
return nil, SessionNotShared
|
||
|
} else if ogs.Expired() {
|
||
|
return nil, SessionExpired
|
||
|
}
|
||
|
ogs.MessageCount++
|
||
|
ogs.LastEncryptedTime = time.Now()
|
||
|
return ogs.Internal.Encrypt(plaintext), nil
|
||
|
}
|
||
|
|
||
|
type TimeMixin struct {
|
||
|
CreationTime time.Time
|
||
|
LastEncryptedTime time.Time
|
||
|
LastDecryptedTime time.Time
|
||
|
}
|
||
|
|
||
|
type ExpirationMixin struct {
|
||
|
TimeMixin
|
||
|
MaxAge time.Duration
|
||
|
}
|
||
|
|
||
|
func (exp *ExpirationMixin) Expired() bool {
|
||
|
if exp.MaxAge == 0 {
|
||
|
return false
|
||
|
}
|
||
|
return exp.CreationTime.Add(exp.MaxAge).Before(time.Now())
|
||
|
}
|