Skip to content

Commit

Permalink
vsock_proxy: Introduce DnsResolutionInfo type
Browse files Browse the repository at this point in the history
Replace DnsResolveResult with DnsResolutionInfo. The new type includes
utility methods and provides a better interface for DNS resolution
information, encapsulating resolved IP address, TTL value, and last
resolution time.

Signed-off-by: Erdem Meydanli <meydanli@amazon.com>
  • Loading branch information
meerd committed Apr 11, 2024
1 parent 4707473 commit 15e5567
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 53 deletions.
68 changes: 52 additions & 16 deletions vsock_proxy/src/dns.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,53 @@
// Copyright 2019-2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

#![deny(warnings)]

/// Contains code for Proxy, a library used for translating vsock traffic to
/// TCP traffic
///
use std::net::IpAddr;

use chrono::{DateTime, Duration, Utc};
use hickory_resolver::config::*;
use hickory_resolver::Resolver;
use idna::domain_to_ascii;

use crate::{DnsResolveResult, IpAddrType, VsockProxyResult};
use crate::{IpAddrType, VsockProxyResult};

/// `DnsResolutionInfo` represents DNS resolution information, including the resolved
/// IP address, TTL value and last resolution time.
#[derive(Copy, Clone, Debug)]
pub struct DnsResolutionInfo {
/// The IP address that the hostname was resolved to.
ip_addr: IpAddr,
/// The configured duration after which the DNS resolution should be refreshed.
ttl: Duration,
/// The timestamp representing the last time the DNS resolution was performed.
last_dns_resolution_time: DateTime<Utc>,
}

impl DnsResolutionInfo {
pub fn is_expired(&self) -> bool {
Utc::now() > self.last_dns_resolution_time + self.ttl
}

fn new(new_ip_addr: IpAddr, new_ttl: Duration) -> Self {
DnsResolutionInfo {
ip_addr: new_ip_addr,
ttl: new_ttl,
last_dns_resolution_time: Utc::now(),
}
}

pub fn ip_addr(&self) -> IpAddr {
self.ip_addr
}

pub fn ttl(&self) -> Duration {
self.ttl
}
}

