impress_sig_mesh_hacs/custom_components/sigmesh_gateway/provisioning.py
impressionyang d21e7f1b3f feat: 添加配网和分组管理功能
新增功能:
- 配网管理模块 (provisioning.py): 支持设备扫描、配网、超时处理
- 配网配置步骤: UI 配置流程增加配网参数配置(Network Key, App Key 等)
- 分组管理:支持 SIG 分组和 VENDOR 分组的加入/删除操作
- HA 服务调用:7 个配网和分组相关的服务

文件变更:
- const.py: 添加配网相关常量(CONF_NETWORK_KEY, PROV_TIMEOUT 等)
- config_flow.py: 增加 prov_config 配置步骤和 OptionsFlow 菜单
- provisioning.py: 新建配网管理器(ProvisioningManager 类)
- coordinator.py: 集成配网管理器,添加配网状态管理方法
- services.py: 新建服务定义和注册
- services.yaml: HA 服务定义文件
- __init__.py: 集成服务注册和卸载
- PRD.md: 更新服务调用接口和配置参数文档

配网功能说明:
- 首次使用需配置 Network Key, App Key, Network ID, IV Index
- 配网超时时间:180 秒
- 组地址范围:0xC000 - 0xCFFF
- 支持 SIG 标准分组和 VENDOR 自定义分组
2026-04-16 12:05:13 +08:00

