diff --git a/proto/workload.proto b/proto/workload.proto index 3496fd634..1daee71b5 100644 --- a/proto/workload.proto +++ b/proto/workload.proto @@ -74,6 +74,52 @@ message Service { // Waypoint is the waypoint proxy for this service. When set, all incoming requests must go // through the waypoint. GatewayAddress waypoint = 7; + + // Load balancing policy for selecting endpoints. + // Note: this applies only to connecting directly to the workload; when waypoints are used, the waypoint's load_balancing + // configuration is used. + LoadBalancing load_balancing = 8; +} + +message LoadBalancing { + enum Scope { + UNSPECIFIED_SCOPE = 0; + // Prefer traffic in the same region. + REGION = 1; + // Prefer traffic in the same zone. + ZONE = 2; + // Prefer traffic in the same subzone. + SUBZONE = 3; + // Prefer traffic on the same node. + NODE = 4; + // Prefer traffic in the same cluster. + CLUSTER = 5; + // Prefer traffic in the same network. + NETWORK = 6; + } + enum Mode { + UNSPECIFIED_MODE = 0; + + // In STRICT mode, only endpoints that meets all of the routing preferences will be considered. + // This can be used, for instance, to keep traffic ONLY within the same cluster/node/region. + // This should be used with caution, as it can result in all traffic being dropped if there is no matching endpoints, + // even if there are endpoints outside of the preferences. + STRICT = 1; + // In FAILOVER mode, endpoint selection will prefer endpoints that match all preferences, but failover to groups of endpoints + // that match less (or, eventually, none) preferences. + // For instance, with `[NETWORK, REGION, ZONE]`, we will send to: + // 1. Endpoints matching `[NETWORK, REGION, ZONE]` + // 2. Endpoints matching `[NETWORK, REGION]` + // 3. Endpoints matching `[NETWORK]` + // 4. Any endpoints + FAILOVER = 2; + } + + // routing_preference defines what scopes we want to keep traffic within. + // The `mode` determines how these routing preferences are handled + repeated Scope routing_preference = 1; + // mode defines how we should handle the routing preferences. + Mode mode = 2; } // Workload represents a workload - an endpoint (or collection behind a hostname). @@ -116,6 +162,7 @@ message Workload { // a workload that backs a Kubernetes service will typically have only endpoints. A // workload that backs a headless Kubernetes service, however, will have both // addresses as well as a hostname used for direct access to the headless endpoint. + // TODO: support this field string hostname = 21; // Network represents the network this workload is on. This may be elided for the default network. @@ -178,10 +225,19 @@ message Workload { // The cluster ID that the workload instance belongs to string cluster_id = 18; + // The Locality defines information about where a workload is geographically deployed + Locality locality = 24; + // Reservations for deleted fields. reserved 15; } +message Locality { + string region = 1; + string zone = 2; + string subzone = 3; +} + enum WorkloadStatus { // Workload is healthy and ready to serve traffic. HEALTHY = 0; diff --git a/src/admin.rs b/src/admin.rs index dcc0aa4a5..d3936374c 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -447,6 +447,8 @@ mod tests { use crate::xds::istio::security::StringMatch as XdsStringMatch; use crate::xds::istio::workload::gateway_address::Destination as XdsDestination; use crate::xds::istio::workload::GatewayAddress as XdsGatewayAddress; + use crate::xds::istio::workload::LoadBalancing as XdsLoadBalancing; + use crate::xds::istio::workload::Locality as XdsLocality; use crate::xds::istio::workload::NetworkAddress as XdsNetworkAddress; use crate::xds::istio::workload::Port as XdsPort; use crate::xds::istio::workload::PortList as XdsPortList; @@ -637,6 +639,11 @@ mod tests { }], }, )]), + locality: Some(XdsLocality { + region: "region".to_string(), + zone: "zone".to_string(), + subzone: "subezone".to_string(), + }), // ..Default::default() // intentionally don't default. we want all fields populated }; @@ -654,7 +661,10 @@ mod tests { }], subject_alt_names: vec!["SAN1".to_string(), "SAN2".to_string()], waypoint: None, - // ..Default::default() // intentionally don't default. we want all fields populated + load_balancing: Some(XdsLoadBalancing { + routing_preference: vec![1, 2], + mode: 1, + }), // ..Default::default() // intentionally don't default. we want all fields populated }; let auth = XdsAuthorization { diff --git a/src/proxy.rs b/src/proxy.rs index c85685cb6..26c25efa0 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -661,9 +661,7 @@ mod tests { let w = mock_default_gateway_workload(); let s = mock_default_gateway_service(); let mut state = state::ProxyState::default(); - if let Err(err) = state.workloads.insert(w) { - panic!("received error inserting workload: {}", err); - } + state.workloads.insert(w); state.services.insert(s); let state = state::DemandProxyState::new( Arc::new(RwLock::new(state)), @@ -741,6 +739,7 @@ mod tests { authorization_policies: Vec::new(), native_tunnel: false, application_tunnel: None, + locality: Default::default(), } } @@ -769,6 +768,7 @@ mod tests { authorization_policies: Vec::new(), native_tunnel: false, application_tunnel: None, + locality: Default::default(), } } @@ -806,6 +806,7 @@ mod tests { endpoints, subject_alt_names: vec![], waypoint: None, + load_balancer: None, } } diff --git a/src/proxy/inbound.rs b/src/proxy/inbound.rs index b1560499d..a05844dc3 100644 --- a/src/proxy/inbound.rs +++ b/src/proxy/inbound.rs @@ -785,6 +785,7 @@ mod tests { .collect(), subject_alt_names: vec![format!("{name}.default.svc.cluster.local")], waypoint: waypoint.service_attached(), + load_balancer: None, } }); @@ -816,7 +817,7 @@ mod tests { state.services.insert(svc); } for wl in workloads { - state.workloads.insert(wl)?; + state.workloads.insert(wl); } Ok(DemandProxyState::new( diff --git a/src/proxy/outbound.rs b/src/proxy/outbound.rs index 32d816af1..57fe47564 100644 --- a/src/proxy/outbound.rs +++ b/src/proxy/outbound.rs @@ -404,7 +404,7 @@ impl OutboundConnection { let waypoint_us = self .pi .state - .fetch_upstream(&self.pi.cfg.network, waypoint_vip) + .fetch_upstream(&self.pi.cfg.network, &source_workload, waypoint_vip) .await .ok_or(proxy::Error::UnknownWaypoint( "unable to determine waypoint upstream".to_string(), @@ -414,7 +414,7 @@ impl OutboundConnection { let waypoint_ip = self .pi .state - .load_balance( + .pick_workload_destination( &waypoint_workload, &source_workload, self.pi.metrics.clone(), @@ -447,7 +447,7 @@ impl OutboundConnection { let us = self .pi .state - .fetch_upstream(&source_workload.network, target) + .fetch_upstream(&source_workload.network, &source_workload, target) .await; if us.is_none() { // For case no upstream found, passthrough it @@ -469,7 +469,7 @@ impl OutboundConnection { let workload_ip = self .pi .state - .load_balance( + .pick_workload_destination( &mutable_us.workload, &source_workload, self.pi.metrics.clone(), @@ -491,7 +491,7 @@ impl OutboundConnection { match self .pi .state - .fetch_waypoint(&mutable_us.workload, workload_ip) + .fetch_waypoint(&mutable_us.workload, &source_workload, workload_ip) .await { Ok(None) => {} // workload doesn't have a waypoint; this is fine @@ -500,7 +500,7 @@ impl OutboundConnection { let waypoint_ip = self .pi .state - .load_balance( + .pick_workload_destination( &waypoint_workload, &source_workload, self.pi.metrics.clone(), diff --git a/src/state.rs b/src/state.rs index f81a2dc41..22fcd5a67 100644 --- a/src/state.rs +++ b/src/state.rs @@ -16,7 +16,7 @@ use crate::identity::SecretManager; use crate::proxy; use crate::proxy::{Error, OnDemandDnsLabels}; use crate::state::policy::PolicyStore; -use crate::state::service::ServiceStore; +use crate::state::service::{Endpoint, LoadBalancerMode, LoadBalancerScopes, ServiceStore}; use crate::state::service::{Service, ServiceDescription}; use crate::state::workload::{ address::Address, gatewayaddress::Destination, network_addr, NamespacedHostname, @@ -206,7 +206,12 @@ impl ProxyState { } } - pub fn find_upstream(&self, network: &str, addr: SocketAddr) -> Option { + pub fn find_upstream( + &self, + network: &str, + source_workload: &Workload, + addr: SocketAddr, + ) -> Option { if let Some(svc) = self.services.get_by_vip(&network_addr(network, addr.ip())) { let Some(target_port) = svc.ports.get(&addr.port()) else { debug!( @@ -218,7 +223,7 @@ impl ProxyState { }; // Randomly pick an upstream // TODO: do this more efficiently, and not just randomly - let Some((_, ep)) = svc.endpoints.iter().choose(&mut rand::thread_rng()) else { + let Some(ep) = self.load_balance(source_workload, &svc) else { debug!("VIP {} has no healthy endpoints", addr); return None; }; @@ -250,6 +255,62 @@ impl ProxyState { } None } + + fn load_balance<'a>(&self, src: &Workload, svc: &'a Service) -> Option<&'a Endpoint> { + match svc.load_balancer { + None => svc.endpoints.values().choose(&mut rand::thread_rng()), + Some(ref lb) => { + let ranks = svc + .endpoints + .iter() + .filter_map(|(_, ep)| { + let Some(wl) = self.workloads.find_uid(&ep.workload_uid) else { + debug!("failed to fetch workload for {}", ep.workload_uid); + return None; + }; + // Load balancer will define N targets we want to match + // Consider [network, region, zone] + // Rank = 3 means we match all of them + // Rank = 2 means network and region match + // Rank = 0 means none match + let mut rank = 0; + for target in &lb.routing_preferences { + let matches = match target { + LoadBalancerScopes::Region => { + src.locality.region == wl.locality.region + } + LoadBalancerScopes::Zone => src.locality.zone == wl.locality.zone, + LoadBalancerScopes::Subzone => { + src.locality.subzone == wl.locality.subzone + } + LoadBalancerScopes::Node => src.node == wl.node, + LoadBalancerScopes::Cluster => src.cluster_id == wl.cluster_id, + LoadBalancerScopes::Network => src.network == wl.network, + }; + if matches { + rank += 1; + } else { + break; + } + } + // Doesn't match all, and required to. Do not select this endpoint + if lb.mode == LoadBalancerMode::Strict + && rank != lb.routing_preferences.len() + { + return None; + } + Some((rank, ep)) + }) + .collect::>(); + let max = *ranks.iter().map(|(rank, _ep)| rank).max()?; + ranks + .into_iter() + .filter(|(rank, _ep)| *rank == max) + .map(|(_, ep)| ep) + .choose(&mut rand::thread_rng()) + } + } + } } /// Wrapper around [ProxyState] that provides additional methods for requesting information @@ -355,7 +416,7 @@ impl DemandProxyState { } // this should only be called once per request (for the workload itself and potentially its waypoint) - pub async fn load_balance( + pub async fn pick_workload_destination( &self, dst_workload: &Workload, src_workload: &Workload, @@ -551,14 +612,23 @@ impl DemandProxyState { self.state.read().unwrap().workloads.find_uid(uid) } - pub async fn fetch_upstream(&self, network: &str, addr: SocketAddr) -> Option { + pub async fn fetch_upstream( + &self, + network: &str, + source_workload: &Workload, + addr: SocketAddr, + ) -> Option { self.fetch_address(&network_addr(network, addr.ip())).await; - self.state.read().unwrap().find_upstream(network, addr) + self.state + .read() + .unwrap() + .find_upstream(network, source_workload, addr) } pub async fn fetch_waypoint( &self, wl: &Workload, + source_workload: &Workload, workload_ip: IpAddr, ) -> Result, WaypointError> { let Some(gw_address) = &wl.waypoint else { @@ -576,7 +646,7 @@ impl DemandProxyState { }; let wp_socket_addr = SocketAddr::new(wp_nw_addr.address, gw_address.hbone_mtls_port); match self - .fetch_upstream(&wp_nw_addr.network, wp_socket_addr) + .fetch_upstream(&wp_nw_addr.network, source_workload, wp_socket_addr) .await { Some(mut upstream) => { @@ -730,18 +800,20 @@ impl ProxyStateManager { #[cfg(test)] mod tests { + use crate::state::service::LoadBalancer; + use crate::state::workload::Locality; use std::{net::Ipv4Addr, net::SocketAddrV4, time::Duration}; use super::*; use crate::test_helpers; + use crate::test_helpers::TEST_SERVICE_NAMESPACE; #[tokio::test] async fn lookup_address() { let mut state = ProxyState::default(); state .workloads - .insert(test_helpers::test_default_workload()) - .unwrap(); + .insert(test_helpers::test_default_workload()); state.services.insert(test_helpers::mock_default_service()); let mock_proxy_state = DemandProxyState::new( @@ -815,7 +887,7 @@ mod tests { workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 2))], ..test_helpers::test_default_workload() }; - state.workloads.insert(wl).unwrap(); + state.workloads.insert(wl); let mock_proxy_state = DemandProxyState::new( Arc::new(RwLock::new(state)), @@ -874,4 +946,200 @@ mod tests { assert!(!mock_proxy_state.assert_rbac(&ctx).await); } } + + #[tokio::test] + async fn test_load_balance() { + let mut state = ProxyState::default(); + let wl_no_locality = Workload { + uid: "cluster1//v1/Pod/default/wl_no_locality".to_string(), + name: "wl_no_locality".to_string(), + namespace: "default".to_string(), + trust_domain: "cluster.local".to_string(), + service_account: "default".to_string(), + workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))], + ..test_helpers::test_default_workload() + }; + let wl_match = Workload { + uid: "cluster1//v1/Pod/default/wl_match".to_string(), + name: "wl_match".to_string(), + namespace: "default".to_string(), + trust_domain: "cluster.local".to_string(), + service_account: "default".to_string(), + workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 2))], + network: "network".to_string(), + locality: Locality { + region: "reg".to_string(), + zone: "zone".to_string(), + subzone: "".to_string(), + }, + ..test_helpers::test_default_workload() + }; + let wl_almost = Workload { + uid: "cluster1//v1/Pod/default/wl_almost".to_string(), + name: "wl_almost".to_string(), + namespace: "default".to_string(), + trust_domain: "cluster.local".to_string(), + service_account: "default".to_string(), + workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 3))], + network: "network".to_string(), + locality: Locality { + region: "reg".to_string(), + zone: "not-zone".to_string(), + subzone: "".to_string(), + }, + ..test_helpers::test_default_workload() + }; + let _ep_almost = Workload { + uid: "cluster1//v1/Pod/default/ep_almost".to_string(), + name: "wl_almost".to_string(), + namespace: "default".to_string(), + trust_domain: "cluster.local".to_string(), + service_account: "default".to_string(), + workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 4))], + network: "network".to_string(), + locality: Locality { + region: "reg".to_string(), + zone: "other-not-zone".to_string(), + subzone: "".to_string(), + }, + ..test_helpers::test_default_workload() + }; + let _ep_no_match = Workload { + uid: "cluster1//v1/Pod/default/ep_no_match".to_string(), + name: "wl_almost".to_string(), + namespace: "default".to_string(), + trust_domain: "cluster.local".to_string(), + service_account: "default".to_string(), + workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 5))], + network: "not-network".to_string(), + locality: Locality { + region: "not-reg".to_string(), + zone: "unmatched-zone".to_string(), + subzone: "".to_string(), + }, + ..test_helpers::test_default_workload() + }; + let endpoints = HashMap::from([ + ( + "cluster1//v1/Pod/default/ep_almost".to_string(), + Endpoint { + workload_uid: "cluster1//v1/Pod/default/ep_almost".to_string(), + service: NamespacedHostname { + namespace: TEST_SERVICE_NAMESPACE.to_string(), + hostname: "example.com".to_string(), + }, + address: Some(NetworkAddress { + address: "192.168.0.4".parse().unwrap(), + network: "".to_string(), + }), + port: HashMap::from([(80u16, 80u16)]), + }, + ), + ( + "cluster1//v1/Pod/default/ep_no_match".to_string(), + Endpoint { + workload_uid: "cluster1//v1/Pod/default/ep_almost".to_string(), + service: NamespacedHostname { + namespace: TEST_SERVICE_NAMESPACE.to_string(), + hostname: "example.com".to_string(), + }, + address: Some(NetworkAddress { + address: "192.168.0.5".parse().unwrap(), + network: "".to_string(), + }), + port: HashMap::from([(80u16, 80u16)]), + }, + ), + ( + "cluster1//v1/Pod/default/wl_match".to_string(), + Endpoint { + workload_uid: "cluster1//v1/Pod/default/wl_match".to_string(), + service: NamespacedHostname { + namespace: TEST_SERVICE_NAMESPACE.to_string(), + hostname: "example.com".to_string(), + }, + address: Some(NetworkAddress { + address: "192.168.0.2".parse().unwrap(), + network: "".to_string(), + }), + port: HashMap::from([(80u16, 80u16)]), + }, + ), + ]); + let strict_svc = Service { + endpoints: endpoints.clone(), + load_balancer: Some(LoadBalancer { + mode: LoadBalancerMode::Strict, + routing_preferences: vec![ + LoadBalancerScopes::Network, + LoadBalancerScopes::Region, + LoadBalancerScopes::Zone, + ], + }), + ..test_helpers::mock_default_service() + }; + let failover_svc = Service { + endpoints, + load_balancer: Some(LoadBalancer { + mode: LoadBalancerMode::Failover, + routing_preferences: vec![ + LoadBalancerScopes::Network, + LoadBalancerScopes::Region, + LoadBalancerScopes::Zone, + ], + }), + ..test_helpers::mock_default_service() + }; + state.workloads.insert(wl_no_locality.clone()); + state.workloads.insert(wl_match.clone()); + state.workloads.insert(wl_almost.clone()); + state.services.insert(strict_svc.clone()); + state.services.insert(failover_svc.clone()); + + let assert_endpoint = |src: &Workload, svc: &Service, ips: Vec<&str>, desc: &str| { + let got = state + .load_balance(src, svc) + .and_then(|ep| ep.address.clone()) + .map(|addr| addr.address.to_string()); + if ips.is_empty() { + assert!(got.is_none(), "{}", desc); + } else { + let want: Vec = ips.iter().map(ToString::to_string).collect(); + assert!(want.contains(&got.unwrap()), "{}", desc); + } + }; + + assert_endpoint( + &wl_no_locality, + &strict_svc, + vec![], + "strict no match should not select", + ); + assert_endpoint( + &wl_almost, + &strict_svc, + vec![], + "strict no match should not select", + ); + assert_endpoint(&wl_match, &strict_svc, vec!["192.168.0.2"], "strict match"); + + assert_endpoint( + &wl_no_locality, + &failover_svc, + vec!["192.168.0.2", "192.168.0.4", "192.168.0.5"], + "failover no match can select any endpoint", + ); + assert_endpoint( + &wl_almost, + &failover_svc, + vec!["192.168.0.2", "192.168.0.4"], + "failover almost match can select any close matches", + ); + assert_endpoint( + &wl_match, + &failover_svc, + vec!["192.168.0.2"], + "failover full match selects closest match", + ); + } } diff --git a/src/state/service.rs b/src/state/service.rs index 0d312271b..f2e3d537f 100644 --- a/src/state/service.rs +++ b/src/state/service.rs @@ -26,6 +26,7 @@ use crate::state::workload::{ WorkloadError, }; use crate::xds; +use crate::xds::istio::workload::load_balancing::Scope as XdsScope; use crate::xds::istio::workload::PortList; #[derive(Debug, Eq, PartialEq, Clone, serde::Serialize, serde::Deserialize)] @@ -43,6 +44,57 @@ pub struct Service { #[serde(default)] pub subject_alt_names: Vec, pub waypoint: Option, + + pub load_balancer: Option, +} + +#[derive(Debug, Eq, PartialEq, Clone, serde::Serialize, serde::Deserialize)] +pub enum LoadBalancerMode { + Strict, + Failover, +} + +impl From for LoadBalancerMode { + fn from(value: xds::istio::workload::load_balancing::Mode) -> Self { + match value { + xds::istio::workload::load_balancing::Mode::Strict => LoadBalancerMode::Strict, + xds::istio::workload::load_balancing::Mode::UnspecifiedMode => { + LoadBalancerMode::Failover + } + xds::istio::workload::load_balancing::Mode::Failover => LoadBalancerMode::Failover, + } + } +} + +#[derive(Debug, Eq, PartialEq, Clone, serde::Serialize, serde::Deserialize)] +pub enum LoadBalancerScopes { + Region, + Zone, + Subzone, + Node, + Cluster, + Network, +} + +impl TryFrom for LoadBalancerScopes { + type Error = WorkloadError; + fn try_from(value: XdsScope) -> Result { + match value { + XdsScope::Region => Ok(LoadBalancerScopes::Region), + XdsScope::Zone => Ok(LoadBalancerScopes::Zone), + XdsScope::Subzone => Ok(LoadBalancerScopes::Subzone), + XdsScope::Node => Ok(LoadBalancerScopes::Node), + XdsScope::Cluster => Ok(LoadBalancerScopes::Cluster), + XdsScope::Network => Ok(LoadBalancerScopes::Network), + _ => Err(WorkloadError::EnumParse("invalid target".to_string())), + } + } +} + +#[derive(Debug, Eq, PartialEq, Clone, serde::Serialize, serde::Deserialize)] +pub struct LoadBalancer { + pub routing_preferences: Vec, + pub mode: LoadBalancerMode, } impl Service { @@ -116,6 +168,22 @@ impl TryFrom<&XdsService> for Service { Some(w) => Some(GatewayAddress::try_from(w)?), None => None, }; + let lb = if let Some(lb) = &s.load_balancing { + Some(LoadBalancer { + routing_preferences: lb + .routing_preference + .iter() + .map(|r| { + xds::istio::workload::load_balancing::Scope::try_from(*r) + .map_err(WorkloadError::DecodeError) + .and_then(|r| r.try_into()) + }) + .collect::, WorkloadError>>()?, + mode: xds::istio::workload::load_balancing::Mode::try_from(lb.mode)?.into(), + }) + } else { + None + }; let svc = Service { name: s.name.to_string(), namespace: s.namespace.to_string(), @@ -128,6 +196,7 @@ impl TryFrom<&XdsService> for Service { endpoints: Default::default(), // Will be populated once inserted into the store. subject_alt_names: s.subject_alt_names.clone(), waypoint, + load_balancer: lb, }; Ok(svc) } diff --git a/src/state/workload.rs b/src/state/workload.rs index 034d8a45b..ff9ede71e 100644 --- a/src/state/workload.rs +++ b/src/state/workload.rs @@ -63,6 +63,23 @@ pub enum HealthStatus { Unhealthy, } +#[derive(Default, Debug, Hash, Eq, PartialEq, Clone, serde::Serialize, serde::Deserialize)] +pub struct Locality { + pub region: String, + pub zone: String, + pub subzone: String, +} + +impl From for Locality { + fn from(value: xds::istio::workload::Locality) -> Self { + Locality { + region: value.region, + zone: value.zone, + subzone: value.subzone, + } + } +} + impl From for HealthStatus { fn from(value: xds::istio::workload::WorkloadStatus) -> Self { match value { @@ -184,6 +201,9 @@ pub struct Workload { #[serde(default)] pub cluster_id: String, + + #[serde(default)] + pub locality: Locality, } fn is_default(t: &T) -> bool { @@ -395,6 +415,8 @@ impl TryFrom<&XdsWorkload> for Workload { authorization_policies: resource.authorization_policies, + locality: resource.locality.map(Locality::from).unwrap_or_default(), + cluster_id: { let result = resource.cluster_id; if result.is_empty() { @@ -562,7 +584,7 @@ pub struct WorkloadStore { } impl WorkloadStore { - pub fn insert(&mut self, w: Workload) -> anyhow::Result<()> { + pub fn insert(&mut self, w: Workload) { // First, remove the entry entirely to make sure things are cleaned up properly. self.remove(w.uid.as_str()); @@ -575,7 +597,6 @@ impl WorkloadStore { self.by_hostname.insert(w.hostname.clone(), w.clone()); } self.by_uid.insert(w.uid.clone(), w.clone()); - Ok(()) } pub fn remove(&mut self, uid: &str) -> Option { @@ -855,6 +876,7 @@ mod tests { }], subject_alt_names: vec![], waypoint: None, + load_balancing: None, }, ) .unwrap(); @@ -886,6 +908,7 @@ mod tests { }], subject_alt_names: vec![], waypoint: None, + load_balancing: None, }, ) .unwrap(); @@ -940,6 +963,7 @@ mod tests { }], subject_alt_names: vec![], waypoint: None, + load_balancing: None, }, ) .unwrap(); @@ -1171,12 +1195,19 @@ mod tests { let mut found: HashSet = HashSet::new(); // VIP has randomness. We will try to fetch the VIP 1k times and assert the we got the expected results // at least once, and no unexpected results + let wl: Workload = (&XdsWorkload { + name: "some name".to_string(), + ..Default::default() + }) + .try_into() + .unwrap(); for _ in 0..1000 { - if let Some(us) = state - .state - .read() - .unwrap() - .find_upstream("", "127.0.1.1:80".parse().unwrap()) + if let Some(us) = + state + .state + .read() + .unwrap() + .find_upstream("", &wl, "127.0.1.1:80".parse().unwrap()) { let n = &us.workload.name; // borrow name instead of cloning found.insert(n.to_owned()); // insert an owned copy of the borrowed n @@ -1219,12 +1250,12 @@ mod tests { .find_address(&network_addr("", "127.0.0.1".parse().unwrap())); // Make sure we get a valid workload assert!(wl.is_some()); - assert_eq!(wl.unwrap().service_account, "default"); - let us = demand - .state - .read() - .unwrap() - .find_upstream("", "127.10.0.1:80".parse().unwrap()); + assert_eq!(wl.as_ref().unwrap().service_account, "default"); + let us = demand.state.read().unwrap().find_upstream( + "", + wl.as_ref().unwrap(), + "127.10.0.1:80".parse().unwrap(), + ); // Make sure we get a valid VIP assert!(us.is_some()); assert_eq!(us.clone().unwrap().port, 8080); @@ -1234,11 +1265,11 @@ mod tests { ); // test that we can have a service in another network than workloads it selects - let us = demand - .state - .read() - .unwrap() - .find_upstream("remote", "127.10.0.2:80".parse().unwrap()); + let us = demand.state.read().unwrap().find_upstream( + "remote", + wl.as_ref().unwrap(), + "127.10.0.2:80".parse().unwrap(), + ); // Make sure we get a valid VIP assert!(us.is_some()); assert_eq!(us.unwrap().port, 8080); diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 156968c62..86a21515b 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -187,6 +187,7 @@ pub fn mock_default_service() -> Service { endpoints, subject_alt_names: vec![], waypoint: None, + load_balancer: None, } } @@ -215,6 +216,7 @@ pub fn test_default_workload() -> Workload { authorization_policies: Vec::new(), native_tunnel: false, application_tunnel: None, + locality: Default::default(), } } @@ -291,6 +293,7 @@ fn test_custom_svc( )]), subject_alt_names: vec!["spiffe://cluster.local/ns/default/sa/default".to_string()], waypoint: None, + load_balancer: None, }) } diff --git a/src/test_helpers/linux.rs b/src/test_helpers/linux.rs index 0c5449bdf..69c06cd06 100644 --- a/src/test_helpers/linux.rs +++ b/src/test_helpers/linux.rs @@ -212,6 +212,7 @@ impl<'a> TestServiceBuilder<'a> { endpoints: Default::default(), // populated later when workloads are added subject_alt_names: vec![], waypoint: None, + load_balancer: None, }, manager, } diff --git a/src/xds.rs b/src/xds.rs index 76b4b57dd..694dc8466 100644 --- a/src/xds.rs +++ b/src/xds.rs @@ -140,7 +140,7 @@ impl ProxyStateUpdateMutator { self.cert_fetcher.prefetch_cert(&workload); // Lock and upstate the stores. - state.workloads.insert(workload)?; + state.workloads.insert(workload); while let Some(endpoint) = endpoints.pop() { state.services.insert_endpoint(endpoint); } @@ -396,7 +396,7 @@ impl LocalClient { let num_policies = r.policies.len(); for wl in r.workloads { trace!("inserting local workload {}", &wl.workload.uid); - state.workloads.insert(wl.workload.clone())?; + state.workloads.insert(wl.workload.clone()); self.cert_fetcher.prefetch_cert(&wl.workload); let services: HashMap = wl