/// Resolve a DNS name (IDNA format) into multiple IP addresses (v4 or v6)
pub fn resolve(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult<Vec<DnsResolveResult>> {
pub fn resolve(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult<Vec<DnsResolutionInfo>> {
// IDNA parsing
let addr = domain_to_ascii(addr).map_err(|_| "Could not parse domain name")?;

Expand All @@ -21,17 +56,17 @@ pub fn resolve(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult<Vec<Dns
let resolver = Resolver::new(ResolverConfig::default(), ResolverOpts::default())
.map_err(|_| "Error while initializing DNS resolver!")?;

let rresults: Vec<DnsResolveResult> = resolver
let rresults: Vec<DnsResolutionInfo> = resolver
.lookup_ip(addr)
.map_err(|_| "DNS lookup failed!")?
.as_lookup()
.records()
.iter()
.filter_map(|record| {
if let Some(rdata) = record.data() {
if let Some(ip) = rdata.ip_addr() {
let ttl = record.ttl();
return Some(DnsResolveResult { ip, ttl });
if let Some(ip_addr) = rdata.ip_addr() {
let ttl = Duration::seconds(record.ttl() as i64);
return Some(DnsResolutionInfo::new(ip_addr, ttl));
}
}
None
Expand All @@ -48,8 +83,9 @@ pub fn resolve(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult<Vec<Dns
}

//Partition the resolution results into groups that use IPv4 or IPv6 addresses.
let (rresults_with_ipv4, rresults_with_ipv6): (Vec<_>, Vec<_>) =
rresults.into_iter().partition(|result| result.ip.is_ipv4());
let (rresults_with_ipv4, rresults_with_ipv6): (Vec<_>, Vec<_>) = rresults
.into_iter()
.partition(|result| result.ip_addr().is_ipv4());

if IpAddrType::IPAddrV4Only == ip_addr_type && !rresults_with_ipv4.is_empty() {
Ok(rresults_with_ipv4)
Expand All @@ -61,7 +97,7 @@ pub fn resolve(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult<Vec<Dns
}

/// Resolve a DNS name (IDNA format) into a single address with a TTL value
pub fn resolve_single(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult<DnsResolveResult> {
pub fn resolve_single(addr: &str, ip_addr_type: IpAddrType) -> VsockProxyResult<DnsResolutionInfo> {
let rresults = resolve(addr, ip_addr_type)?;
// Return the first resolved IP address and its TTL value.
rresults
Expand Down Expand Up @@ -127,14 +163,14 @@ mod tests {
fn test_resolve_ipv4_only() {
let domain = unsafe { IPV4_ONLY_TEST_DOMAIN };
let rresults = resolve(domain, IpAddrType::IPAddrV4Only).unwrap();
assert!(rresults.iter().all(|item| item.ip.is_ipv4()));
assert!(rresults.iter().all(|item| item.ip_addr().is_ipv4()));
}

#[test]
fn test_resolve_ipv6_only() {
let domain = unsafe { IPV6_ONLY_TEST_DOMAIN };
let rresults = resolve(domain, IpAddrType::IPAddrV6Only).unwrap();
assert!(rresults.iter().all(|item| item.ip.is_ipv6()));
assert!(rresults.iter().all(|item| item.ip_addr().is_ipv6()));
}

#[test]
Expand All @@ -148,7 +184,7 @@ mod tests {
fn test_resolve_single_address() {
let domain = unsafe { IPV4_ONLY_TEST_DOMAIN };
let rresult = resolve_single(domain, IpAddrType::IPAddrMixed).unwrap();
assert!(rresult.ip.is_ipv4());
assert!(rresult.ttl != 0);
assert!(rresult.ip_addr().is_ipv4());
assert!(rresult.ttl != Duration::seconds(0));
}
}
10 changes: 0 additions & 10 deletions vsock_proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
pub mod dns;
pub mod proxy;

use std::net::IpAddr;

#[derive(Copy, Clone, PartialEq)]
pub enum IpAddrType {
/// Only allows IP4 addresses
Expand All @@ -16,13 +14,5 @@ pub enum IpAddrType {
IPAddrMixed,
}

#[derive(Copy, Clone, Debug)]
pub struct DnsResolveResult {
///Resolved address
pub ip: IpAddr,
///DNS TTL value
pub ttl: u32,
}

/// The most common result type provided by VsockProxy operations.
pub type VsockProxyResult<T> = Result<T, String>;
51 changes: 24 additions & 27 deletions vsock_proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

/// Contains code for Proxy, a library used for translating vsock traffic to
/// TCP traffic
use chrono::{DateTime, Duration, Utc};
use log::{info, warn};
use nix::sys::select::{select, FdSet};
use nix::sys::socket::SockType;
Expand All @@ -16,6 +15,7 @@ use threadpool::ThreadPool;
use vsock::{VsockAddr, VsockListener};
use yaml_rust::YamlLoader;

use crate::dns::DnsResolutionInfo;
use crate::{dns, IpAddrType, VsockProxyResult};

const BUFF_SIZE: usize = 8192;
Expand Down Expand Up @@ -43,7 +43,7 @@ pub fn check_allowlist(

// Obtain the remote server's IP address.
let dns_result = dns::resolve_single(remote_host, ip_addr_type)?;
let remote_addr = dns_result.ip;
let remote_addr = dns_result.ip_addr();

for raw_service in services {
let addr = raw_service["address"].as_str().ok_or("No address field")?;
Expand All @@ -69,7 +69,7 @@ pub fn check_allowlist(
let remote_addr_matched = rresults
.into_iter()
.flatten()
.find(|rresult| rresult.ip == remote_addr)
.find(|rresult| rresult.ip_addr() == remote_addr)
.map(|_| remote_addr);

match remote_addr_matched {
Expand All @@ -89,10 +89,8 @@ pub fn check_allowlist(
pub struct Proxy {
local_port: u32,
remote_host: String,
remote_addr: Option<IpAddr>,
remote_port: u16,
dns_resolve_date: Option<DateTime<Utc>>,
dns_refresh_interval: Option<Duration>,
dns_resolution_info: Option<DnsResolutionInfo>,
pool: ThreadPool,
sock_type: SockType,
ip_addr_type: IpAddrType,
Expand All @@ -108,17 +106,13 @@ impl Proxy {
) -> VsockProxyResult<Self> {
let pool = ThreadPool::new(num_workers);
let sock_type = SockType::Stream;
let remote_addr: Option<IpAddr> = None;
let dns_resolve_date: Option<DateTime<Utc>> = None;
let dns_refresh_interval: Option<Duration> = None;
let dns_resolution_info: Option<DnsResolutionInfo> = None;

Ok(Proxy {
local_port,
remote_host,
remote_addr,
remote_port,
dns_resolve_date,
dns_refresh_interval,
dns_resolution_info,
pool,
sock_type,
ip_addr_type,
Expand All @@ -145,28 +139,31 @@ impl Proxy {
.map_err(|_| "Could not accept connection")?;
info!("Accepted connection on {:?}", client_addr);

let needs_resolve =
|d: DateTime<Utc>, i: Duration| (Utc::now() - d + Duration::seconds(2)) > i;
let dns_needs_resolution = self
.dns_resolution_info
.map_or(true, |info| info.is_expired());

if self.dns_resolve_date.is_none()
|| needs_resolve(
self.dns_resolve_date.unwrap(),
self.dns_refresh_interval.unwrap(),
)
{
let remote_addr = if dns_needs_resolution {
info!("Resolving hostname: {}.", self.remote_host);
let result = dns::resolve_single(&self.remote_host, self.ip_addr_type)?;
self.dns_resolve_date = Some(Utc::now());
self.dns_refresh_interval = Some(Duration::seconds(result.ttl as i64));
self.remote_addr = Some(result.ip);

let dns_resolution = dns::resolve_single(&self.remote_host, self.ip_addr_type)?;

info!(
"Using IP \"{:?}\" for the given server \"{}\". (TTL: {} secs)",
result.ip, self.remote_host, result.ttl
dns_resolution.ip_addr(),
self.remote_host,
dns_resolution.ttl().num_seconds()
);
}

let sockaddr = SocketAddr::new(self.remote_addr.unwrap(), self.remote_port);
self.dns_resolution_info = Some(dns_resolution);
dns_resolution.ip_addr()
} else {
self.dns_resolution_info
.ok_or("DNS resolution failed!")?
.ip_addr()
};

let sockaddr = SocketAddr::new(remote_addr, self.remote_port);
let sock_type = self.sock_type;
self.pool.execute(move || {
let mut server = match sock_type {
Expand Down

0 comments on commit 15e5567

Please sign in to comment.