diff --git a/homeassistant/components/group/entity.py b/homeassistant/components/group/entity.py index a8fd902798428..489226742ae16 100644 --- a/homeassistant/components/group/entity.py +++ b/homeassistant/components/group/entity.py @@ -8,7 +8,7 @@ import logging from typing import Any -from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_ON +from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_OFF, STATE_ON from homeassistant.core import ( CALLBACK_TYPE, Event, @@ -24,7 +24,7 @@ from homeassistant.helpers.event import async_track_state_change_event from .const import ATTR_AUTO, ATTR_ORDER, DOMAIN, GROUP_ORDER, REG_KEY -from .registry import GroupIntegrationRegistry +from .registry import GroupIntegrationRegistry, SingleStateType ENTITY_ID_FORMAT = DOMAIN + ".{}" @@ -133,6 +133,7 @@ class Group(Entity): _attr_should_poll = False tracking: tuple[str, ...] trackable: tuple[str, ...] + single_state_type_key: SingleStateType | None def __init__( self, @@ -153,7 +154,7 @@ def __init__( self._attr_name = name self._state: str | None = None self._attr_icon = icon - self._set_tracked(entity_ids) + self._entity_ids = entity_ids self._on_off: dict[str, bool] = {} self._assumed: dict[str, bool] = {} self._on_states: set[str] = set() @@ -287,6 +288,7 @@ def _set_tracked(self, entity_ids: Collection[str] | None) -> None: if not entity_ids: self.tracking = () self.trackable = () + self.single_state_type_key = None return registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] @@ -294,16 +296,42 @@ def _set_tracked(self, entity_ids: Collection[str] | None) -> None: tracking: list[str] = [] trackable: list[str] = [] + single_state_type_set: set[SingleStateType] = set() for ent_id in entity_ids: ent_id_lower = ent_id.lower() domain = split_entity_id(ent_id_lower)[0] tracking.append(ent_id_lower) if domain not in excluded_domains: trackable.append(ent_id_lower) + if domain in registry.state_group_mapping: + single_state_type_set.add(registry.state_group_mapping[domain]) + elif domain == DOMAIN: + # If a group contains another group we check if that group + # has a specific single state type + if ent_id in registry.state_group_mapping: + single_state_type_set.add(registry.state_group_mapping[ent_id]) + else: + single_state_type_set.add(SingleStateType(STATE_ON, STATE_OFF)) + + if len(single_state_type_set) == 1: + self.single_state_type_key = next(iter(single_state_type_set)) + # To support groups with nested groups we store the state type + # per group entity_id if there is a single state type + registry.state_group_mapping[self.entity_id] = self.single_state_type_key + else: + self.single_state_type_key = None + self.async_on_remove(self._async_deregister) self.trackable = tuple(trackable) self.tracking = tuple(tracking) + @callback + def _async_deregister(self) -> None: + """Deregister group entity from the registry.""" + registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] + if self.entity_id in registry.state_group_mapping: + registry.state_group_mapping.pop(self.entity_id) + @callback def _async_start(self, _: HomeAssistant | None = None) -> None: """Start tracking members and write state.""" @@ -342,6 +370,7 @@ def async_update_group_state(self) -> None: async def async_added_to_hass(self) -> None: """Handle addition to Home Assistant.""" + self._set_tracked(self._entity_ids) self.async_on_remove(start.async_at_start(self.hass, self._async_start)) async def async_will_remove_from_hass(self) -> None: @@ -430,12 +459,14 @@ def _async_update_group_state(self, tr_state: State | None = None) -> None: # have the same on state we use this state # and its hass.data[REG_KEY].on_off_mapping to off if num_on_states == 1: - on_state = list(self._on_states)[0] + on_state = next(iter(self._on_states)) # If we do not have an on state for any domains # we use None (which will be STATE_UNKNOWN) elif num_on_states == 0: self._state = None return + if self.single_state_type_key: + on_state = self.single_state_type_key.on_state # If the entity domains have more than one # on state, we use STATE_ON/STATE_OFF else: @@ -443,9 +474,10 @@ def _async_update_group_state(self, tr_state: State | None = None) -> None: group_is_on = self.mode(self._on_off.values()) if group_is_on: self._state = on_state + elif self.single_state_type_key: + self._state = self.single_state_type_key.off_state else: - registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] - self._state = registry.on_off_mapping[on_state] + self._state = STATE_OFF def async_get_component(hass: HomeAssistant) -> EntityComponent[Group]: diff --git a/homeassistant/components/group/registry.py b/homeassistant/components/group/registry.py index 9ddf7c0b40923..4ce89a4c7256b 100644 --- a/homeassistant/components/group/registry.py +++ b/homeassistant/components/group/registry.py @@ -1,8 +1,12 @@ -"""Provide the functionality to group entities.""" +"""Provide the functionality to group entities. + +Legacy group support will not be extended for new domains. +""" from __future__ import annotations -from typing import TYPE_CHECKING, Protocol +from dataclasses import dataclass +from typing import Protocol from homeassistant.const import STATE_OFF, STATE_ON from homeassistant.core import HomeAssistant, callback @@ -12,9 +16,6 @@ from .const import DOMAIN, REG_KEY -if TYPE_CHECKING: - from .entity import Group - async def async_setup(hass: HomeAssistant) -> None: """Set up the Group integration registry of integration platforms.""" @@ -43,6 +44,14 @@ def _process_group_platform( platform.async_describe_on_off_states(hass, registry) +@dataclass(frozen=True, slots=True) +class SingleStateType: + """Dataclass to store a single state type.""" + + on_state: str + off_state: str + + class GroupIntegrationRegistry: """Class to hold a registry of integrations.""" @@ -53,8 +62,7 @@ def __init__(self, hass: HomeAssistant) -> None: self.off_on_mapping: dict[str, str] = {STATE_OFF: STATE_ON} self.on_states_by_domain: dict[str, set[str]] = {} self.exclude_domains: set[str] = set() - self.state_group_mapping: dict[str, tuple[str, str]] = {} - self.group_entities: set[Group] = set() + self.state_group_mapping: dict[str, SingleStateType] = {} @callback def exclude_domain(self, domain: str) -> None: @@ -65,12 +73,16 @@ def exclude_domain(self, domain: str) -> None: def on_off_states( self, domain: str, on_states: set[str], default_on_state: str, off_state: str ) -> None: - """Register on and off states for the current domain.""" + """Register on and off states for the current domain. + + Legacy group support will not be extended for new domains. + """ for on_state in on_states: if on_state not in self.on_off_mapping: self.on_off_mapping[on_state] = off_state - if len(on_states) == 1 and off_state not in self.off_on_mapping: + if off_state not in self.off_on_mapping: self.off_on_mapping[off_state] = default_on_state + self.state_group_mapping[domain] = SingleStateType(default_on_state, off_state) self.on_states_by_domain[domain] = on_states diff --git a/tests/components/group/test_init.py b/tests/components/group/test_init.py index d3f2747933edd..9dbd1fe1f6e31 100644 --- a/tests/components/group/test_init.py +++ b/tests/components/group/test_init.py @@ -10,6 +10,7 @@ import pytest from homeassistant.components import group +from homeassistant.components.group.registry import GroupIntegrationRegistry from homeassistant.const import ( ATTR_ASSUMED_STATE, ATTR_FRIENDLY_NAME, @@ -33,7 +34,116 @@ from . import common -from tests.common import MockConfigEntry, assert_setup_component +from tests.common import ( + MockConfigEntry, + MockModule, + MockPlatform, + assert_setup_component, + mock_integration, + mock_platform, +) + + +async def help_test_mixed_entity_platforms_on_off_state_test( + hass: HomeAssistant, + on_off_states1: tuple[set[str], str, str], + on_off_states2: tuple[set[str], str, str], + entity_and_state1_state_2: tuple[str, str | None, str | None], + group_state1: str, + group_state2: str, + grouped_groups: bool = False, +) -> None: + """Help test on_off_states on mixed entity platforms.""" + + class MockGroupPlatform1(MockPlatform): + """Mock a group platform module for test1 integration.""" + + def async_describe_on_off_states( + self, hass: HomeAssistant, registry: GroupIntegrationRegistry + ) -> None: + """Describe group on off states.""" + registry.on_off_states("test1", *on_off_states1) + + class MockGroupPlatform2(MockPlatform): + """Mock a group platform module for test2 integration.""" + + def async_describe_on_off_states( + self, hass: HomeAssistant, registry: GroupIntegrationRegistry + ) -> None: + """Describe group on off states.""" + registry.on_off_states("test2", *on_off_states2) + + mock_integration(hass, MockModule(domain="test1")) + mock_platform(hass, "test1.group", MockGroupPlatform1()) + assert await async_setup_component(hass, "test1", {"test1": {}}) + + mock_integration(hass, MockModule(domain="test2")) + mock_platform(hass, "test2.group", MockGroupPlatform2()) + assert await async_setup_component(hass, "test2", {"test2": {}}) + + if grouped_groups: + assert await async_setup_component( + hass, + "group", + { + "group": { + "test1": { + "entities": [ + item[0] + for item in entity_and_state1_state_2 + if item[0].startswith("test1.") + ] + }, + "test2": { + "entities": [ + item[0] + for item in entity_and_state1_state_2 + if item[0].startswith("test2.") + ] + }, + "test": {"entities": ["group.test1", "group.test2"]}, + } + }, + ) + else: + assert await async_setup_component( + hass, + "group", + { + "group": { + "test": { + "entities": [item[0] for item in entity_and_state1_state_2] + }, + } + }, + ) + await hass.async_block_till_done() + await hass.async_block_till_done() + + state = hass.states.get("group.test") + assert state is not None + + # Set first state + for entity_id, state1, _ in entity_and_state1_state_2: + hass.states.async_set(entity_id, state1) + + await hass.async_block_till_done() + await hass.async_block_till_done() + + state = hass.states.get("group.test") + assert state is not None + assert state.state == group_state1 + + # Set second state + for entity_id, _, state2 in entity_and_state1_state_2: + hass.states.async_set(entity_id, state2) + + await hass.async_block_till_done() + await hass.async_block_till_done() + + state = hass.states.get("group.test") + assert state is not None + assert state.state == group_state2 async def test_setup_group_with_mixed_groupable_states(hass: HomeAssistant) -> None: @@ -1560,6 +1670,7 @@ async def test_group_that_references_a_group_of_covers(hass: HomeAssistant) -> N for entity_id in entity_ids: hass.states.async_set(entity_id, "closed") await hass.async_block_till_done() + assert await async_setup_component(hass, "cover", {}) assert await async_setup_component( hass, @@ -1643,6 +1754,7 @@ async def test_group_that_references_two_types_of_groups(hass: HomeAssistant) -> hass.states.async_set(entity_id, "home") await hass.async_block_till_done() + assert await async_setup_component(hass, "cover", {}) assert await async_setup_component(hass, "device_tracker", {}) assert await async_setup_component( hass, @@ -1884,3 +1996,216 @@ async def test_unhide_members_on_remove( # Check the group members are unhidden assert entity_registry.async_get(f"{group_type}.one").hidden_by == hidden_by assert entity_registry.async_get(f"{group_type}.three").hidden_by == hidden_by + + +@pytest.mark.parametrize("grouped_groups", [False, True]) +@pytest.mark.parametrize( + ("on_off_states1", "on_off_states2"), + [ + ( + ( + { + "on_beer", + "on_milk", + }, + "on_beer", # default ON state test1 + "off_water", # default OFF state test1 + ), + ( + { + "on_beer", + "on_milk", + }, + "on_milk", # default ON state test2 + "off_wine", # default OFF state test2 + ), + ), + ], +) +@pytest.mark.parametrize( + ("entity_and_state1_state_2", "group_state1", "group_state2"), + [ + # All OFF states, no change, so group stays OFF + ( + [ + ("test1.ent1", "off_water", "off_water"), + ("test1.ent2", "off_water", "off_water"), + ("test2.ent1", "off_wine", "off_wine"), + ("test2.ent2", "off_wine", "off_wine"), + ], + STATE_OFF, + STATE_OFF, + ), + # All entities have state on_milk, but the state groups + # are different so the group status defaults to ON / OFF + ( + [ + ("test1.ent1", "off_water", "on_milk"), + ("test1.ent2", "off_water", "on_milk"), + ("test2.ent1", "off_wine", "on_milk"), + ("test2.ent2", "off_wine", "on_milk"), + ], + STATE_OFF, + STATE_ON, + ), + # Only test1 entities in group, all at ON state + # group returns the default ON state `on_beer` + ( + [ + ("test1.ent1", "off_water", "on_milk"), + ("test1.ent2", "off_water", "on_beer"), + ], + "off_water", + "on_beer", + ), + # Only test1 entities in group, all at ON state + # group returns the default ON state `on_beer` + ( + [ + ("test1.ent1", "off_water", "on_milk"), + ("test1.ent2", "off_water", "on_milk"), + ], + "off_water", + "on_beer", + ), + # Only test2 entities in group, all at ON state + # group returns the default ON state `on_milk` + ( + [ + ("test2.ent1", "off_wine", "on_milk"), + ("test2.ent2", "off_wine", "on_milk"), + ], + "off_wine", + "on_milk", + ), + ], +) +async def test_entity_platforms_with_multiple_on_states_no_state_match( + hass: HomeAssistant, + on_off_states1: tuple[set[str], str, str], + on_off_states2: tuple[set[str], str, str], + entity_and_state1_state_2: tuple[str, str | None, str | None], + group_state1: str, + group_state2: str, + grouped_groups: bool, +) -> None: + """Test custom entity platforms with multiple ON states without state match. + + The test group 1 an 2 non matching (default_state_on, state_off) pairs. + """ + await help_test_mixed_entity_platforms_on_off_state_test( + hass, + on_off_states1, + on_off_states2, + entity_and_state1_state_2, + group_state1, + group_state2, + grouped_groups, + ) + + +@pytest.mark.parametrize("grouped_groups", [False, True]) +@pytest.mark.parametrize( + ("on_off_states1", "on_off_states2"), + [ + ( + ( + { + "on_beer", + "on_milk", + }, + "on_beer", # default ON state test1 + "off_water", # default OFF state test1 + ), + ( + { + "on_beer", + "on_wine", + }, + "on_beer", # default ON state test2 + "off_water", # default OFF state test2 + ), + ), + ], +) +@pytest.mark.parametrize( + ("entity_and_state1_state_2", "group_state1", "group_state2"), + [ + # All OFF states, no change, so group stays OFF + ( + [ + ("test1.ent1", "off_water", "off_water"), + ("test1.ent2", "off_water", "off_water"), + ("test2.ent1", "off_water", "off_water"), + ("test2.ent2", "off_water", "off_water"), + ], + "off_water", + "off_water", + ), + # All entities have ON state `on_milk` + # but the group state will default to on_beer + # which is the default ON state for both integrations. + ( + [ + ("test1.ent1", "off_water", "on_milk"), + ("test1.ent2", "off_water", "on_milk"), + ("test2.ent1", "off_water", "on_milk"), + ("test2.ent2", "off_water", "on_milk"), + ], + "off_water", + "on_beer", + ), + # Only test1 entities in group, all at ON state + # group returns the default ON state `on_beer` + ( + [ + ("test1.ent1", "off_water", "on_milk"), + ("test1.ent2", "off_water", "on_beer"), + ], + "off_water", + "on_beer", + ), + # Only test1 entities in group, all at ON state + # group returns the default ON state `on_beer` + ( + [ + ("test1.ent1", "off_water", "on_milk"), + ("test1.ent2", "off_water", "on_milk"), + ], + "off_water", + "on_beer", + ), + # Only test2 entities in group, all at ON state + # group returns the default ON state `on_milk` + ( + [ + ("test2.ent1", "off_water", "on_wine"), + ("test2.ent2", "off_water", "on_wine"), + ], + "off_water", + "on_beer", + ), + ], +) +async def test_entity_platforms_with_multiple_on_states_with_state_match( + hass: HomeAssistant, + on_off_states1: tuple[set[str], str, str], + on_off_states2: tuple[set[str], str, str], + entity_and_state1_state_2: tuple[str, str | None, str | None], + group_state1: str, + group_state2: str, + grouped_groups: bool, +) -> None: + """Test custom entity platforms with multiple ON states with a state match. + + The integrations test1 and test2 have matching (default_state_on, state_off) pairs. + """ + await help_test_mixed_entity_platforms_on_off_state_test( + hass, + on_off_states1, + on_off_states2, + entity_and_state1_state_2, + group_state1, + group_state2, + grouped_groups, + )