Skip to content

Commit

Permalink
decrypt,send: msmsg decryption and bot message sending support (#615)
Browse files Browse the repository at this point in the history
  • Loading branch information
purpshell authored Jul 6, 2024
1 parent 1295ce2 commit deecc4d
Show file tree
Hide file tree
Showing 11 changed files with 476 additions and 33 deletions.
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ var (
ErrUnknownServer = errors.New("can't send message to unknown server")
ErrRecipientADJID = errors.New("message recipient must be a user JID with no device part")
ErrServerReturnedError = errors.New("server returned error")
ErrInvalidInlineBotID = errors.New("invalid inline bot ID")
)

type DownloadHTTPError struct {
Expand Down
61 changes: 61 additions & 0 deletions mdtest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"net/http"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"sync/atomic"
Expand Down Expand Up @@ -1022,6 +1023,66 @@ func handleCmd(cmd string, args []string) {
if err != nil {
log.Errorf("Error editing label: %v", err)
}
case "sendbotmsg":
if len(args) < 1 {
log.Errorf("Usage: sendBotMsg <inline jid (optional)> <text>")
return
}
var inlineJID types.JID
if len(args) > 1 {
var numbersRegex = regexp.MustCompile(`^[0-9]+$`)
jid, ok := parseJID(args[0])
if ok && numbersRegex.MatchString(jid.User) {
inlineJID = jid
} else {
inlineJID = types.EmptyJID
}
}

personaID := proto.String("867051314767696$760019659443059") // default meta bot personality: "Assistant"

var resp, err = whatsmeow.SendResponse{}, error(nil)
if !inlineJID.IsEmpty() {
text := fmt.Sprintf("@%s %s", types.MetaAIJID.User, strings.Join(args[1:], " "))
msg := &waE2E.Message{
ExtendedTextMessage: &waE2E.ExtendedTextMessage{
Text: &text,
ContextInfo: &waE2E.ContextInfo{
MentionedJID: []string{types.MetaAIJID.String()},
},
},
MessageContextInfo: &waE2E.MessageContextInfo{
BotMetadata: &waE2E.BotMetadata{
PersonaID: personaID,
},
},
}

resp, err = cli.SendMessage(context.Background(), inlineJID, msg, whatsmeow.SendRequestExtra{
InlineBotJID: types.MetaAIJID,
})
} else {
text := strings.Join(args, " ")
msg := &waE2E.Message{
Conversation: &text,
MessageContextInfo: &waE2E.MessageContextInfo{
BotMetadata: &waE2E.BotMetadata{
PersonaID: personaID,
},
},
}
resp, err = cli.SendMessage(context.Background(), types.MetaAIJID, msg)
}
if err != nil {
log.Errorf("Error sending bot message: %v", err)
} else {
log.Infof("Bot message sent (server timestamp: %s)", resp.Timestamp)
}
case "fetchbotprofiles":
list, _ := cli.GetBotListV2()
log.Infof("Bots list: %+v", list)
profiles, _ := cli.GetBotProfiles(list)
log.Infof("Bots profiles: %+v", profiles)
}
}

Expand Down
87 changes: 86 additions & 1 deletion message.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"runtime/debug"
"time"

"go.mau.fi/whatsmeow/proto/waE2E"

"go.mau.fi/libsignal/groups"
"go.mau.fi/libsignal/protocol"
"go.mau.fi/libsignal/session"
Expand Down Expand Up @@ -88,6 +90,16 @@ func (cli *Client) parseMessageSource(node *waBinary.Node, requireParticipant bo
} else {
source.Chat = from.ToNonAD()
}
} else if from.IsBot() {
source.Sender = from
meta := node.GetChildByTag("meta")
ag = meta.AttrGetter()
targetChatJID := ag.OptionalJID("target_chat_jid")
if targetChatJID != nil {
source.Chat = targetChatJID.ToNonAD()
} else {
source.Chat = from
}
} else {
source.Chat = from.ToNonAD()
source.Sender = from
Expand All @@ -96,6 +108,32 @@ func (cli *Client) parseMessageSource(node *waBinary.Node, requireParticipant bo
return
}

func (cli *Client) parseMsgBotInfo(node waBinary.Node) (botInfo types.MsgBotInfo, err error) {
botNode := node.GetChildByTag("bot")

ag := botNode.AttrGetter()
botInfo.EditType = types.BotEditType(ag.String("edit"))
if botInfo.EditType == types.EditTypeInner || botInfo.EditType == types.EditTypeLast {
botInfo.EditTargetID = types.MessageID(ag.String("edit_target_id"))
botInfo.EditSenderTimestampMS = ag.UnixMilli("sender_timestamp_ms")
}
err = ag.Error()
return
}

func (cli *Client) parseMsgMetaInfo(node waBinary.Node) (metaInfo types.MsgMetaInfo, err error) {
metaNode := node.GetChildByTag("meta")

ag := metaNode.AttrGetter()
metaInfo.TargetID = types.MessageID(ag.String("target_id"))
targetSenderJID := ag.OptionalJIDOrEmpty("target_sender_jid")
if targetSenderJID.User != "" {
metaInfo.TargetSender = targetSenderJID
}
err = ag.Error()
return
}

func (cli *Client) parseMessageInfo(node *waBinary.Node) (*types.MessageInfo, error) {
var info types.MessageInfo
var err error
Expand Down Expand Up @@ -124,6 +162,16 @@ func (cli *Client) parseMessageInfo(node *waBinary.Node) (*types.MessageInfo, er
if err != nil {
cli.Log.Warnf("Failed to parse verified_name node in %s: %v", info.ID, err)
}
case "bot":
info.MsgBotInfo, err = cli.parseMsgBotInfo(child)
if err != nil {
cli.Log.Warnf("Failed to parse <bot> node in %s: %v", info.ID, err)
}
case "meta":
info.MsgMetaInfo, err = cli.parseMsgMetaInfo(child)
if err != nil {
cli.Log.Warnf("Failed to parse <meta> node in %s: %v", info.ID, err)
}
case "franking":
// TODO
case "trace":
Expand Down Expand Up @@ -200,10 +248,47 @@ func (cli *Client) decryptMessages(info *types.MessageInfo, node *waBinary.Node)
containsDirectMsg = true
} else if info.IsGroup && encType == "skmsg" {
decrypted, err = cli.decryptGroupMsg(&child, info.Sender, info.Chat)
} else if encType == "msmsg" && info.Sender.IsBot() {
// Meta AI / other bots (biz?):

// step 1: get message secret
targetSenderJID := info.MsgMetaInfo.TargetSender
if targetSenderJID.User == "" {
// if no targetSenderJID in <meta> this must be ourselves (one-one-one mode)
targetSenderJID = cli.getOwnID()
}

messageSecret, err := cli.Store.MsgSecrets.GetMessageSecret(info.Chat, targetSenderJID, info.MsgMetaInfo.TargetID)
if err != nil || messageSecret == nil {
cli.Log.Warnf("Error getting message secret for bot msg with id %s", node.AttrGetter().String("id"))
continue
}

// step 2: get MessageSecretMessage
byteContents := child.Content.([]byte) // <enc> contents
var msMsg waE2E.MessageSecretMessage

err = proto.Unmarshal(byteContents, &msMsg)
if err != nil {
cli.Log.Warnf("Error decoding MessageSecretMesage protobuf %v", err)
continue
}

// step 3: determine best message id for decryption
var messageID string
if info.MsgBotInfo.EditType == types.EditTypeInner || info.MsgBotInfo.EditType == types.EditTypeLast {
messageID = info.MsgBotInfo.EditTargetID
} else {
messageID = info.ID
}

// step 4: decrypt and voila
decrypted, err = cli.decryptBotMessage(messageSecret, &msMsg, messageID, targetSenderJID, info)
} else {
cli.Log.Warnf("Unhandled encrypted message (type %s) from %s", encType, info.SourceString())
continue
}

if err != nil {
cli.Log.Warnf("Error decrypting message from %s: %v", info.SourceString(), err)
isUnavailable := encType == "skmsg" && !containsDirectMsg && errors.Is(err, signalerror.ErrNoSenderKeyForUser)
Expand All @@ -220,7 +305,7 @@ func (cli *Client) decryptMessages(info *types.MessageInfo, node *waBinary.Node)
cli.cancelDelayedRequestFromPhone(info.ID)
}

var msg waProto.Message
var msg waE2E.Message
switch ag.Int("v") {
case 2:
err = proto.Unmarshal(decrypted, &msg)
Expand Down
40 changes: 30 additions & 10 deletions msgsecret.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import (
"fmt"
"time"

"go.mau.fi/whatsmeow/proto/waCommon"
"go.mau.fi/whatsmeow/proto/waE2E"

"go.mau.fi/util/random"
"google.golang.org/protobuf/proto"

Expand All @@ -26,8 +29,13 @@ type MsgSecretType string
const (
EncSecretPollVote MsgSecretType = "Poll Vote"
EncSecretReaction MsgSecretType = "Enc Reaction"
EncSecretBotMsg MsgSecretType = "Bot Message"
)

func applyBotMessageHKDF(messageSecret []byte) []byte {
return hkdfutil.SHA256(messageSecret, nil, []byte(EncSecretBotMsg), 32)
}

func generateMsgSecretKey(
modificationType MsgSecretType, modificationSender types.JID,
origMsgID types.MessageID, origMsgSender types.JID, origMsgSecret []byte,
Expand All @@ -47,7 +55,7 @@ func generateMsgSecretKey(
return secretKey, additionalData
}

func getOrigSenderFromKey(msg *events.Message, key *waProto.MessageKey) (types.JID, error) {
func getOrigSenderFromKey(msg *events.Message, key *waCommon.MessageKey) (types.JID, error) {
if key.GetFromMe() {
// fromMe always means the poll and vote were sent by the same user
return msg.Info.Sender, nil
Expand All @@ -74,18 +82,18 @@ type messageEncryptedSecret interface {
GetEncPayload() []byte
}

func (cli *Client) decryptMsgSecret(msg *events.Message, useCase MsgSecretType, encrypted messageEncryptedSecret, origMsgKey *waProto.MessageKey) ([]byte, error) {
func (cli *Client) decryptMsgSecret(msg *events.Message, useCase MsgSecretType, encrypted messageEncryptedSecret, origMsgKey *waCommon.MessageKey) ([]byte, error) {
pollSender, err := getOrigSenderFromKey(msg, origMsgKey)
if err != nil {
return nil, err
}
baseEncKey, err := cli.Store.MsgSecrets.GetMessageSecret(msg.Info.Chat, pollSender, origMsgKey.GetId())
baseEncKey, err := cli.Store.MsgSecrets.GetMessageSecret(msg.Info.Chat, pollSender, origMsgKey.GetID())
if err != nil {
return nil, fmt.Errorf("failed to get original message secret key: %w", err)
} else if baseEncKey == nil {
return nil, ErrOriginalMessageSecretNotFound
}
secretKey, additionalData := generateMsgSecretKey(useCase, msg.Info.Sender, origMsgKey.GetId(), pollSender, baseEncKey)
secretKey, additionalData := generateMsgSecretKey(useCase, msg.Info.Sender, origMsgKey.GetID(), pollSender, baseEncKey)
plaintext, err := gcmutil.Decrypt(secretKey, encrypted.GetEncIV(), encrypted.GetEncPayload(), additionalData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt secret message: %w", err)
Expand Down Expand Up @@ -115,6 +123,18 @@ func (cli *Client) encryptMsgSecret(chat, origSender types.JID, origMsgID types.
return ciphertext, iv, nil
}

func (cli *Client) decryptBotMessage(messageSecret []byte, msMsg messageEncryptedSecret, messageID types.MessageID, targetSenderJID types.JID, info *types.MessageInfo) ([]byte, error) {
// gcm decrypt key generation
newKey, additionalData := generateMsgSecretKey("", info.Sender, messageID, targetSenderJID, applyBotMessageHKDF(messageSecret))

plaintext, err := gcmutil.Decrypt(newKey, msMsg.GetEncIV(), msMsg.GetEncPayload(), additionalData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt secret message: %w", err)
}

return plaintext, nil
}

// DecryptReaction decrypts a reaction update message. This form of reactions hasn't been rolled out yet,
// so this function is likely not of much use.
//
Expand All @@ -126,7 +146,7 @@ func (cli *Client) encryptMsgSecret(chat, origSender types.JID, origMsgID types.
// }
// fmt.Printf("Reaction message: %+v\n", reaction)
// }
func (cli *Client) DecryptReaction(reaction *events.Message) (*waProto.ReactionMessage, error) {
func (cli *Client) DecryptReaction(reaction *events.Message) (*waE2E.ReactionMessage, error) {
encReaction := reaction.Message.GetEncReactionMessage()
if encReaction == nil {
return nil, ErrNotEncryptedReactionMessage
Expand All @@ -135,7 +155,7 @@ func (cli *Client) DecryptReaction(reaction *events.Message) (*waProto.ReactionM
if err != nil {
return nil, fmt.Errorf("failed to decrypt reaction: %w", err)
}
var msg waProto.ReactionMessage
var msg waE2E.ReactionMessage
err = proto.Unmarshal(plaintext, &msg)
if err != nil {
return nil, fmt.Errorf("failed to decode reaction protobuf: %w", err)
Expand All @@ -156,7 +176,7 @@ func (cli *Client) DecryptReaction(reaction *events.Message) (*waProto.ReactionM
// fmt.Printf("- %X\n", hash)
// }
// }
func (cli *Client) DecryptPollVote(vote *events.Message) (*waProto.PollVoteMessage, error) {
func (cli *Client) DecryptPollVote(vote *events.Message) (*waE2E.PollVoteMessage, error) {
pollUpdate := vote.Message.GetPollUpdateMessage()
if pollUpdate == nil {
return nil, ErrNotPollUpdateMessage
Expand All @@ -165,16 +185,16 @@ func (cli *Client) DecryptPollVote(vote *events.Message) (*waProto.PollVoteMessa
if err != nil {
return nil, fmt.Errorf("failed to decrypt poll vote: %w", err)
}
var msg waProto.PollVoteMessage
var msg waE2E.PollVoteMessage
err = proto.Unmarshal(plaintext, &msg)
if err != nil {
return nil, fmt.Errorf("failed to decode poll vote protobuf: %w", err)
}
return &msg, nil
}

func getKeyFromInfo(msgInfo *types.MessageInfo) *waProto.MessageKey {
creationKey := &waProto.MessageKey{
func getKeyFromInfo(msgInfo *types.MessageInfo) *waCommon.MessageKey {
creationKey := &waCommon.MessageKey{
RemoteJID: proto.String(msgInfo.Chat.String()),
FromMe: proto.Bool(msgInfo.IsFromMe),
ID: proto.String(msgInfo.ID),
Expand Down
29 changes: 22 additions & 7 deletions prekeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,34 @@ func nodeToPreKeyBundle(deviceID uint32, node waBinary.Node) (*prekey.Bundle, er
}
identityKeyPub := *(*[32]byte)(identityKeyRaw)

preKey, err := nodeToPreKey(keysNode.GetChildByTag("key"))
if err != nil {
return nil, fmt.Errorf("invalid prekey in prekey response: %w", err)
preKeyNode, ok := keysNode.GetOptionalChildByTag("key")
preKey := &keys.PreKey{}
if ok {
var err error
preKey, err = nodeToPreKey(preKeyNode)
if err != nil {
return nil, fmt.Errorf("invalid prekey in prekey response: %w", err)
}
}

signedPreKey, err := nodeToPreKey(keysNode.GetChildByTag("skey"))
if err != nil {
return nil, fmt.Errorf("invalid signed prekey in prekey response: %w", err)
}

return prekey.NewBundle(registrationID, deviceID,
optional.NewOptionalUint32(preKey.KeyID), signedPreKey.KeyID,
ecc.NewDjbECPublicKey(*preKey.Pub), ecc.NewDjbECPublicKey(*signedPreKey.Pub), *signedPreKey.Signature,
identity.NewKey(ecc.NewDjbECPublicKey(identityKeyPub))), nil
var bundle *prekey.Bundle
if ok {
bundle = prekey.NewBundle(registrationID, deviceID,
optional.NewOptionalUint32(preKey.KeyID), signedPreKey.KeyID,
ecc.NewDjbECPublicKey(*preKey.Pub), ecc.NewDjbECPublicKey(*signedPreKey.Pub), *signedPreKey.Signature,
identity.NewKey(ecc.NewDjbECPublicKey(identityKeyPub)))
} else {
bundle = prekey.NewBundle(registrationID, deviceID, optional.NewEmptyUint32(), signedPreKey.KeyID,
nil, ecc.NewDjbECPublicKey(*signedPreKey.Pub), *signedPreKey.Signature,
identity.NewKey(ecc.NewDjbECPublicKey(identityKeyPub)))
}

return bundle, nil
}

func nodeToPreKey(node waBinary.Node) (*keys.PreKey, error) {
Expand Down
2 changes: 1 addition & 1 deletion retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func (cli *Client) handleRetryReceipt(receipt *events.Receipt, node *waBinary.No
}
var content []waBinary.Node
if msg.wa != nil {
content = cli.getMessageContent(*encrypted, msg.wa, attrs, includeDeviceIdentity)
content = cli.getMessageContent(*encrypted, msg.wa, attrs, includeDeviceIdentity, nil)
} else {
content = []waBinary.Node{
*encrypted,
Expand Down
Loading

0 comments on commit deecc4d

Please sign in to comment.