#include <webrtc/RTPSocketHandler.h>

#include <webrtc/MyWebSocketHandler.h>
#include <webrtc/STUNMessage.h>
#include <Utils.h>

#include <https/PlainSocket.h>
#include <https/SafeCallbackable.h>
#include <https/Support.h>
#include <android-base/logging.h>

#include <netdb.h>
#include <netinet/in.h>

#include <cstring>
#include <iostream>
#include <set>

#include <gflags/gflags.h>

DECLARE_string(public_ip);

static socklen_t getSockAddrLen(const sockaddr_storage &addr) {
    switch (addr.ss_family) {
        case AF_INET:
            return sizeof(sockaddr_in);
        case AF_INET6:
            return sizeof(sockaddr_in6);
        default:
            CHECK(!"Should not be here.");
            return 0;
    }
}

RTPSocketHandler::RTPSocketHandler(
        std::shared_ptr<RunLoop> runLoop,
        std::shared_ptr<ServerState> serverState,
        int domain,
        uint16_t port,
        uint32_t trackMask,
        std::shared_ptr<RTPSession> session)
    : mRunLoop(runLoop),
      mServerState(serverState),
      mLocalPort(port),
      mTrackMask(trackMask),
      mSession(session),
      mSendPending(false),
      mDTLSConnected(false) {
    int sock = socket(domain, SOCK_DGRAM, 0);

    makeFdNonblocking(sock);
    mSocket = std::make_shared<PlainSocket>(mRunLoop, sock);

    sockaddr_storage addr;

    if (domain == PF_INET) {
        sockaddr_in addrV4;
        memset(addrV4.sin_zero, 0, sizeof(addrV4.sin_zero));
        addrV4.sin_family = AF_INET;
        addrV4.sin_port = htons(port);
        addrV4.sin_addr.s_addr = INADDR_ANY;
        memcpy(&addr, &addrV4, sizeof(addrV4));
    } else {
        CHECK_EQ(domain, PF_INET6);

        sockaddr_in6 addrV6;
        addrV6.sin6_family = AF_INET6;
        addrV6.sin6_port = htons(port);
        addrV6.sin6_addr = in6addr_any;
        addrV6.sin6_scope_id = 0;
        memcpy(&addr, &addrV6, sizeof(addrV6));
    }

    int res = bind(
            sock,
            reinterpret_cast<const sockaddr *>(&addr),
            getSockAddrLen(addr));

    CHECK(!res);

    auto videoPacketizer =
        (trackMask & TRACK_VIDEO)
            ? mServerState->getVideoPacketizer() : nullptr;

    auto audioPacketizer =
        (trackMask & TRACK_AUDIO)
            ? mServerState->getAudioPacketizer() : nullptr;

    mRTPSender = std::make_shared<RTPSender>(
            mRunLoop,
            this,
            videoPacketizer,
            audioPacketizer);

    if (trackMask & TRACK_VIDEO) {
        mRTPSender->addSource(0xdeadbeef);
        mRTPSender->addSource(0xcafeb0b0);

        mRTPSender->addRetransInfo(0xdeadbeef, 96, 0xcafeb0b0, 97);

        videoPacketizer->addSender(mRTPSender);
    }

    if (trackMask & TRACK_AUDIO) {
        mRTPSender->addSource(0x8badf00d);

        audioPacketizer->addSender(mRTPSender);
    }
}

uint16_t RTPSocketHandler::getLocalPort() const {
    return mLocalPort;
}

std::string RTPSocketHandler::getLocalUFrag() const {
    return mSession->localUFrag();
}

std::string RTPSocketHandler::getLocalIPString() const {
    return FLAGS_public_ip;
}

void RTPSocketHandler::run() {
    mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onReceive));
}