473 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""SigMesh Gateway 配网管理模块."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Callable
from .const import (
DEFAULT_APP_KEY,
DEFAULT_GROUP_ADDRESS_START,
DEFAULT_NETWORK_KEY,
DEFAULT_NETWORK_ID,
MeshSigOp,
PROV_TIMEOUT,
)
from .serial_reader import SerialReader
_LOGGER = logging.getLogger(__name__)
class ProvState(Enum):
"""配网状态."""
IDLE = "idle" # 空闲
SCANNING = "scanning" # 扫描中
PROV_STARTING = "prov_starting" # 配网启动中
PROV_IN_PROGRESS = "prov_in_progress" # 配网进行中
PROV_COMPLETED = "prov_completed" # 配网完成
PROV_FAILED = "prov_failed" # 配网失败
TIMEOUT = "timeout" # 超时
@dataclass
class ProvDevice:
"""配网设备信息."""
mac_address: str
element_count: int
unicast_address: int | None = None
model_id: int | None = None
joined_at: datetime | None = None
@dataclass
class GroupConfig:
"""组配置信息."""
group_address: int
model_id: int
element_address: int
class ProvisioningManager:
"""配网管理器."""
def __init__(
self,
serial_reader: SerialReader,
network_key: str = DEFAULT_NETWORK_KEY,
app_key: str = DEFAULT_APP_KEY,
network_id: str = DEFAULT_NETWORK_ID,
) -> None:
"""初始化配网管理器."""
self.serial_reader = serial_reader
self.network_key = network_key
self.app_key = app_key
self.network_id = network_id
self._state = ProvState.IDLE
self._devices: dict[str, ProvDevice] = {}
self._group_configs: dict[int, list[GroupConfig]] = {}
self._prov_timeout_handle: asyncio.TimerHandle | None = None
self._scan_result: list[dict] | None = None
# 回调
self._on_state_change_callback: Callable[[ProvState], None] | None = None
self._on_device_found_callback: Callable[[ProvDevice], None] | None = None
self._on_prov_complete_callback: Callable[[ProvDevice], None] | None = None
@property
def state(self) -> ProvState:
"""获取当前配网状态."""
return self._state
@property
def devices(self) -> dict[str, ProvDevice]:
"""获取已配网设备列表."""
return self._devices
def set_callbacks(
self,
on_state_change: Callable[[ProvState], None] | None = None,
on_device_found: Callable[[ProvDevice], None] | None = None,
on_prov_complete: Callable[[ProvDevice], None] | None = None,
) -> None:
"""设置回调函数."""
self._on_state_change_callback = on_state_change
self._on_device_found_callback = on_device_found
self._on_prov_complete_callback = on_prov_complete
def _set_state(self, state: ProvState) -> None:
"""设置配网状态."""
self._state = state
_LOGGER.info("配网状态变更:%s", state.value)
if self._on_state_change_callback:
self._on_state_change_callback(state)
def _start_prov_timeout(self) -> None:
"""启动配网超时计时器."""
if self._prov_timeout_handle:
self._prov_timeout_handle.cancel()
async def timeout_handler() -> None:
_LOGGER.warning("配网超时(%d 秒)", PROV_TIMEOUT)
self._set_state(ProvState.TIMEOUT)
await self.stop_provisioning()
loop = asyncio.get_event_loop()
self._prov_timeout_handle = loop.call_later(PROV_TIMEOUT, lambda: asyncio.create_task(timeout_handler()))
def _cancel_prov_timeout(self) -> None:
"""取消配网超时计时器."""
if self._prov_timeout_handle:
self._prov_timeout_handle.cancel()
self._prov_timeout_handle = None
async def start_scanning(self) -> None:
"""开始扫描设备."""
if self._state not in [ProvState.IDLE, ProvState.PROV_COMPLETED, ProvState.PROV_FAILED]:
_LOGGER.warning("无法开始扫描,当前状态:%s", self._state.value)
return
self._set_state(ProvState.SCANNING)
self._devices = {}
self._scan_result = []
# 发送扫描命令
# 注意:实际扫描由网关自动广播触发,这里只需等待设备上报
_LOGGER.info("开始扫描设备,等待设备上报...")
async def start_provisioning(self, device_address: str) -> None:
"""开始配网指定设备。
Args:
device_address: 设备地址16 进制字符串)
"""
if self._state != ProvState.SCANNING:
_LOGGER.warning("无法开始配网,当前状态:%s", self._state.value)
return
self._set_state(ProvState.PROV_STARTING)
try:
# 1. 发送配网启动命令
# 格式AT+PROV=START,<address>
cmd = f"AT+PROV=START,{device_address}"
await self.serial_reader.write_command(cmd)
# 2. 启动超时计时器
self._start_prov_timeout()
# 3. 等待配网完成
self._set_state(ProvState.PROV_IN_PROGRESS)
except Exception as e:
_LOGGER.error("启动配网失败:%s", e)
self._set_state(ProvState.PROV_FAILED)
self._cancel_prov_timeout()
async def stop_provisioning(self) -> None:
"""停止配网."""
self._cancel_prov_timeout()
try:
# 发送停止配网命令
cmd = "AT+PROV=STOP"
await self.serial_reader.write_command(cmd)
except Exception as e:
_LOGGER.warning("停止配网命令失败:%s", e)
self._set_state(ProvState.IDLE)
async def bind_app_key(self, device_address: str, element_address: int) -> None:
"""绑定 App Key。
Args:
device_address: 设备地址
element_address: 元素地址
"""
try:
# 发送绑定 App Key 命令
# 格式AT+PROV=BIND,<device_address>,<element_address>
cmd = f"AT+PROV=BIND,{device_address},{element_address}"
await self.serial_reader.write_command(cmd)
_LOGGER.info("绑定 App Key 完成:%s, 元素:%d", device_address, element_address)
except Exception as e:
_LOGGER.error("绑定 App Key 失败:%s", e)
async def add_to_group(
self,
target_address: str,
element_address: int,
group_address: int,
model_id: int,
is_sig: bool = True,
) -> None:
"""添加设备到组。
Args:
target_address: 目标设备地址
element_address: 元素地址
group_address: 组地址
model_id: Model ID
is_sig: 是否为 SIG 标准分组
"""
try:
if is_sig:
# SIG 分组命令
# e8 ff 00 00 00 00 02 01 <target_addr> 80 1b <element_addr> <group_addr> <model_id>
target_bytes = bytes.fromhex(target_address.zfill(4))
element_bytes = element_address.to_bytes(2, "big")
group_bytes = group_address.to_bytes(2, "big")
model_bytes = model_id.to_bytes(2, "big")
# 构建命令帧
cmd_frame = (
b"\xe8\xff\x00\x00\x00\x00\x02\x01"
+ target_bytes
+ b"\x80\x1b"
+ element_bytes
+ group_bytes
+ model_bytes
+ b"\x00\x10"
)
# 转换为 16 进制字符串发送
cmd_hex = cmd_frame.hex().upper()
cmd = f"AT+MESH=TX,{cmd_hex}"
else:
# VENDOR 分组命令
# e8 ff 00 00 00 00 02 01 <target_addr> 80 1b <element_addr> <group_addr> <model_id> <fixed>
target_bytes = bytes.fromhex(target_address.zfill(4))
element_bytes = element_address.to_bytes(2, "big")
group_bytes = group_address.to_bytes(2, "big")
model_bytes = model_id.to_bytes(2, "big")
cmd_frame = (
b"\xe8\xff\x00\x00\x00\x00\x02\x01"
+ target_bytes
+ b"\x80\x1b"
+ element_bytes
+ group_bytes
+ model_bytes
+ b"\x00\x00"
)
cmd_hex = cmd_frame.hex().upper()
cmd = f"AT+MESH=TX,{cmd_hex}"
await self.serial_reader.write_command(cmd)
_LOGGER.info(
"添加设备到组:%s, 元素:%d, 组地址:0x%04X, Model:0x%04X",
target_address,
element_address,
group_address,
model_id,
)
# 记录组配置
if group_address not in self._group_configs:
self._group_configs[group_address] = []
self._group_configs[group_address].append(
GroupConfig(
group_address=group_address,
model_id=model_id,
element_address=element_address,
)
)
except Exception as e:
_LOGGER.error("添加设备到组失败:%s", e)
async def remove_from_group(
self,
target_address: str,
element_address: int,
group_address: int,
model_id: int,
is_sig: bool = True,
) -> None:
"""从组中移除设备。
Args:
target_address: 目标设备地址
element_address: 元素地址
group_address: 组地址
model_id: Model ID
is_sig: 是否为 SIG 标准分组
"""
try:
if is_sig:
# SIG 删除组命令
# e8 ff 00 00 00 00 02 01 <target_addr> 80 1d <element_addr> <group_addr> <model_id>
target_bytes = bytes.fromhex(target_address.zfill(4))
element_bytes = element_address.to_bytes(2, "big")
group_bytes = group_address.to_bytes(2, "big")
model_bytes = model_id.to_bytes(2, "big")
cmd_frame = (
b"\xe8\xff\x00\x00\x00\x00\x02\x01"
+ target_bytes
+ b"\x80\x1d"
+ element_bytes
+ group_bytes
+ model_bytes
+ b"\x00\x10"
)
cmd_hex = cmd_frame.hex().upper()
cmd = f"AT+MESH=TX,{cmd_hex}"
else:
# VENDOR 删除组命令(类似,使用 80 1d 操作码)
target_bytes = bytes.fromhex(target_address.zfill(4))
element_bytes = element_address.to_bytes(2, "big")
group_bytes = group_address.to_bytes(2, "big")
model_bytes = model_id.to_bytes(2, "big")
cmd_frame = (
b"\xe8\xff\x00\x00\x00\x00\x02\x01"
+ target_bytes
+ b"\x80\x1d"
+ element_bytes
+ group_bytes
+ model_bytes
+ b"\x00\x00"
)
cmd_hex = cmd_frame.hex().upper()
cmd = f"AT+MESH=TX,{cmd_hex}"
await self.serial_reader.write_command(cmd)
_LOGGER.info(
"从组中移除设备:%s, 元素:%d, 组地址:0x%04X, Model:0x%04X",
target_address,
element_address,
group_address,
model_id,
)
# 移除组配置记录
if group_address in self._group_configs:
self._group_configs[group_address] = [
cfg
for cfg in self._group_configs[group_address]
if cfg.element_address != element_address and cfg.model_id != model_id
]
if not self._group_configs[group_address]:
del self._group_configs[group_address]
except Exception as e:
_LOGGER.error("从组中移除设备失败:%s", e)
def get_next_group_address(self) -> int:
"""获取下一个可用组地址."""
used_addresses = set(self._group_configs.keys())
for addr in range(DEFAULT_GROUP_ADDRESS_START, DEFAULT_GROUP_ADDRESS_END + 1):
if addr not in used_addresses:
return addr
raise RuntimeError("组地址已用尽")
def get_group_config(self, group_address: int) -> list[GroupConfig] | None:
"""获取指定组地址的配置。
Args:
group_address: 组地址
Returns:
组配置列表None 表示未找到
"""
return self._group_configs.get(group_address)
async def send_vendor_command(
self,
target_address: str,
element_address: int,
opcode: int,
payload: bytes,
) -> None:
"""发送 VENDOR 命令。
Args:
target_address: 目标设备地址
element_address: 元素地址
opcode: 操作码
payload: 数据负载
"""
try:
# VENDOR 命令帧格式
target_bytes = bytes.fromhex(target_address.zfill(4))
element_bytes = element_address.to_bytes(2, "big")
opcode_bytes = opcode.to_bytes(2, "big")
cmd_frame = (
b"\xe8\xff\x00\x00\x00\x00\x02\x01"
+ target_bytes
+ opcode_bytes
+ element_bytes
+ payload
)
cmd_hex = cmd_frame.hex().upper()
cmd = f"AT+MESH=TX,{cmd_hex}"
await self.serial_reader.write_command(cmd)
_LOGGER.info(
"发送 VENDOR 命令:目标=%s, 元素=%d, Opcode=0x%04X",
target_address,
element_address,
opcode,
)
except Exception as e:
_LOGGER.error("发送 VENDOR 命令失败:%s", e)
async def handle_device_joined(self, mac_address: str, element_count: int) -> None:
"""处理设备加入事件。
Args:
mac_address: 设备 MAC 地址
element_count: 元素数量
"""
device = ProvDevice(
mac_address=mac_address,
element_count=element_count,
joined_at=datetime.now(),
)
self._devices[mac_address] = device
_LOGGER.info("设备加入:%s, 元素数量:%d", mac_address, element_count)
if self._on_device_found_callback:
self._on_device_found_callback(device)
# 配网完成,取消超时
if self._state == ProvState.PROV_IN_PROGRESS:
self._cancel_prov_timeout()
self._set_state(ProvState.PROV_COMPLETED)
if self._on_prov_complete_callback:
self._on_prov_complete_callback(device)
def handle_device_left(self, mac_address: str) -> None:
"""处理设备离开事件。
Args:
mac_address: 设备 MAC 地址
"""
if mac_address in self._devices:
del self._devices[mac_address]
_LOGGER.info("设备离开:%s", mac_address)
def get_device(self, mac_address: str) -> ProvDevice | None:
"""获取配网设备信息。
Args:
mac_address: 设备 MAC 地址
Returns:
配网设备信息None 表示未找到
"""
return self._devices.get(mac_address)