package mmtls import ( "crypto/elliptic" "crypto/rand" "errors" "fmt" "sync/atomic" "golang.org/x/net/proxy" "xiawan/wx/clientsdk/baseutils" "github.com/micro/go-micro/util/log" "github.com/wsddn/go-ecdh" ) // CreateNewMMInfo 创建新的MMInfo func CreateNewMMInfo() *MMInfo { mmInfo := &MMInfo{} mmInfo.ShortHost = "szshort.weixin.qq.com" mmInfo.LongHost = "szlong.weixin.qq.com" mmInfo.LONGPort = 8080 mmInfo.LONGClientSeq = 1 mmInfo.LONGServerSeq = 1 // 从全局配置设置代理配置 mmInfo.LongConnTimeout = GlobalProxyConfig.LongConnTimeout mmInfo.LongConnReadTimeout = GlobalProxyConfig.LongConnReadTimeout mmInfo.LongConnRetryTimes = GlobalProxyConfig.LongConnRetryTimes mmInfo.LongConnRetryInterval = GlobalProxyConfig.LongConnRetryInterval mmInfo.ShortConnTimeout = GlobalProxyConfig.ShortConnTimeout mmInfo.AllowDirectOnProxyFail = GlobalProxyConfig.AllowDirectOnProxyFail return mmInfo } // InitMMTLSInfoShort 如果使用MMTLS,每次登陆前都需要初始化MMTLSInfo信息 // pskList: 之前握手服务端返回的,要保存起来,后面握手时使用, 第一次握手传空数组即可 func InitMMTLSInfoShort(dialer proxy.Dialer, hostName string, pskList []*Psk) *MMInfo { // 初始化MMInfo mmInfo := CreateNewMMInfo() mmInfo.ShortPskList = pskList // 随机生成ClientEcdhKeys mmInfo.ClientEcdhKeys = CreateClientEcdhKeys() mmInfo.ShortHost = hostName mmInfo.Dialer = dialer // 握手 mmInfo, err := MMHandShakeByShortLink(mmInfo, hostName) // 如果握手失败 就不使用MMTLS if err != nil { log.Info(err) return nil } // 握手成功,设置好HOST 和 新的URL shortURL := []byte("/mmtls/") mmInfo.ShortHost = hostName mmInfo.ShortURL = string(append(shortURL, []byte(baseutils.RandomSmallHexString(8))[0:]...)) return mmInfo } // CreateClientEcdhKeys 创建新的ClientEcdhKeys func CreateClientEcdhKeys() *ClientEcdhKeys { // 随机 clientEcdhKeys := &ClientEcdhKeys{} e := ecdh.NewEllipticECDH(elliptic.P256()) priKey1, pubKey1, _ := e.GenerateKey(rand.Reader) priKey2, pubKey2, _ := e.GenerateKey(rand.Reader) clientEcdhKeys.PriKey1 = priKey1 clientEcdhKeys.PriKey2 = priKey2 clientEcdhKeys.PubKeyBuf1 = e.Marshal(pubKey1) clientEcdhKeys.PubKeyBuf2 = e.Marshal(pubKey2) return clientEcdhKeys } // MMHandShakeByShortLink 通过短链接握手 func MMHandShakeByShortLink(mmInfo *MMInfo, hostName string) (*MMInfo, error) { shortURL := []byte("/mmtls/") mmURL := append(shortURL, []byte(baseutils.RandomSmallHexString(8))[0:]...) mmInfo.ShortURL = string(mmURL) // 发送握手请求 - ClientHello clientHelloData := CreateHandShakeClientHelloData(mmInfo) sendData := CreateRecordData(ServerHandShakeType, clientHelloData) retBytes, err := MMHTTPPost(mmInfo, sendData, "handshake") if err != nil { return nil, err } // 解析握手相应数据 retItems, err := ParserMMtlsResponseData(retBytes) if err != nil { return nil, err } // 处理握手信息 clientFinishData, err := DealHandShakePackItems(mmInfo, retItems, clientHelloData) _ = clientFinishData return mmInfo, nil } // ParserMMtlsResponseData 解析mmtls响应数据 func ParserMMtlsResponseData(data []byte) ([]*PackItem, error) { // RecodeHead *RecodeHead retItems := make([]*PackItem, 0) // 总数据大小 totalLength := uint32(len(data)) current := uint32(0) // 解析所有包 for current < totalLength { packItem := &PackItem{} // recordHead if current+5 > totalLength { return retItems, errors.New("ParserMMtlsResponseData err: current+5 >= totalLength") } recordHead := RecordHeadDeSerialize(data[current:]) packItem.RecordHead = data[current : current+5] current = current + 5 // PackData // 判断数据是否有问题 if current+uint32(recordHead.Size) > totalLength { return retItems, errors.New("ParserMMtlsResponseData err: current+uint32(recordHead.Size) >= totalLength") } packItem.PackData = data[current : current+uint32(recordHead.Size)] // current current = current + uint32(recordHead.Size) retItems = append(retItems, packItem) } return retItems, nil } // DealHandShakePackItems 解密packItems func DealHandShakePackItems(mmInfo *MMInfo, packItems []*PackItem, clientHelloReq []byte) ([]byte, error) { retClientFinishData := make([]byte, 0) // 先解析 ServerHello secretKey, err := DealServerHello(mmInfo, packItems[0]) if err != nil { return retClientFinishData, err } // 计算HashRet hashData := make([]byte, 0) hashData = append(hashData, clientHelloReq[0:]...) hashData = append(hashData, packItems[0].PackData...) hashRet := Sha256(hashData) // 密钥扩展 message := []byte("handshake key expansion") message = append(message, hashRet...) aesKeyExpand := HkdfExpand(secretKey, message, 56) gcmAesKey := aesKeyExpand[0x10:0x20] oriNonce := aesKeyExpand[0x2c:] // 解密后面的包 count := len(packItems) for index := 1; index < count; index++ { tmpPackItem := packItems[index] // 解密数据 tmpNonce := GetNonce(oriNonce, uint32(index)) tmpAad := []byte{0x00, 0x00, 0x00, 0x00} tmpAad = append(tmpAad, baseutils.Int32ToBytes(atomic.LoadUint32(&mmInfo.LONGServerSeq))...) tmpAad = append(tmpAad, tmpPackItem.RecordHead...) decodeData, err := AesGcmDecrypt(gcmAesKey, tmpNonce, tmpAad, tmpPackItem.PackData) // 设置解密后的数据 tmpPackItem.PackData = decodeData if err != nil { return retClientFinishData, err } atomic.AddUint32(&mmInfo.LONGServerSeq, 1) //mmInfo.LONGServerSeq++ // 处理CertificateVerifyType tmpType := decodeData[4] if tmpType == CertificateVerifyType { // 校验服务器 flag, err := DealCertificateVerify(clientHelloReq, packItems[0].PackData, decodeData) if err != nil { return retClientFinishData, err } if !flag { return retClientFinishData, errors.New("DealHandShakePackItems err: CertificateVerify failed") } } // 处理NewSessionTicketType if tmpType == NewSessionTicketType { err := DealNewSessionTicket(mmInfo, decodeData) if err != nil { return retClientFinishData, err } } // 处理Server FinishType if tmpType == FinishedType { // 第一步验证ServerFinished数据 tmpHashData := make([]byte, 0) tmpHashData = append(tmpHashData, clientHelloReq[0:]...) tmpHashData = append(tmpHashData, packItems[0].PackData[0:]...) tmpHashData = append(tmpHashData, packItems[1].PackData[0:]...) tmpHashData = append(tmpHashData, packItems[2].PackData[0:]...) tmpHashValue := Sha256(tmpHashData) serverFinished, err := FinishedDeSerialize(tmpPackItem.PackData[4:]) if err != nil { return retClientFinishData, err } bSuccess := VerifyFinishedData(secretKey, tmpHashValue, serverFinished.VerifyData) if !bSuccess { return retClientFinishData, errors.New("DealHandShakePackItems err: Finished verify failed") } // 第二步生成ClientFinished数据,然后加密 hkdfClientFinish := HkdfExpand(secretKey, []byte("client finished"), 32) hmacRet := HmacHash256(hkdfClientFinish, tmpHashValue) aesGcmParam := &AesGcmParam{} aesGcmParam.AesKey = aesKeyExpand[0x00:0x10] aesGcmParam.Nonce = aesKeyExpand[0x20:0x2c] // 创建Finished finished := CreateFinished(hmacRet) // 加密 finishedData := FinishedSerialize(finished) clientSeq := atomic.AddUint32(&mmInfo.LONGClientSeq, 1) - 1 encodeData, err := EncryptedReqData(aesGcmParam, finishedData, ServerHandShakeType, clientSeq) if err != nil { return retClientFinishData, err } retClientFinishData = CreateRecordData(ServerHandShakeType, encodeData) //mmInfo.LONGClientSeq++ //atomic.AddUint32(&mmInfo.LONGClientSeq, 1) break } } // 计算扩展出来的用于后续加密的Key tmpExpandHashData := make([]byte, 0) tmpExpandHashData = append(tmpExpandHashData, clientHelloReq[0:]...) tmpExpandHashData = append(tmpExpandHashData, packItems[0].PackData[0:]...) tmpExpandHashData = append(tmpExpandHashData, packItems[1].PackData[0:]...) tmpExpandHashValue := Sha256(tmpExpandHashData) // PskAccessKey 短连接MMTLS密钥 expandPskAccessData := []byte("PSK_ACCESS") expandPskAccessData = append(expandPskAccessData, tmpExpandHashValue[0:]...) mmInfo.PskAccessKey = HkdfExpand(secretKey, expandPskAccessData, 32) // AppDataKeyExtension 长链接MMTLS密钥 tmpExpandHashData = append(tmpExpandHashData, packItems[2].PackData[0:]...) tmpLongHashValue := Sha256(tmpExpandHashData) expandedSecret := append([]byte("expanded secret"), tmpLongHashValue[0:]...) retExpandSecret := HkdfExpand(secretKey, expandedSecret, 32) appDataKeyData := append([]byte("application data key expansion"), tmpLongHashValue[0:]...) appDataKeyExtension := HkdfExpand(retExpandSecret, appDataKeyData, 56) mmInfo.LongHdkfKey = &HkdfKey56{} mmInfo.LongHdkfKey.EncodeAesKey = appDataKeyExtension[0x00:0x10] mmInfo.LongHdkfKey.DecodeAesKey = appDataKeyExtension[0x10:0x20] mmInfo.LongHdkfKey.EncodeNonce = appDataKeyExtension[0x20:0x2c] mmInfo.LongHdkfKey.DecodeNonce = appDataKeyExtension[0x2c:] // 返回ClientFinishData return retClientFinishData, nil } // DealServerHello 处理ServerHello func DealServerHello(mmInfo *MMInfo, packItem *PackItem) ([]byte, error) { // 解析ServerHello serverHello, err := ServerHelloDeSerialize(packItem.PackData[4:]) if err != nil { return []byte{}, err } // 解析ServerKeyShare serverKeyShareExtension, err := ServerKeyShareExtensionDeSerialize(serverHello.ExtensionList[0].ExtensionData) if err != nil { return []byte{}, err } // 解析ServerPublicKey ecdhTool := ecdh.NewEllipticECDH(elliptic.P256()) serverPubKey, isOk := ecdhTool.Unmarshal(serverKeyShareExtension.PublicValue) if !isOk { return []byte{}, errors.New("DecodePackItems ecdhTool.Unmarshal(serverKeyShareExtension.PublicValue) failed") } // 根据NameGroup 决定使用哪个Privakey ecdhPriKey := mmInfo.ClientEcdhKeys.PriKey1 if serverKeyShareExtension.KeyOfferNameGroup == 2 { ecdhPriKey = mmInfo.ClientEcdhKeys.PriKey2 } // 协商密钥 secretKey, err := ecdhTool.GenerateSharedSecret(ecdhPriKey, serverPubKey) //服务器公钥和本地第一个私钥协商出安全密钥 if err != nil { return []byte{}, err } return Sha256(secretKey), nil } // DealCertificateVerify 处理CertificateVerify数据: 校验服务器-判断是不是微信服务器,请求返回数据有没有被串改 func DealCertificateVerify(clientHelloData []byte, serverHelloData []byte, data []byte) (bool, error) { // 解析数据 totalSize := baseutils.BytesToInt32(data[0:4]) certificateVerify, err := CertificateVerifyDeSerialize(data[4 : 4+totalSize]) if err != nil { return false, err } // 合并请求数据 message := make([]byte, 0) message = append(message, clientHelloData[0:]...) message = append(message, serverHelloData[0:]...) message = Sha256(message) // 校验数据 flag, err := ECDSAVerifyData(message, certificateVerify.Signature) if err != nil { return false, err } return flag, nil } // DealNewSessionTicket 处理NewSessionTicket数据 func DealNewSessionTicket(mmInfo *MMInfo, data []byte) error { // 解析数据 totalSize := baseutils.BytesToInt32(data[0:4]) newSessionTicket, err := NewSessionTicketDeSerialize(data[4 : 4+totalSize]) if err != nil { return err } mmInfo.ShortPskList = newSessionTicket.PskList return nil } // ----------- 上面是握手阶段 ----------- // ----------- 接下来是发送请求 ----------- // EncryptedReqData EncryptedReqData func EncryptedReqData(aesGcmParam *AesGcmParam, data []byte, recordHeadType byte, clientSeq uint32) ([]byte, error) { tmpNonce := GetNonce(aesGcmParam.Nonce, clientSeq) tmpHead := GetRecordDataByLength(recordHeadType, uint16(len(data)+0x10)) tmpAad := []byte{0x00, 0x00, 0x00, 0x00} tmpAad = append(tmpAad, baseutils.Int32ToBytes(clientSeq)...) tmpAad = append(tmpAad, tmpHead[0:]...) encodeData, err := AesGcmEncrypt(aesGcmParam.AesKey, tmpNonce, tmpAad, data) if err != nil { return []byte{}, err } return encodeData, nil } // DecryptedRecvData 解析响应数据包 func DecryptedRecvData(aesGcmParam *AesGcmParam, recvItem *PackItem, serverSeq uint32) ([]byte, error) { tmpNonce := GetNonce(aesGcmParam.Nonce, serverSeq) tmpAad := []byte{0x00, 0x00, 0x00, 0x00} tmpAad = append(tmpAad, baseutils.Int32ToBytes(serverSeq)...) tmpAad = append(tmpAad, recvItem.RecordHead[0:]...) encodeData, err := AesGcmDecrypt(aesGcmParam.AesKey, tmpNonce, tmpAad, recvItem.PackData) if err != nil { return []byte{}, err } return encodeData, nil } // CreateSendPackItems 创建发送的请求项列表 func CreateSendPackItems(mmInfo *MMInfo, httpHandler *HTTPHandler) ([]*PackItem, error) { retItems := make([]*PackItem, 0) // ClientHelloItem clientHelloItem := &PackItem{} clientHello, err := CreateClientHelloData(mmInfo) if err != nil { return nil, err } clientHelloData := ClientHelloSerialize(clientHello) clientHelloItem.RecordHead = GetRecordDataByLength(ClientHandShakeType, uint16(len(clientHelloData))) clientHelloItem.PackData = clientHelloData // EncryptedExtensionsItem encryptedExtensionsItem := &PackItem{} encryptedExtensions := CreateEncryptedExtensions() encryptedExtensionsData := EncryptedExtensionsSerialize(encryptedExtensions) encryptedExtensionsItem.RecordHead = GetRecordDataByLength(ClientHandShakeType, uint16(len(encryptedExtensionsData))) encryptedExtensionsItem.PackData = encryptedExtensionsData // HTTPHandlerItem httpHandlerItem := &PackItem{} httpHandlerData := HTTPHandlerSerialize(httpHandler) httpHandlerItem.RecordHead = GetRecordDataByLength(BodyType, uint16(len(httpHandlerData))) httpHandlerItem.PackData = httpHandlerData // AlertItem alertItem := &PackItem{} alertData := GetAlertData() alertItem.RecordHead = GetRecordDataByLength(AlertType, uint16(len(alertData))) alertItem.PackData = alertData // 返回数据 retItems = append(retItems, clientHelloItem) retItems = append(retItems, encryptedExtensionsItem) retItems = append(retItems, httpHandlerItem) retItems = append(retItems, alertItem) return retItems, nil } // MMHTTPPackData MMPackData func MMHTTPPackData(mmInfo *MMInfo, items []*PackItem) ([]byte, error) { // 密钥扩展 sha256Value := Sha256(items[0].PackData) expandSecretData := []byte("early data key expansion") expandSecretData = append(expandSecretData, sha256Value[0:]...) tmpHkdfValue := HkdfExpand(mmInfo.PskAccessKey, expandSecretData, 28) aesGcmParam := &AesGcmParam{} aesGcmParam.AesKey = tmpHkdfValue[0x00:0x10] aesGcmParam.Nonce = tmpHkdfValue[0x10:0x1c] // 加密EncryptedExtensions encryptData, err := EncryptedReqData(aesGcmParam, items[1].PackData, ClientHandShakeType, 1) if err != nil { return []byte{}, err } partData2 := CreateRecordData(ClientHandShakeType, encryptData) // 加密HTTPHandler httpHandlerEncryptData, err := EncryptedReqData(aesGcmParam, items[2].PackData, BodyType, 2) if err != nil { return []byte{}, err } partData3 := CreateRecordData(BodyType, httpHandlerEncryptData) // 加密Alert alertDataEncryptData, err := EncryptedReqData(aesGcmParam, items[3].PackData, AlertType, 3) if err != nil { return []byte{}, err } partData4 := CreateRecordData(AlertType, alertDataEncryptData) // 返回数据 retData := make([]byte, 0) retData = append(retData, items[0].RecordHead[0:]...) retData = append(retData, items[0].PackData[0:]...) retData = append(retData, partData2[0:]...) retData = append(retData, partData3[0:]...) retData = append(retData, partData4[0:]...) return retData, err } // MMDecodeResponseData 解码响应数据 // func MMDecodeResponseData(mmInfo *MMInfo, sendItems []*PackItem, respData []byte) ([]byte, error) { // retData := make([]byte, 0) // // 解析 对响应数据进行分包 // recvItems, err := ParserMMtlsResponseData(respData) // if err != nil { // return retData, err // } // if len(recvItems) < 4 { // return retData, errors.New("MMDecodeResponseData err: recvItems Length < 4") // } // // 密钥扩展 用于后面的解密 // shaData := make([]byte, 0) // shaData = append(shaData, sendItems[0].PackData[0:]...) // shaData = append(shaData, sendItems[1].PackData[0:]...) // shaData = append(shaData, recvItems[0].PackData[0:]...) // sha256Value := Sha256(shaData) // expandSecretData := []byte("handshake key expansion") // expandSecretData = append(expandSecretData, sha256Value[0:]...) // tmpHkdfValue := HkdfExpand(mmInfo.PskAccessKey, expandSecretData, 28) // aesGcmParam := &AesGcmParam{} // aesGcmParam.AesKey = tmpHkdfValue[0x00:0x10] // aesGcmParam.Nonce = tmpHkdfValue[0x10:0x1c] // // 解密剩下的包 // count := len(recvItems) // for index := 1; index < count; index++ { // // 解密Finished数据包 // decodeData, err := DecryptedRecvData(aesGcmParam, recvItems[index], uint32(index)) // if err != nil { // return retData, err // } // // RecordHeadType // recordHeadType := recvItems[index].RecordHead[0] // // ServerHandShakeType 校验收到的数据是否完整,是否又被串改 // if recordHeadType == ServerHandShakeType { // // 判断数据长度是否正常 // current := 0 // totalLength := int(baseutils.BytesToInt32(decodeData[current : current+4])) // current = current + 4 // if totalLength < 0 { // return retData, errors.New("MMDecodeResponseData err: totalLength < 0") // } // // ReceiveSubType // subType := decodeData[current] // // FinishedType 校验数据是否正常 // if subType == FinishedType { // // 反序列化 // finished, err := FinishedDeSerialize(decodeData[current : current+totalLength]) // if err != nil { // return retData, err // } // bSuccess := VerifyFinishedData(mmInfo.PskAccessKey, sha256Value, finished.VerifyData) // if !bSuccess { // return retData, errors.New("MMDecodeResponseData err: VerifyFinishedData failed") // } // } // } // // 解析响应数据 // if recordHeadType == BodyType { // retData = append(retData, decodeData[0:]...) // } // // 解析AlertType // if recordHeadType == AlertType { // // 关闭连接的数据包(对长链接有用) 固定为0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x01 // } // } // return retData, nil // } // MMDecodeResponseData 解码响应数据 func MMDecodeResponseData(mmInfo *MMInfo, sendItems []*PackItem, respData []byte) ([]byte, error) { var retData []byte defer func() { if r := recover(); r != nil { fmt.Printf("Recovered from panic: %v\n", r) // 这里可以记录日志或者执行其他的恢复操作 return } }() // 解析响应数据 recvItems, err := ParserMMtlsResponseData(respData) if err != nil { return retData, fmt.Errorf("failed to parse response data: %w", err) } if len(recvItems) < 4 { return retData, errors.New("MMDecodeResponseData err: recvItems Length < 4") } if len(sendItems) < 2 || sendItems[0] == nil || sendItems[1] == nil { return retData, errors.New("MMDecodeResponseData err: sendItems Length < 2 or contains nil") } // 密钥扩展 用于后面的解密, 确保 sendItems 和 recvItems 有效 shaData := append(sendItems[0].PackData, sendItems[1].PackData...) shaData = append(shaData, recvItems[0].PackData...) sha256Value := Sha256(shaData) expandSecretData := append([]byte("handshake key expansion"), sha256Value...) tmpHkdfValue := HkdfExpand(mmInfo.PskAccessKey, expandSecretData, 28) aesGcmParam := &AesGcmParam{ AesKey: tmpHkdfValue[:0x10], Nonce: tmpHkdfValue[0x10:0x1c], } // 解密每个包 for index := 1; index < len(recvItems); index++ { decodeData, err := DecryptedRecvData(aesGcmParam, recvItems[index], uint32(index)) if err != nil { return retData, fmt.Errorf("error decrypting packet %d: %w", index, err) } recordHeadType := recvItems[index].RecordHead[0] if recordHeadType == ServerHandShakeType { if len(decodeData) < 4 { return retData, errors.New("MMDecodeResponseData err: decodeData Length < 4") } current := 4 totalLength := int(baseutils.BytesToInt32(decodeData[:current])) if totalLength < 0 || len(decodeData) < totalLength { return retData, errors.New("MMDecodeResponseData err: invalid total length") } subType := decodeData[current] if subType == FinishedType { finished, err := FinishedDeSerialize(decodeData[current : current+totalLength]) if err != nil { return retData, fmt.Errorf("failed to deserialize finished data: %w", err) } bSuccess := VerifyFinishedData(mmInfo.PskAccessKey, sha256Value, finished.VerifyData) if !bSuccess { return retData, errors.New("MMDecodeResponseData err: VerifyFinishedData failed") } } } if recordHeadType == BodyType { retData = append(retData, decodeData...) } if recordHeadType == AlertType { // 如果是告警数据包,可以在此处理 } } return retData, nil } // VerifyFinishedData 校验服务端返回数据是否正确 func VerifyFinishedData(aesKey []byte, shaValue []byte, finishedData []byte) bool { count := len(finishedData) message := []byte("server finished") tmpHkdfValue := HkdfExpand(aesKey, message, count) verifyData := HmacHash256(tmpHkdfValue, shaValue) // 比对结果是否一致 for index := 0; index < count; index++ { if verifyData[index] != finishedData[index] { return false } } return true }