void RTPSocketHandler::onReceive() {
    std::vector<uint8_t> buffer(kMaxUDPPayloadSize);

    uint8_t *data = buffer.data();

    sockaddr_storage addr;
    socklen_t addrLen = sizeof(addr);

    auto n = mSocket->recvfrom(
            data, buffer.size(), reinterpret_cast<sockaddr *>(&addr), &addrLen);

    STUNMessage msg(data, n);
    if (!msg.isValid()) {
        if (mDTLSConnected) {
            int err = -EINVAL;
            if (mRTPSender) {
                err = onSRTPReceive(data, static_cast<size_t>(n));
            }

            if (err == -EINVAL) {
                LOG(VERBOSE) << "Sending to DTLS instead:";
                // hexdump(data, n);

                onDTLSReceive(data, static_cast<size_t>(n));

                if (mTrackMask & TRACK_DATA) {
                    ssize_t n;

                    do {
                        uint8_t buf[kMaxUDPPayloadSize];
                        n = mDTLS->readApplicationData(buf, sizeof(buf));

                        if (n > 0) {
                            auto err = mSCTPHandler->inject(
                                    buf, static_cast<size_t>(n));

                            if (err) {
                                LOG(WARNING)
                                    << "SCTPHandler::inject returned error "
                                    << err;
                            }
                        }
                    } while (n > 0);
                }
            }
        } else {
            onDTLSReceive(data, static_cast<size_t>(n));
        }

        run();
        return;
    }

    if (msg.type() == 0x0001 /* Binding Request */) {
        STUNMessage response(0x0101 /* Binding Response */, msg.data() + 8);

        if (!matchesSession(msg)) {
            LOG(WARNING) << "Unknown session or no USERNAME.";
            run();
            return;
        }

        const auto &answerPassword = mSession->localPassword();

        // msg.dump(answerPassword);

        if (addr.ss_family == AF_INET) {
            uint8_t attr[8];
            attr[0] = 0x00;

            sockaddr_in addrV4;
            CHECK_EQ(addrLen, sizeof(addrV4));

            memcpy(&addrV4, &addr, addrLen);

            attr[1] = 0x01;  // IPv4

            static constexpr uint32_t kMagicCookie = 0x2112a442;

            uint16_t portHost = ntohs(addrV4.sin_port);
            portHost ^= (kMagicCookie >> 16);

            uint32_t ipHost = ntohl(addrV4.sin_addr.s_addr);
            ipHost ^= kMagicCookie;

            attr[2] = portHost >> 8;
            attr[3] = portHost & 0xff;
            attr[4] = ipHost >> 24;
            attr[5] = (ipHost >> 16) & 0xff;
            attr[6] = (ipHost >> 8) & 0xff;
            attr[7] = ipHost & 0xff;

            response.addAttribute(
                    0x0020 /* XOR-MAPPED-ADDRESS */, attr, sizeof(attr));
        } else {
            uint8_t attr[20];
            attr[0] = 0x00;

            CHECK_EQ(addr.ss_family, AF_INET6);

            sockaddr_in6 addrV6;
            CHECK_EQ(addrLen, sizeof(addrV6));

            memcpy(&addrV6, &addr, addrLen);

            attr[1] = 0x02;  // IPv6

            static constexpr uint32_t kMagicCookie = 0x2112a442;

            uint16_t portHost = ntohs(addrV6.sin6_port);
            portHost ^= (kMagicCookie >> 16);

            attr[2] = portHost >> 8;
            attr[3] = portHost & 0xff;

            uint8_t ipHost[16];

            std::string out;

            for (size_t i = 0; i < 16; ++i) {
                ipHost[i] = addrV6.sin6_addr.s6_addr[15 - i];

                if (!out.empty()) {
                    out += ":";
                }
                out += StringPrintf("%02x", ipHost[i]);

                ipHost[i] ^= response.data()[4 + i];
            }

            // LOG(INFO) << "IP6 = " << out;

            for (size_t i = 0; i < 16; ++i) {
                attr[4 + i] = ipHost[15 - i];
            }

            response.addAttribute(
                    0x0020 /* XOR-MAPPED-ADDRESS */, attr, sizeof(attr));
        }

        response.addMessageIntegrityAttribute(answerPassword);
        response.addFingerprint();

        // response.dump(answerPassword);

        auto res =
            mSocket->sendto(
                    response.data(),
                    response.size(),
                    reinterpret_cast<const sockaddr *>(&addr),
                    addrLen);

        CHECK_GT(res, 0);
        CHECK_EQ(static_cast<size_t>(res), response.size());

        if (!mSession->isActive()) {
            mSession->setRemoteAddress(addr);

            mSession->setIsActive();

            mSession->schedulePing(
                    mRunLoop,
                    makeSafeCallback(
                        this, &RTPSocketHandler::pingRemote, mSession),
                    std::chrono::seconds(0));
        }

    } else {
        // msg.dump();

        if (msg.type() == 0x0101 && !mDTLS) {
            mDTLS = std::make_shared<DTLS>(
                    shared_from_this(),
                    DTLS::Mode::ACCEPT,
                    mSession->localCertificate(),
                    mSession->localKey(),
                    mSession->remoteFingerprint(),
                    (mTrackMask != TRACK_DATA) /* useSRTP */);

            mDTLS->connect(mSession->remoteAddress());
        }
    }

    run();
}

