// Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "shill/connection_health_checker.h"

#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <time.h>

#include <vector>

#include <base/bind.h>

#include "shill/async_connection.h"
#include "shill/connection.h"
#include "shill/dns_client.h"
#include "shill/dns_client_factory.h"
#include "shill/error.h"
#include "shill/http_url.h"
#include "shill/ip_address.h"
#include "shill/ip_address_store.h"
#include "shill/logging.h"
#include "shill/sockets.h"
#include "shill/socket_info.h"
#include "shill/socket_info_reader.h"

using base::Bind;
using base::Unretained;
using std::string;
using std::vector;

namespace shill {

// static
const char *ConnectionHealthChecker::kDefaultRemoteIPPool[] = {
    "74.125.224.47",
    "74.125.224.79",
    "74.125.224.111",
    "74.125.224.143"
};
// static
const int ConnectionHealthChecker::kDNSTimeoutMilliseconds = 5000;
// static
const int ConnectionHealthChecker::kInvalidSocket = -1;
// static
const int ConnectionHealthChecker::kMaxFailedConnectionAttempts = 2;
// static
const int ConnectionHealthChecker::kMaxSentDataPollingAttempts = 2;
// static
const int ConnectionHealthChecker::kMinCongestedQueueAttempts = 2;
// static
const int ConnectionHealthChecker::kMinSuccessfulSendAttempts = 1;
// static
const int ConnectionHealthChecker::kNumDNSQueries = 5;
// static
const int ConnectionHealthChecker::kTCPStateUpdateWaitMilliseconds = 5000;
// static
const uint16 ConnectionHealthChecker::kRemotePort = 80;

ConnectionHealthChecker::ConnectionHealthChecker(
    ConnectionRefPtr connection,
    EventDispatcher *dispatcher,
    IPAddressStore *remote_ips,
    const base::Callback<void(Result)> &result_callback)
    : connection_(connection),
      dispatcher_(dispatcher),
      remote_ips_(remote_ips),
      result_callback_(result_callback),
      socket_(new Sockets()),
      weak_ptr_factory_(this),
      connection_complete_callback_(
          Bind(&ConnectionHealthChecker::OnConnectionComplete,
               weak_ptr_factory_.GetWeakPtr())),
      tcp_connection_(new AsyncConnection(connection_->interface_name(),
                                          dispatcher_,
                                          socket_.get(),
                                          connection_complete_callback_)),
      report_result_(
          Bind(&ConnectionHealthChecker::ReportResult,
               weak_ptr_factory_.GetWeakPtr())),
      sock_fd_(kInvalidSocket),
      socket_info_reader_(new SocketInfoReader()),
      dns_client_factory_(DNSClientFactory::GetInstance()),
      dns_client_callback_(Bind(&ConnectionHealthChecker::GetDNSResult,
                                weak_ptr_factory_.GetWeakPtr())),
      health_check_in_progress_(false),
      num_connection_failures_(0),
      num_congested_queue_detected_(0),
      num_successful_sends_(0),
      tcp_state_update_wait_milliseconds_(kTCPStateUpdateWaitMilliseconds) {
  for (size_t i = 0; i < arraysize(kDefaultRemoteIPPool); ++i) {
    const char *ip_string = kDefaultRemoteIPPool[i];
    IPAddress ip(IPAddress::kFamilyIPv4);
    ip.SetAddressFromString(ip_string);
    remote_ips_->AddUnique(ip);
  }
}

ConnectionHealthChecker::~ConnectionHealthChecker() {
  Stop();
}

bool ConnectionHealthChecker::health_check_in_progress() const {
 return health_check_in_progress_;
}

void ConnectionHealthChecker::AddRemoteIP(IPAddress ip) {
  remote_ips_->AddUnique(ip);
}

void ConnectionHealthChecker::AddRemoteURL(const string &url_string) {
  GarbageCollectDNSClients();

  HTTPURL url;
  if (!url.ParseFromString(url_string)) {
    SLOG(Connection, 2) << __func__ << ": Malformed url: " << url_string << ".";
    return;
  }
  if (url.port() != kRemotePort) {
    SLOG(Connection, 2) << __func__ << ": Remote connections only supported "
                        << " to port 80, requested " << url.port() << ".";
    return;
  }
  for (int i = 0; i < kNumDNSQueries; ++i) {
    Error error;
    DNSClient *dns_client =
      dns_client_factory_->CreateDNSClient(IPAddress::kFamilyIPv4,
                                           connection_->interface_name(),
                                           connection_->dns_servers(),
                                           kDNSTimeoutMilliseconds,
                                           dispatcher_,
                                           dns_client_callback_);
    dns_clients_.push_back(dns_client);
    if (!dns_clients_[i]->Start(url.host(), &error)) {
      SLOG(Connection, 2) << __func__ << ": Failed to start DNS client "
                          << "(query #" << i << "): "
                          << error.message();
    }
  }
}

void ConnectionHealthChecker::Start() {
  if (health_check_in_progress_) {
    SLOG(Connection, 2) << __func__ << ": Health Check already in progress.";
    return;
  }
  if (!connection_.get()) {
    SLOG(Connection, 2) << __func__ << ": Connection not ready yet.";
    result_callback_.Run(kResultUnknown);
    return;
  }

  health_check_in_progress_ = true;
  num_connection_failures_ = 0;
  num_congested_queue_detected_ = 0;
  num_successful_sends_ = 0;

  if (remote_ips_->Empty()) {
    // Nothing to try.
    Stop();
    SLOG(Connection, 2) << __func__ << ": Not enough IPs.";
    result_callback_.Run(kResultUnknown);
    return;
  }

  // Initiate the first attempt.
  NextHealthCheckSample();
}

void ConnectionHealthChecker::Stop() {
  if (tcp_connection_.get() != NULL)
    tcp_connection_->Stop();
  verify_sent_data_callback_.Cancel();
  ClearSocketDescriptor();
  health_check_in_progress_ = false;
  num_connection_failures_ = 0;
  num_congested_queue_detected_ = 0;
  num_successful_sends_ = 0;
  num_tx_queue_polling_attempts_ = 0;
}

void ConnectionHealthChecker::SetConnection(ConnectionRefPtr connection) {
  SLOG(Connection, 3) << __func__;
  connection_ = connection;
  tcp_connection_.reset(new AsyncConnection(connection_->interface_name(),
                                            dispatcher_,
                                            socket_.get(),
                                            connection_complete_callback_));
  dns_clients_.clear();
  bool restart = health_check_in_progress();
  Stop();
  if (restart)
    Start();
}

const char *ConnectionHealthChecker::ResultToString(
    ConnectionHealthChecker::Result result) {
  switch(result) {
    case kResultUnknown:
      return "Unknown";
    case kResultConnectionFailure:
      return "ConnectionFailure";
    case kResultCongestedTxQueue:
      return "CongestedTxQueue";
    case kResultSuccess:
      return "Success";
    default:
      return "Invalid";
  }
}

void ConnectionHealthChecker::GetDNSResult(const Error &error,
                                           const IPAddress& ip) {
  if (!error.IsSuccess()) {
    SLOG(Connection, 2) << __func__ << "DNSClient returned failure: "
                        << error.message();
    return;
  }
  remote_ips_->AddUnique(ip);
}

void ConnectionHealthChecker::GarbageCollectDNSClients() {
  ScopedVector<DNSClient> keep;
  ScopedVector<DNSClient> discard;
  for (size_t i = 0; i < dns_clients_.size(); ++i) {
    if (dns_clients_[i]->IsActive())
      keep.push_back(dns_clients_[i]);
    else
      discard.push_back(dns_clients_[i]);
  }
  dns_clients_.weak_clear();
  dns_clients_ = keep.Pass();  // Passes ownership of contents.
  discard.clear();
}

void ConnectionHealthChecker::NextHealthCheckSample() {
  // Finish conditions:
  if (num_connection_failures_ == kMaxFailedConnectionAttempts) {
    health_check_result_ = kResultConnectionFailure;
    dispatcher_->PostTask(report_result_);
    return;
  }
  if (num_congested_queue_detected_ == kMinCongestedQueueAttempts) {
    health_check_result_ = kResultCongestedTxQueue;
    dispatcher_->PostTask(report_result_);
    return;
  }
  if (num_successful_sends_ == kMinSuccessfulSendAttempts) {
    health_check_result_ = kResultSuccess;
    dispatcher_->PostTask(report_result_);
    return;
  }

  // Pick a random IP from the set of IPs.
  // This guards against
  //   (1) Repeated failed attempts for the same IP at start-up everytime.
  //   (2) All users attempting to connect to the same IP.
  IPAddress ip = remote_ips_->GetRandomIP();
  SLOG(Connection, 3) << __func__ << ": Starting connection at "
                      << ip.ToString();
  if (!tcp_connection_->Start(ip, kRemotePort)) {
    SLOG(Connection, 2) << __func__ << ": Connection attempt failed.";
    ++num_connection_failures_;
    NextHealthCheckSample();
  }
}

void ConnectionHealthChecker::OnConnectionComplete(bool success, int sock_fd) {
  if (!success) {
    SLOG(Connection, 2) << __func__
                        << ": AsyncConnection connection attempt failed "
                        << "with error: "
                        << tcp_connection_->error();
    ++num_connection_failures_;
    NextHealthCheckSample();
    return;
  }

  SetSocketDescriptor(sock_fd);

  SocketInfo sock_info;
  if (!GetSocketInfo(sock_fd_, &sock_info) ||
      sock_info.connection_state() !=
          SocketInfo::kConnectionStateEstablished) {
    SLOG(Connection, 2) << __func__
                        << ": Connection originally not in established state..";
    // Count this as a failed connection attempt.
    ++num_connection_failures_;
    ClearSocketDescriptor();
    NextHealthCheckSample();
    return;
  }

  old_transmit_queue_value_ = sock_info.transmit_queue_value();
  num_tx_queue_polling_attempts_ = 0;

  // Send data on the connection and post a delayed task to check successful
  // transfer.
  char buf;
  if (socket_->Send(sock_fd_, &buf, sizeof(buf), 0) == -1) {
    SLOG(Connection, 2) << __func__ << ": " << socket_->ErrorString();
    // Count this as a failed connection attempt.
    ++num_connection_failures_;
    ClearSocketDescriptor();
    NextHealthCheckSample();
    return;
  }

  verify_sent_data_callback_.Reset(
      Bind(&ConnectionHealthChecker::VerifySentData, Unretained(this)));
  dispatcher_->PostDelayedTask(verify_sent_data_callback_.callback(),
                               tcp_state_update_wait_milliseconds_);
}

void ConnectionHealthChecker::VerifySentData() {
  SocketInfo sock_info;
  bool sock_info_found = GetSocketInfo(sock_fd_, &sock_info);
  // Acceptable TCP connection states after sending the data:
  // kConnectionStateEstablished: No change in connection state since the send.
  // kConnectionStateCloseWait: The remote host recieved the sent data and
  //    requested connection close.
  if (!sock_info_found ||
      (sock_info.connection_state() !=
           SocketInfo::kConnectionStateEstablished &&
      sock_info.connection_state() !=
           SocketInfo::kConnectionStateCloseWait)) {
    SLOG(Connection, 2) << __func__
                        << ": Connection not in acceptable state after send.";
    if (sock_info_found)
      SLOG(Connection, 3) << "Found socket info but in state: "
                          << sock_info.connection_state();
    ++num_connection_failures_;
  } else if (sock_info.transmit_queue_value() > old_transmit_queue_value_ &&
      sock_info.timer_state() ==
          SocketInfo::kTimerStateRetransmitTimerPending) {
    if (num_tx_queue_polling_attempts_ < kMaxSentDataPollingAttempts) {
      SLOG(Connection, 2) << __func__
                          << ": Polling again.";
      ++num_tx_queue_polling_attempts_;
      verify_sent_data_callback_.Reset(
          Bind(&ConnectionHealthChecker::VerifySentData, Unretained(this)));
      dispatcher_->PostDelayedTask(verify_sent_data_callback_.callback(),
                                   tcp_state_update_wait_milliseconds_);
      return;
    }
    SLOG(Connection, 2) << __func__ << ": Sampled congested Tx-Queue";
    ++num_congested_queue_detected_;
  } else {
    SLOG(Connection, 2) << __func__ << ": Sampled successful send.";
    ++num_successful_sends_;
  }
  ClearSocketDescriptor();
  NextHealthCheckSample();
}

//TODO(pprabhu): Scrub IP address logging.
bool ConnectionHealthChecker::GetSocketInfo(int sock_fd,
                                            SocketInfo *sock_info) {
  struct sockaddr_storage addr;
  socklen_t addrlen = sizeof(addr);
  memset(&addr, 0, sizeof(addr));
  if (socket_->GetSockName(sock_fd,
                           reinterpret_cast<struct sockaddr *>(&addr),
                           &addrlen) != 0) {
    SLOG(Connection, 2) << __func__
                        << ": Failed to get address of created socket.";
    return false;
  }
  if (addr.ss_family != AF_INET) {
    SLOG(Connection, 2) << __func__ << ": IPv6 socket address found.";
    return false;
  }

  CHECK_EQ(sizeof(struct sockaddr_in), addrlen);
  struct sockaddr_in *addr_in = reinterpret_cast<sockaddr_in *>(&addr);
  uint16 local_port = ntohs(addr_in->sin_port);
  char ipstr[INET_ADDRSTRLEN];
  const char *res = inet_ntop(AF_INET, &addr_in->sin_addr,
                              ipstr, sizeof(ipstr));
  if (res == NULL) {
    SLOG(Connection, 2) << __func__
                        << ": Could not convert IP address to string.";
    return false;
  }

  IPAddress local_ip_address(IPAddress::kFamilyIPv4);
  CHECK(local_ip_address.SetAddressFromString(ipstr));
  SLOG(Connection, 3) << "Local IP = " << local_ip_address.ToString()
                      << ":" << local_port;

  vector<SocketInfo> info_list;
  if (!socket_info_reader_->LoadTcpSocketInfo(&info_list)) {
    SLOG(Connection, 2) << __func__ << ": Failed to load TCP socket info.";
    return false;
  }

  for (vector<SocketInfo>::const_iterator info_list_it = info_list.begin();
       info_list_it != info_list.end();
       ++info_list_it) {
    const SocketInfo &cur_sock_info = *info_list_it;

    SLOG(Connection, 4)
        << "Testing against IP = "
        << cur_sock_info.local_ip_address().ToString()
        << ":" << cur_sock_info.local_port()
        << " (addresses equal:"
        << cur_sock_info.local_ip_address().Equals(local_ip_address)
        << ", ports equal:" << (cur_sock_info.local_port() == local_port)
        << ")";

    if (cur_sock_info.local_ip_address().Equals(local_ip_address) &&
        cur_sock_info.local_port() == local_port) {
      SLOG(Connection, 3) << __func__ << ": Found matching TCP socket info.";
      *sock_info = cur_sock_info;
      return true;
    }
  }

  SLOG(Connection, 2) << __func__ << ": No matching TCP socket info.";
  return false;
}

void ConnectionHealthChecker::ReportResult() {
  SLOG(Connection, 2) << __func__ << ": Result: "
                     << ResultToString(health_check_result_);
  Stop();
  result_callback_.Run(health_check_result_);
}

void ConnectionHealthChecker::SetSocketDescriptor(int sock_fd) {
  if (sock_fd_ != kInvalidSocket) {
    SLOG(Connection, 4) << "Closing socket";
    socket_->Close(sock_fd_);
  }
  sock_fd_ = sock_fd;
}

void ConnectionHealthChecker::ClearSocketDescriptor() {
  SetSocketDescriptor(kInvalidSocket);
}

}  // namespace shill
