diff --git a/README.md b/README.md index ad36939..dc69388 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ Windows route-table boundary: - read-only best-route lookup for a relay destination IP - selected source address, next hop, interface index/LUID, prefix, and metric - interface index/LUID lookup from Windows network adapter GUIDs +- scoped IP interface metric overrides with restore-on-drop behavior - scoped host-route pinning for the relay IP on the pre-TAP interface - non-Windows builds return a clear unsupported-platform error diff --git a/crates/lanparty-client-route/src/lib.rs b/crates/lanparty-client-route/src/lib.rs index 807bbd8..718a3af 100644 --- a/crates/lanparty-client-route/src/lib.rs +++ b/crates/lanparty-client-route/src/lib.rs @@ -44,6 +44,65 @@ impl NetworkInterfaceIdentity { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum IpInterfaceFamily { + Ipv4, + Ipv6, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct InterfaceMetricSnapshot { + identity: NetworkInterfaceIdentity, + family: IpInterfaceFamily, + automatic_metric: bool, + metric: u32, + disable_default_routes: bool, +} + +impl InterfaceMetricSnapshot { + #[cfg_attr(not(windows), allow(dead_code))] + const fn new( + identity: NetworkInterfaceIdentity, + family: IpInterfaceFamily, + automatic_metric: bool, + metric: u32, + disable_default_routes: bool, + ) -> Self { + Self { + identity, + family, + automatic_metric, + metric, + disable_default_routes, + } + } + + #[must_use] + pub const fn identity(self) -> NetworkInterfaceIdentity { + self.identity + } + + #[must_use] + pub const fn family(self) -> IpInterfaceFamily { + self.family + } + + #[must_use] + pub const fn automatic_metric(self) -> bool { + self.automatic_metric + } + + #[must_use] + pub const fn metric(self) -> u32 { + self.metric + } + + #[must_use] + pub const fn disable_default_routes(self) -> bool { + self.disable_default_routes + } +} + impl RouteSnapshot { #[cfg_attr(not(windows), allow(dead_code))] #[allow(clippy::too_many_arguments)] @@ -115,6 +174,8 @@ mod windows; #[cfg(windows)] pub use windows::{PinnedRelayRoute, best_route_to, interface_identity_from_guid, pin_relay_route}; +#[cfg(windows)] +pub use windows::{ScopedInterfaceMetric, interface_metric, set_scoped_interface_metric}; #[cfg(not(windows))] pub fn best_route_to(_destination: IpAddr) -> Result { @@ -137,6 +198,29 @@ pub fn interface_identity_from_guid(_interface_guid: &str) -> Result Result { + bail!("Windows interface metric lookup is only available on Windows"); +} + +#[cfg(not(windows))] +pub fn set_scoped_interface_metric( + _identity: NetworkInterfaceIdentity, + _family: IpInterfaceFamily, + _metric: u32, +) -> Result { + bail!("Windows interface metric updates are only available on Windows"); +} + #[cfg(test)] mod tests { use super::*; @@ -172,6 +256,19 @@ mod tests { assert_eq!(identity.luid(), 34); } + #[test] + fn exposes_interface_metric_snapshot_fields() { + let identity = NetworkInterfaceIdentity::new(12, 34); + let snapshot = + InterfaceMetricSnapshot::new(identity, IpInterfaceFamily::Ipv4, true, 25, false); + + assert_eq!(snapshot.identity(), identity); + assert_eq!(snapshot.family(), IpInterfaceFamily::Ipv4); + assert!(snapshot.automatic_metric()); + assert_eq!(snapshot.metric(), 25); + assert!(!snapshot.disable_default_routes()); + } + #[cfg(not(windows))] #[test] fn rejects_route_inspection_on_non_windows() { @@ -201,6 +298,15 @@ mod tests { assert!(interface_identity_from_guid("{00112233-4455-6677-8899-AABBCCDDEEFF}").is_err()); } + #[cfg(not(windows))] + #[test] + fn rejects_interface_metric_operations_on_non_windows() { + let identity = NetworkInterfaceIdentity::new(12, 34); + + assert!(interface_metric(identity, IpInterfaceFamily::Ipv4).is_err()); + assert!(set_scoped_interface_metric(identity, IpInterfaceFamily::Ipv4, 500).is_err()); + } + fn ip(value: &str) -> IpAddr { value.parse().unwrap() } diff --git a/crates/lanparty-client-route/src/windows.rs b/crates/lanparty-client-route/src/windows.rs index e761232..ea5a026 100644 --- a/crates/lanparty-client-route/src/windows.rs +++ b/crates/lanparty-client-route/src/windows.rs @@ -10,8 +10,9 @@ use windows_sys::Win32::{ NetworkManagement::{ IpHelper::{ ConvertInterfaceGuidToLuid, ConvertInterfaceLuidToIndex, CreateIpForwardEntry2, - DeleteIpForwardEntry2, GetBestRoute2, IP_ADDRESS_PREFIX, InitializeIpForwardEntry, - MIB_IPFORWARD_ROW2, + DeleteIpForwardEntry2, GetBestRoute2, GetIpInterfaceEntry, IP_ADDRESS_PREFIX, + InitializeIpForwardEntry, InitializeIpInterfaceEntry, MIB_IPFORWARD_ROW2, + MIB_IPINTERFACE_ROW, SetIpInterfaceEntry, }, Ndis::NET_LUID_LH, }, @@ -22,6 +23,7 @@ use windows_sys::Win32::{ }; use windows_sys::core::GUID; +use crate::{InterfaceMetricSnapshot, IpInterfaceFamily}; use crate::{NetworkInterfaceIdentity, RouteSnapshot}; pub fn interface_identity_from_guid(interface_guid: &str) -> Result { @@ -49,6 +51,64 @@ pub fn interface_identity_from_guid(interface_guid: &str) -> Result Result { + let row = get_interface_row(identity, family)?; + + Ok(metric_snapshot(identity, family, row)) +} + +pub fn set_scoped_interface_metric( + identity: NetworkInterfaceIdentity, + family: IpInterfaceFamily, + metric: u32, +) -> Result { + let previous = interface_metric(identity, family)?; + let mut row = get_interface_row(identity, family)?; + row.UseAutomaticMetric = false; + row.Metric = metric; + set_interface_row(&mut row) + .with_context(|| format!("failed to set {family:?} interface metric to {metric}"))?; + + Ok(ScopedInterfaceMetric { + previous, + active: true, + }) +} + +pub struct ScopedInterfaceMetric { + previous: InterfaceMetricSnapshot, + active: bool, +} + +impl ScopedInterfaceMetric { + #[must_use] + pub const fn previous(&self) -> InterfaceMetricSnapshot { + self.previous + } +} + +impl fmt::Debug for ScopedInterfaceMetric { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ScopedInterfaceMetric") + .field("previous", &self.previous) + .field("active", &self.active) + .finish() + } +} + +impl Drop for ScopedInterfaceMetric { + fn drop(&mut self) { + if !self.active { + return; + } + + let _ = restore_interface_metric(self.previous); + } +} + pub fn best_route_to(destination: IpAddr) -> Result { let destination_sockaddr = sockaddr_from_ip(destination); let mut route = MIB_IPFORWARD_ROW2::default(); @@ -86,6 +146,82 @@ pub fn best_route_to(destination: IpAddr) -> Result { )) } +fn get_interface_row( + identity: NetworkInterfaceIdentity, + family: IpInterfaceFamily, +) -> Result { + let mut row = interface_row_key(identity, family); + let status = unsafe { + // SAFETY: row is initialized with the family and interface identity Windows needs to + // retrieve the IP interface entry. + GetIpInterfaceEntry(&mut row) + }; + windows_status(status).with_context(|| { + format!( + "failed to read {family:?} interface row for index {} LUID {}", + identity.index(), + identity.luid() + ) + })?; + + Ok(row) +} + +fn set_interface_row(row: &mut MIB_IPINTERFACE_ROW) -> Result<()> { + let status = unsafe { + // SAFETY: row was obtained from GetIpInterfaceEntry and only mutable configuration fields + // are changed before calling SetIpInterfaceEntry. + SetIpInterfaceEntry(row) + }; + windows_status(status).context("failed to update IP interface row") +} + +fn restore_interface_metric(snapshot: InterfaceMetricSnapshot) -> Result<()> { + let mut row = get_interface_row(snapshot.identity(), snapshot.family())?; + row.UseAutomaticMetric = snapshot.automatic_metric(); + row.Metric = snapshot.metric(); + set_interface_row(&mut row) +} + +fn interface_row_key( + identity: NetworkInterfaceIdentity, + family: IpInterfaceFamily, +) -> MIB_IPINTERFACE_ROW { + let mut row = MIB_IPINTERFACE_ROW::default(); + unsafe { + // SAFETY: row points to valid writable storage for Windows to initialize. + InitializeIpInterfaceEntry(&mut row); + } + row.Family = address_family(family); + row.InterfaceLuid = NET_LUID_LH { + Value: identity.luid(), + }; + row.InterfaceIndex = identity.index(); + + row +} + +fn metric_snapshot( + identity: NetworkInterfaceIdentity, + family: IpInterfaceFamily, + row: MIB_IPINTERFACE_ROW, +) -> InterfaceMetricSnapshot { + InterfaceMetricSnapshot::new( + identity, + family, + row.UseAutomaticMetric, + row.Metric, + row.DisableDefaultRoutes, + ) +} + +const fn address_family(family: IpInterfaceFamily) -> u16 { + match family { + IpInterfaceFamily::Ipv4 => AF_INET, + IpInterfaceFamily::Ipv6 => AF_INET6, + } +} + pub fn pin_relay_route(route: &RouteSnapshot) -> Result { let mut pinned = pinned_route_row(route); let status = unsafe { @@ -354,6 +490,32 @@ mod tests { assert!(parse_interface_guid("{00112233-4455-6677-8899-AABBCCDDEEGG}").is_err()); } + #[test] + fn builds_interface_row_keys() { + let identity = NetworkInterfaceIdentity::new(12, 34); + let row = interface_row_key(identity, IpInterfaceFamily::Ipv4); + + assert_eq!(row.Family, AF_INET); + assert_eq!(row.InterfaceIndex, 12); + assert_eq!(luid_value(row.InterfaceLuid), 34); + } + + #[test] + fn builds_metric_snapshots_from_rows() { + let identity = NetworkInterfaceIdentity::new(12, 34); + let mut row = interface_row_key(identity, IpInterfaceFamily::Ipv6); + row.UseAutomaticMetric = true; + row.Metric = 500; + row.DisableDefaultRoutes = true; + let snapshot = metric_snapshot(identity, IpInterfaceFamily::Ipv6, row); + + assert_eq!(snapshot.identity(), identity); + assert_eq!(snapshot.family(), IpInterfaceFamily::Ipv6); + assert!(snapshot.automatic_metric()); + assert_eq!(snapshot.metric(), 500); + assert!(snapshot.disable_default_routes()); + } + #[test] fn builds_ipv6_on_link_host_route_row() { let route = RouteSnapshot::new(