bool RTPSocketHandler::matchesSession(const STUNMessage &msg) const {
    const void *attrData;
    size_t attrSize;
    if (!msg.findAttribute(0x0006 /* USERNAME */, &attrData, &attrSize)) {
        return false;
    }

    std::string uFragPair(static_cast<const char *>(attrData), attrSize);
    auto colonPos = uFragPair.find(':');

    if (colonPos == std::string::npos) {
        return false;
    }

    std::string localUFrag(uFragPair, 0, colonPos);
    std::string remoteUFrag(uFragPair, colonPos + 1);

    if (mSession->localUFrag() != localUFrag
            || mSession->remoteUFrag() != remoteUFrag) {

        LOG(WARNING)
            << "Unable to find session localUFrag='"
            << localUFrag
            << "', remoteUFrag='"
            << remoteUFrag
            << "'";

        return false;
    }

    return true;
}

void RTPSocketHandler::pingRemote(std::shared_ptr<RTPSession> session) {
    std::vector<uint8_t> transactionID { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 };

    STUNMessage msg(
            0x0001 /* Binding Request */,
            transactionID.data());

    std::string uFragPair =
            session->remoteUFrag() + ":" + session->localUFrag();

    msg.addAttribute(
            0x0006 /* USERNAME */,
            uFragPair.c_str(),
            uFragPair.size());

    uint64_t tieBreaker = 0xdeadbeefcafeb0b0;  // XXX
    msg.addAttribute(
            0x802a /* ICE-CONTROLLING */,
            &tieBreaker,
            sizeof(tieBreaker));

    uint32_t priority = 0xdeadbeef;
    msg.addAttribute(
            0x0024 /* PRIORITY */, &priority, sizeof(priority));

    // We're the controlling agent and including the "USE-CANDIDATE" attribute
    // below nominates this candidate.
    msg.addAttribute(0x0025 /* USE_CANDIDATE */);

    msg.addMessageIntegrityAttribute(session->remotePassword());
    msg.addFingerprint();

    queueDatagram(session->remoteAddress(), msg.data(), msg.size());

    session->schedulePing(
            mRunLoop,
            makeSafeCallback(this, &RTPSocketHandler::pingRemote, session),
            std::chrono::seconds(1));
}

RTPSocketHandler::Datagram::Datagram(
        const sockaddr_storage &addr, const void *data, size_t size)
    : mData(size),
      mAddr(addr) {
    memcpy(mData.data(), data, size);
}

const void *RTPSocketHandler::Datagram::data() const {
    return mData.data();
}

