nyxmuks/pkg/hicli/hicli.go
2025-03-04 22:54:20 +02:00

283 lines
8.1 KiB
Go

// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
// Package hicli contains a highly opinionated high-level framework for developing instant messaging clients on Matrix.
package hicli
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exerrors"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/backup"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
"go.mau.fi/gomuks/pkg/hicli/database"
)
type HiClient struct {
DB *database.Database
CryptoDB *dbutil.Database
Account *database.Account
Client *mautrix.Client
Crypto *crypto.OlmMachine
CryptoStore *crypto.SQLCryptoStore
ClientStore *database.ClientStateStore
Log zerolog.Logger
Verified bool
KeyBackupVersion id.KeyBackupVersion
KeyBackupKey *backup.MegolmBackupKey
PushRules atomic.Pointer[pushrules.PushRuleset]
SyncStatus atomic.Pointer[SyncStatus]
syncErrors int
lastSync time.Time
ToDeviceInSync atomic.Bool
EventHandler func(evt any)
LogoutFunc func(context.Context) error
firstSyncReceived bool
syncingID int
syncLock sync.Mutex
stopSync atomic.Pointer[context.CancelFunc]
encryptLock sync.Mutex
requestQueueWakeup chan struct{}
jsonRequestsLock sync.Mutex
jsonRequests map[int64]context.CancelCauseFunc
paginationInterrupterLock sync.Mutex
paginationInterrupter map[id.RoomID]context.CancelCauseFunc
}
var ErrTimelineReset = errors.New("got limited timeline sync response")
func New(rawDB, cryptoDB *dbutil.Database, log zerolog.Logger, pickleKey []byte, evtHandler func(any)) *HiClient {
if cryptoDB == nil {
cryptoDB = rawDB
}
if rawDB.Owner == "" {
rawDB.Owner = "hicli"
rawDB.IgnoreForeignTables = true
}
if rawDB.Log == nil {
rawDB.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger())
}
db := database.New(rawDB)
c := &HiClient{
DB: db,
Log: log,
requestQueueWakeup: make(chan struct{}, 1),
jsonRequests: make(map[int64]context.CancelCauseFunc),
paginationInterrupter: make(map[id.RoomID]context.CancelCauseFunc),
EventHandler: evtHandler,
}
if cryptoDB != rawDB {
c.CryptoDB = cryptoDB
}
c.SyncStatus.Store(syncWaiting)
c.ClientStore = &database.ClientStateStore{Database: db}
c.Client = &mautrix.Client{
UserAgent: mautrix.DefaultUserAgent,
Client: &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext,
// This needs to be relatively high to allow initial syncs,
// it's lowered after the first sync in postProcessSyncResponse
ResponseHeaderTimeout: 300 * time.Second,
// Default settings from http.DefaultTransport
Proxy: http.ProxyFromEnvironment,
ForceAttemptHTTP2: true,
MaxIdleConns: 5,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
Timeout: 300 * time.Second,
},
Syncer: (*hiSyncer)(c),
Store: (*hiStore)(c),
StateStore: c.ClientStore,
Log: log.With().Str("component", "mautrix client").Logger(),
}
c.CryptoStore = crypto.NewSQLCryptoStore(cryptoDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey)
cryptoLog := log.With().Str("component", "crypto").Logger()
c.Crypto = crypto.NewOlmMachine(c.Client, &cryptoLog, c.CryptoStore, c.ClientStore)
c.Crypto.SessionReceived = c.handleReceivedMegolmSession
c.Crypto.DisableRatchetTracking = true
c.Crypto.DisableDecryptKeyFetching = true
c.Client.Crypto = (*hiCryptoHelper)(c)
return c
}
func (h *HiClient) tempClient(homeserverURL string) (*mautrix.Client, error) {
parsedURL, err := url.Parse(homeserverURL)
if err != nil {
return nil, err
}
return &mautrix.Client{
HomeserverURL: parsedURL,
UserAgent: h.Client.UserAgent,
Client: h.Client.Client,
Log: h.Log.With().Str("component", "temp mautrix client").Logger(),
}, nil
}
func (h *HiClient) IsLoggedIn() bool {
return h.Account != nil
}
func (h *HiClient) Start(ctx context.Context, userID id.UserID, expectedAccount *database.Account) error {
if expectedAccount != nil && userID != expectedAccount.UserID {
panic(fmt.Errorf("invalid parameters: different user ID in expected account and user ID"))
}
err := h.DB.Upgrade(ctx)
if err != nil {
return fmt.Errorf("failed to upgrade hicli db: %w", err)
}
err = h.CryptoStore.DB.Upgrade(ctx)
if err != nil {
return fmt.Errorf("failed to upgrade crypto db: %w", err)
}
account, err := h.DB.Account.Get(ctx, userID)
if err != nil {
return err
} else if account == nil && expectedAccount != nil {
err = h.DB.Account.Put(ctx, expectedAccount)
if err != nil {
return err
}
account = expectedAccount
} else if expectedAccount != nil && expectedAccount.DeviceID != account.DeviceID {
return fmt.Errorf("device ID mismatch: expected %s, got %s", expectedAccount.DeviceID, account.DeviceID)
}
if account != nil {
zerolog.Ctx(ctx).Debug().Stringer("user_id", account.UserID).Msg("Preparing client with existing credentials")
h.Account = account
h.CryptoStore.AccountID = account.UserID.String()
h.CryptoStore.DeviceID = account.DeviceID
h.Client.UserID = account.UserID
h.Client.DeviceID = account.DeviceID
h.Client.AccessToken = account.AccessToken
h.Client.HomeserverURL, err = url.Parse(account.HomeserverURL)
if err != nil {
return err
}
err = h.CheckServerVersions(ctx)
if err != nil {
return err
}
err = h.Crypto.Load(ctx)
if err != nil {
return fmt.Errorf("failed to load olm machine: %w", err)
}
h.Verified, err = h.checkIsCurrentDeviceVerified(ctx)
if err != nil {
return err
}
zerolog.Ctx(ctx).Debug().Bool("verified", h.Verified).Msg("Checked current device verification status")
if h.Verified {
err = h.loadPrivateKeys(ctx)
if err != nil {
return err
}
go h.Sync()
}
}
return nil
}
var ErrFailedToCheckServerVersions = errors.New("failed to check server versions")
var ErrOutdatedServer = errors.New("homeserver is outdated")
var MinimumSpecVersion = mautrix.SpecV11
func (h *HiClient) CheckServerVersions(ctx context.Context) error {
return h.checkServerVersions(ctx, h.Client)
}
func (h *HiClient) checkServerVersions(ctx context.Context, cli *mautrix.Client) error {
versions, err := cli.Versions(ctx)
if err != nil {
return exerrors.NewDualError(ErrFailedToCheckServerVersions, err)
} else if !versions.Contains(MinimumSpecVersion) {
return fmt.Errorf("%w (minimum: %s, highest supported: %s)", ErrOutdatedServer, MinimumSpecVersion, versions.GetLatest())
}
return nil
}
func (h *HiClient) IsSyncing() bool {
return h.stopSync.Load() != nil
}
func (h *HiClient) Sync() {
h.Client.StopSync()
if fn := h.stopSync.Load(); fn != nil {
(*fn)()
}
h.syncLock.Lock()
defer h.syncLock.Unlock()
h.syncingID++
syncingID := h.syncingID
log := h.Log.With().
Str("action", "sync").
Int("sync_id", syncingID).
Logger()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h.stopSync.Store(&cancel)
go h.RunRequestQueue(h.Log.WithContext(ctx))
go h.LoadPushRules(h.Log.WithContext(ctx))
ctx = log.WithContext(ctx)
log.Info().Msg("Starting syncing")
err := h.Client.SyncWithContext(ctx)
if err != nil && ctx.Err() == nil {
h.markSyncErrored(err, true)
log.Err(err).Msg("Fatal error in syncer")
} else {
h.SyncStatus.Store(syncWaiting)
log.Info().Msg("Syncing stopped")
}
}
func (h *HiClient) Stop() {
h.Client.StopSync()
if fn := h.stopSync.Swap(nil); fn != nil {
(*fn)()
}
h.syncLock.Lock()
//lint:ignore SA2001 just acquire the lock to make sure Sync is done
h.syncLock.Unlock()
err := h.DB.Close()
if err != nil {
h.Log.Err(err).Msg("Failed to close database cleanly")
}
if h.CryptoDB != nil {
err = h.CryptoDB.Close()
if err != nil {
h.Log.Err(err).Msg("Failed to close crypto database cleanly")
}
}
}