Paul Stewart | c2350ee | 2011-10-19 12:28:40 -0700 | [diff] [blame] | 1 | // Copyright (c) 2011 The Chromium OS Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style license that can be |
| 3 | // found in the LICENSE file. |
| 4 | |
| 5 | #include "shill/dns_client.h" |
| 6 | |
| 7 | #include <arpa/inet.h> |
| 8 | #include <netdb.h> |
| 9 | #include <netinet/in.h> |
| 10 | #include <sys/socket.h> |
| 11 | |
| 12 | #include <map> |
| 13 | #include <set> |
| 14 | #include <string> |
| 15 | #include <tr1/memory> |
| 16 | #include <vector> |
| 17 | |
| 18 | #include <base/stl_util-inl.h> |
| 19 | |
| 20 | #include <shill/shill_ares.h> |
| 21 | #include <shill/shill_time.h> |
| 22 | |
| 23 | using std::map; |
| 24 | using std::set; |
| 25 | using std::string; |
| 26 | using std::vector; |
| 27 | |
| 28 | namespace shill { |
| 29 | |
| 30 | const int DNSClient::kDefaultTimeoutMS = 2000; |
| 31 | const char DNSClient::kErrorNoData[] = "The query response contains no answers"; |
| 32 | const char DNSClient::kErrorFormErr[] = "The server says the query is bad"; |
| 33 | const char DNSClient::kErrorServerFail[] = "The server says it had a failure"; |
| 34 | const char DNSClient::kErrorNotFound[] = "The queried-for domain was not found"; |
| 35 | const char DNSClient::kErrorNotImp[] = "The server doesn't implement operation"; |
| 36 | const char DNSClient::kErrorRefused[] = "The server replied, refused the query"; |
| 37 | const char DNSClient::kErrorBadQuery[] = "Locally we could not format a query"; |
| 38 | const char DNSClient::kErrorNetRefused[] = "The network connection was refused"; |
| 39 | const char DNSClient::kErrorTimedOut[] = "The network connection was timed out"; |
| 40 | const char DNSClient::kErrorUnknown[] = "DNS Resolver unknown internal error"; |
| 41 | |
| 42 | // Private to the implementation of resolver so callers don't include ares.h |
| 43 | struct DNSClientState { |
| 44 | ares_channel channel; |
| 45 | map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > read_handlers; |
| 46 | map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > write_handlers; |
| 47 | struct timeval start_time_; |
| 48 | }; |
| 49 | |
| 50 | DNSClient::DNSClient(IPAddress::Family family, |
| 51 | const string &interface_name, |
| 52 | const vector<string> &dns_servers, |
| 53 | int timeout_ms, |
| 54 | EventDispatcher *dispatcher, |
| 55 | Callback1<bool>::Type *callback) |
| 56 | : address_(IPAddress(family)), |
| 57 | interface_name_(interface_name), |
| 58 | dns_servers_(dns_servers), |
| 59 | dispatcher_(dispatcher), |
| 60 | callback_(callback), |
| 61 | timeout_ms_(timeout_ms), |
| 62 | running_(false), |
| 63 | resolver_state_(NULL), |
| 64 | read_callback_(NewCallback(this, &DNSClient::HandleDNSRead)), |
| 65 | write_callback_(NewCallback(this, &DNSClient::HandleDNSWrite)), |
| 66 | task_factory_(this), |
| 67 | ares_(Ares::GetInstance()), |
| 68 | time_(Time::GetInstance()) {} |
| 69 | |
| 70 | DNSClient::~DNSClient() { |
| 71 | Stop(); |
| 72 | } |
| 73 | |
| 74 | bool DNSClient::Start(const string &hostname) { |
| 75 | if (running_) { |
| 76 | LOG(ERROR) << "Only one DNS request is allowed at a time"; |
| 77 | return false; |
| 78 | } |
| 79 | |
| 80 | if (!resolver_state_.get()) { |
| 81 | struct ares_options options; |
| 82 | memset(&options, 0, sizeof(options)); |
| 83 | |
| 84 | vector<struct in_addr> server_addresses; |
| 85 | for (vector<string>::iterator it = dns_servers_.begin(); |
| 86 | it != dns_servers_.end(); |
| 87 | ++it) { |
| 88 | struct in_addr addr; |
| 89 | if (inet_aton(it->c_str(), &addr) != 0) { |
| 90 | server_addresses.push_back(addr); |
| 91 | } |
| 92 | } |
| 93 | |
| 94 | if (server_addresses.empty()) { |
| 95 | LOG(ERROR) << "No valid DNS server addresses"; |
| 96 | return false; |
| 97 | } |
| 98 | |
| 99 | options.servers = server_addresses.data(); |
| 100 | options.nservers = server_addresses.size(); |
| 101 | options.timeout = timeout_ms_; |
| 102 | |
| 103 | resolver_state_.reset(new DNSClientState); |
| 104 | int status = ares_->InitOptions(&resolver_state_->channel, |
| 105 | &options, |
| 106 | ARES_OPT_SERVERS | ARES_OPT_TIMEOUTMS); |
| 107 | if (status != ARES_SUCCESS) { |
| 108 | LOG(ERROR) << "ARES initialization returns error code: " << status; |
| 109 | resolver_state_.reset(); |
| 110 | return false; |
| 111 | } |
| 112 | |
| 113 | ares_->SetLocalDev(resolver_state_->channel, interface_name_.c_str()); |
| 114 | } |
| 115 | |
| 116 | running_ = true; |
| 117 | time_->GetTimeOfDay(&resolver_state_->start_time_, NULL); |
| 118 | error_.clear(); |
| 119 | ares_->GetHostByName(resolver_state_->channel, hostname.c_str(), |
| 120 | address_.family(), ReceiveDNSReplyCB, this); |
| 121 | |
| 122 | if (!RefreshHandles()) { |
| 123 | LOG(ERROR) << "Impossibly short timeout."; |
| 124 | Stop(); |
| 125 | return false; |
| 126 | } |
| 127 | |
| 128 | return true; |
| 129 | } |
| 130 | |
| 131 | void DNSClient::Stop() { |
| 132 | if (!resolver_state_.get()) { |
| 133 | return; |
| 134 | } |
| 135 | |
| 136 | running_ = false; |
| 137 | task_factory_.RevokeAll(); |
| 138 | ares_->Destroy(resolver_state_->channel); |
| 139 | resolver_state_.reset(); |
| 140 | } |
| 141 | |
| 142 | void DNSClient::HandleDNSRead(int fd) { |
| 143 | ares_->ProcessFd(resolver_state_->channel, fd, ARES_SOCKET_BAD); |
| 144 | RefreshHandles(); |
| 145 | } |
| 146 | |
| 147 | void DNSClient::HandleDNSWrite(int fd) { |
| 148 | ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, fd); |
| 149 | RefreshHandles(); |
| 150 | } |
| 151 | |
| 152 | void DNSClient::HandleTimeout() { |
| 153 | ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD); |
| 154 | if (!RefreshHandles()) { |
| 155 | // If we have timed out, ARES might still have sockets open. |
| 156 | // Force them closed by doing an explicit shutdown. This is |
| 157 | // different from HandleDNSRead and HandleDNSWrite where any |
| 158 | // change in our running_ state would be as a result of ARES |
| 159 | // itself and therefore properly synchronized with it: if a |
| 160 | // search completes during the course of ares_->ProcessFd(), |
| 161 | // the ARES fds and other state is guaranteed to have cleaned |
| 162 | // up and ready for a new request. Since this timeout is |
| 163 | // genererated outside of the library it is best to completely |
| 164 | // shutdown ARES and start with fresh state for a new request. |
| 165 | Stop(); |
| 166 | } |
| 167 | } |
| 168 | |
| 169 | void DNSClient::ReceiveDNSReply(int status, struct hostent *hostent) { |
| 170 | if (!running_) { |
| 171 | // We can be called during ARES shutdown -- ignore these events. |
| 172 | return; |
| 173 | } |
| 174 | running_ = false; |
| 175 | |
| 176 | if (status == ARES_SUCCESS && |
| 177 | hostent != NULL && |
| 178 | hostent->h_addrtype == address_.family() && |
| 179 | hostent->h_length == IPAddress::GetAddressLength(address_.family()) && |
| 180 | hostent->h_addr_list != NULL && |
| 181 | hostent->h_addr_list[0] != NULL) { |
| 182 | address_ = IPAddress(address_.family(), |
| 183 | ByteString(reinterpret_cast<unsigned char *>( |
| 184 | hostent->h_addr_list[0]), hostent->h_length)); |
| 185 | callback_->Run(true); |
| 186 | } else { |
| 187 | switch (status) { |
| 188 | case ARES_ENODATA: |
| 189 | error_ = kErrorNoData; |
| 190 | break; |
| 191 | case ARES_EFORMERR: |
| 192 | error_ = kErrorFormErr; |
| 193 | break; |
| 194 | case ARES_ESERVFAIL: |
| 195 | error_ = kErrorServerFail; |
| 196 | break; |
| 197 | case ARES_ENOTFOUND: |
| 198 | error_ = kErrorNotFound; |
| 199 | break; |
| 200 | case ARES_ENOTIMP: |
| 201 | error_ = kErrorNotImp; |
| 202 | break; |
| 203 | case ARES_EREFUSED: |
| 204 | error_ = kErrorRefused; |
| 205 | break; |
| 206 | case ARES_EBADQUERY: |
| 207 | case ARES_EBADNAME: |
| 208 | case ARES_EBADFAMILY: |
| 209 | case ARES_EBADRESP: |
| 210 | error_ = kErrorBadQuery; |
| 211 | break; |
| 212 | case ARES_ECONNREFUSED: |
| 213 | error_ = kErrorNetRefused; |
| 214 | break; |
| 215 | case ARES_ETIMEOUT: |
| 216 | error_ = kErrorTimedOut; |
| 217 | break; |
| 218 | default: |
| 219 | error_ = kErrorUnknown; |
| 220 | if (status == ARES_SUCCESS) { |
| 221 | LOG(ERROR) << "ARES returned success but hostent was invalid!"; |
| 222 | } else { |
| 223 | LOG(ERROR) << "ARES returned unhandled error status " << status; |
| 224 | } |
| 225 | break; |
| 226 | } |
| 227 | callback_->Run(false); |
| 228 | } |
| 229 | } |
| 230 | |
| 231 | void DNSClient::ReceiveDNSReplyCB(void *arg, int status, |
| 232 | int /*timeouts*/, |
| 233 | struct hostent *hostent) { |
| 234 | DNSClient *res = static_cast<DNSClient *>(arg); |
| 235 | res->ReceiveDNSReply(status, hostent); |
| 236 | } |
| 237 | |
| 238 | bool DNSClient::RefreshHandles() { |
| 239 | map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > old_read = |
| 240 | resolver_state_->read_handlers; |
| 241 | map< ares_socket_t, std::tr1::shared_ptr<IOHandler> > old_write = |
| 242 | resolver_state_->write_handlers; |
| 243 | |
| 244 | resolver_state_->read_handlers.clear(); |
| 245 | resolver_state_->write_handlers.clear(); |
| 246 | |
| 247 | ares_socket_t sockets[ARES_GETSOCK_MAXNUM]; |
| 248 | int action_bits = ares_->GetSock(resolver_state_->channel, sockets, |
| 249 | ARES_GETSOCK_MAXNUM); |
| 250 | |
| 251 | for (int i = 0; i < ARES_GETSOCK_MAXNUM; i++) { |
| 252 | if (ARES_GETSOCK_READABLE(action_bits, i)) { |
| 253 | if (ContainsKey(old_read, sockets[i])) { |
| 254 | resolver_state_->read_handlers[sockets[i]] = old_read[sockets[i]]; |
| 255 | } else { |
| 256 | resolver_state_->read_handlers[sockets[i]] = |
| 257 | std::tr1::shared_ptr<IOHandler> ( |
| 258 | dispatcher_->CreateReadyHandler(sockets[i], |
| 259 | IOHandler::kModeInput, |
| 260 | read_callback_.get())); |
| 261 | } |
| 262 | } |
| 263 | if (ARES_GETSOCK_WRITABLE(action_bits, i)) { |
| 264 | if (ContainsKey(old_write, sockets[i])) { |
| 265 | resolver_state_->write_handlers[sockets[i]] = old_write[sockets[i]]; |
| 266 | } else { |
| 267 | resolver_state_->write_handlers[sockets[i]] = |
| 268 | std::tr1::shared_ptr<IOHandler> ( |
| 269 | dispatcher_->CreateReadyHandler(sockets[i], |
| 270 | IOHandler::kModeOutput, |
| 271 | write_callback_.get())); |
| 272 | } |
| 273 | } |
| 274 | } |
| 275 | |
| 276 | if (!running_) { |
| 277 | // We are here just to clean up socket and timer handles, and the |
| 278 | // ARES state was cleaned up during the last call to ares_process_fd(). |
| 279 | task_factory_.RevokeAll(); |
| 280 | return false; |
| 281 | } |
| 282 | |
| 283 | // Schedule timer event for the earlier of our timeout or one requested by |
| 284 | // the resolver library. |
| 285 | struct timeval now, elapsed_time, timeout_tv; |
| 286 | time_->GetTimeOfDay(&now, NULL); |
| 287 | timersub(&now, &resolver_state_->start_time_, &elapsed_time); |
| 288 | timeout_tv.tv_sec = timeout_ms_ / 1000; |
| 289 | timeout_tv.tv_usec = (timeout_ms_ % 1000) * 1000; |
| 290 | if (timercmp(&elapsed_time, &timeout_tv, >=)) { |
| 291 | // There are 3 cases of interest: |
| 292 | // - If we got here from Start(), we will have the side-effect of |
| 293 | // both invoking the callback and returning False in Start(). |
| 294 | // Start() will call Stop() which will shut down ARES. |
| 295 | // - If we got here from the tail of an IO event (racing with the |
| 296 | // timer, we can't call Stop() since that will blow away the |
| 297 | // IOHandler we are running in, however we will soon be called |
| 298 | // again by the timeout proc so we can clean up the ARES state |
| 299 | // then. |
| 300 | // - If we got here from a timeout handler, it will safely call |
| 301 | // Stop() when we return false. |
| 302 | error_ = kErrorTimedOut; |
| 303 | callback_->Run(false); |
| 304 | running_ = false; |
| 305 | return false; |
| 306 | } else { |
| 307 | struct timeval max, ret_tv; |
| 308 | timersub(&timeout_tv, &elapsed_time, &max); |
| 309 | struct timeval *tv = ares_->Timeout(resolver_state_->channel, |
| 310 | &max, &ret_tv); |
| 311 | task_factory_.RevokeAll(); |
| 312 | dispatcher_->PostDelayedTask( |
| 313 | task_factory_.NewRunnableMethod(&DNSClient::HandleTimeout), |
| 314 | tv->tv_sec * 1000 + tv->tv_usec / 1000); |
| 315 | } |
| 316 | |
| 317 | return true; |
| 318 | } |
| 319 | |
| 320 | } // namespace shill |