size_t RTPSocketHandler::Datagram::size() const {
    return mData.size();
}

const sockaddr_storage &RTPSocketHandler::Datagram::remoteAddress() const {
    return mAddr;
}

void RTPSocketHandler::queueDatagram(
        const sockaddr_storage &addr, const void *data, size_t size) {
    auto datagram = std::make_shared<Datagram>(addr, data, size);

    CHECK_LE(size, RTPSocketHandler::kMaxUDPPayloadSize);

    mRunLoop->post(
            makeSafeCallback<RTPSocketHandler>(
                this,
                [datagram](RTPSocketHandler *me) {
                    me->mOutQueue.push_back(datagram);

                    if (!me->mSendPending) {
                        me->scheduleDrainOutQueue();
                    }
                }));
}

void RTPSocketHandler::scheduleDrainOutQueue() {
    CHECK(!mSendPending);

    mSendPending = true;
    mSocket->postSend(
            makeSafeCallback(
                this, &RTPSocketHandler::drainOutQueue));
}

void RTPSocketHandler::drainOutQueue() {
    mSendPending = false;

    CHECK(!mOutQueue.empty());

    do {
        auto datagram = mOutQueue.front();

        ssize_t n;
        do {
            const sockaddr_storage &remoteAddr = datagram->remoteAddress();

            n = mSocket->sendto(
                    datagram->data(),
                    datagram->size(),
                    reinterpret_cast<const sockaddr *>(&remoteAddr),
                    getSockAddrLen(remoteAddr));
        } while (n < 0 && errno == EINTR);

        if (n < 0) {
            if (errno == EAGAIN || errno == EWOULDBLOCK) {
                break;
            }

            CHECK(!"Should not be here");
        }

        mOutQueue.pop_front();

    } while (!mOutQueue.empty());

    if (!mOutQueue.empty()) {
        scheduleDrainOutQueue();
    }
}

void RTPSocketHandler::onDTLSReceive(const uint8_t *data, size_t size) {
    if (mDTLS) {
        mDTLS->inject(data, size);
    }
}

void RTPSocketHandler::notifyDTLSConnected() {
    LOG(INFO) << "TDLS says that it's now connected.";

    mDTLSConnected = true;

    if (mTrackMask & TRACK_DATA) {
        mSCTPHandler = std::make_shared<SCTPHandler>(mRunLoop, mDTLS);
        mSCTPHandler->run();
    }

    mRTPSender->run();
}

int RTPSocketHandler::onSRTPReceive(uint8_t *data, size_t size) {
    if (size < 2) {
        return -EINVAL;
    }

    auto version = data[0] >> 6;
    if (version != 2) {
        return -EINVAL;
    }

    auto outSize = mDTLS->unprotect(data, size, false /* isRTP */);

    auto err = mRTPSender->injectRTCP(data, outSize);
    if (err) {
        LOG(WARNING) << "RTPSender::injectRTCP returned " << err;
    }

    return err;
}

void RTPSocketHandler::queueRTCPDatagram(const void *data, size_t size) {
    if (!mDTLSConnected) {
        return;
    }

    std::vector<uint8_t> copy(size + SRTP_MAX_TRAILER_LEN);
    memcpy(copy.data(), data, size);

    auto outSize = mDTLS->protect(copy.data(), size, false /* isRTP */);
    CHECK_LE(outSize, copy.size());

    queueDatagram(mSession->remoteAddress(), copy.data(), outSize);
}

void RTPSocketHandler::queueRTPDatagram(const void *data, size_t size) {
    if (!mDTLSConnected) {
        return;
    }

    std::vector<uint8_t> copy(size + SRTP_MAX_TRAILER_LEN);
    memcpy(copy.data(), data, size);

    auto outSize = mDTLS->protect(copy.data(), size, true /* isRTP */);
    CHECK_LE(outSize, copy.size());

    queueDatagram(mSession->remoteAddress(), copy.data(), outSize);
}
