"""SigMesh Gateway 配网管理模块."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Callable
from .const import (
DEFAULT_APP_KEY,
DEFAULT_GROUP_ADDRESS_START,
DEFAULT_NETWORK_KEY,
DEFAULT_NETWORK_ID,
MeshSigOp,
PROV_TIMEOUT,
)
from .serial_reader import SerialReader
_LOGGER = logging.getLogger(__name__)
class ProvState(Enum):
"""配网状态."""
IDLE = "idle" # 空闲
SCANNING = "scanning" # 扫描中
PROV_STARTING = "prov_starting" # 配网启动中
PROV_IN_PROGRESS = "prov_in_progress" # 配网进行中
PROV_COMPLETED = "prov_completed" # 配网完成
PROV_FAILED = "prov_failed" # 配网失败
TIMEOUT = "timeout" # 超时
@dataclass
class ProvDevice:
"""配网设备信息."""
mac_address: str
element_count: int
unicast_address: int | None = None
model_id: int | None = None
joined_at: datetime | None = None
@dataclass
class GroupConfig:
"""组配置信息."""
group_address: int
model_id: int
element_address: int
class ProvisioningManager:
"""配网管理器."""
def __init__(
self,
serial_reader: SerialReader,
network_key: str = DEFAULT_NETWORK_KEY,
app_key: str = DEFAULT_APP_KEY,
network_id: str = DEFAULT_NETWORK_ID,
) -> None:
"""初始化配网管理器."""
self.serial_reader = serial_reader
self.network_key = network_key
self.app_key = app_key
self.network_id = network_id
self._state = ProvState.IDLE
self._devices: dict[str, ProvDevice] = {}
self._group_configs: dict[int, list[GroupConfig]] = {}
self._prov_timeout_handle: asyncio.TimerHandle | None = None
self._scan_result: list[dict] | None = None
# 回调
self._on_state_change_callback: Callable[[ProvState], None] | None = None
self._on_device_found_callback: Callable[[ProvDevice], None] | None = None
self._on_prov_complete_callback: Callable[[ProvDevice], None] | None = None
@property
def state(self) -> ProvState:
"""获取当前配网状态."""
return self._state
@property
def devices(self) -> dict[str, ProvDevice]:
"""获取已配网设备列表."""
return self._devices
def set_callbacks(
self,
on_state_change: Callable[[ProvState], None] | None = None,
on_device_found: Callable[[ProvDevice], None] | None = None,
on_prov_complete: Callable[[ProvDevice], None] | None = None,
) -> None:
"""设置回调函数."""
self._on_state_change_callback = on_state_change
self._on_device_found_callback = on_device_found
self._on_prov_complete_callback = on_prov_complete
def _set_state(self, state: ProvState) -> None:
"""设置配网状态."""
self._state = state
_LOGGER.info("配网状态变更:%s", state.value)
if self._on_state_change_callback:
self._on_state_change_callback(state)
def _start_prov_timeout(self) -> None:
"""启动配网超时计时器."""
if self._prov_timeout_handle:
self._prov_timeout_handle.cancel()
async def timeout_handler() -> None:
_LOGGER.warning("配网超时(%d 秒)", PROV_TIMEOUT)
self._set_state(ProvState.TIMEOUT)
await self.stop_provisioning()
loop = asyncio.get_event_loop()
self._prov_timeout_handle = loop.call_later(PROV_TIMEOUT, lambda: asyncio.create_task(timeout_handler()))
def _cancel_prov_timeout(self) -> None:
"""取消配网超时计时器."""
if self._prov_timeout_handle:
self._prov_timeout_handle.cancel()
self._prov_timeout_handle = None
async def start_scanning(self) -> None:
"""开始扫描设备."""
# 如果已经在扫描中,先重置状态(允许用户重新启动扫描)
if self._state == ProvState.SCANNING:
_LOGGER.info("扫描已在进行中,重置扫描状态...")
self._devices = {}
self._scan_result = []
return
if self._state not in [ProvState.IDLE, ProvState.PROV_COMPLETED, ProvState.PROV_FAILED]:
_LOGGER.warning("无法开始扫描,当前状态:%s", self._state.value)
return
self._set_state(ProvState.SCANNING)
self._devices = {}
self._scan_result = []
# 发送扫描命令:AT+PROV=SCAN
# 网关会开始扫描周围的配网设备,设备响应后通过串口上报
try:
await self.serial_reader.write_command("AT+PROV=SCAN")
_LOGGER.info("已发送扫描命令,等待设备响应...")
except Exception as e:
_LOGGER.error("发送扫描命令失败:%s", e)
self._set_state(ProvState.PROV_FAILED)
async def start_provisioning(self, device_address: str) -> None:
"""开始配网指定设备。
Args:
device_address: 设备地址(16 进制字符串)
"""
if self._state != ProvState.SCANNING:
_LOGGER.warning("无法开始配网,当前状态:%s", self._state.value)
return
self._set_state(ProvState.PROV_STARTING)
try:
# 1. 发送配网启动命令
# 格式:AT+PROV=START,
cmd = f"AT+PROV=START,{device_address}"
await self.serial_reader.write_command(cmd)
# 2. 启动超时计时器
self._start_prov_timeout()
# 3. 等待配网完成
self._set_state(ProvState.PROV_IN_PROGRESS)
except Exception as e:
_LOGGER.error("启动配网失败:%s", e)
self._set_state(ProvState.PROV_FAILED)
self._cancel_prov_timeout()
async def stop_provisioning(self) -> None:
"""停止配网."""
self._cancel_prov_timeout()
try:
# 发送停止配网命令
cmd = "AT+PROV=STOP"
await self.serial_reader.write_command(cmd)
except Exception as e:
_LOGGER.warning("停止配网命令失败:%s", e)
self._set_state(ProvState.IDLE)
async def bind_app_key(self, device_address: str, element_address: int) -> None:
"""绑定 App Key。
Args:
device_address: 设备地址
element_address: 元素地址
"""
try:
# 发送绑定 App Key 命令
# 格式:AT+PROV=BIND,,
cmd = f"AT+PROV=BIND,{device_address},{element_address}"
await self.serial_reader.write_command(cmd)
_LOGGER.info("绑定 App Key 完成:%s, 元素:%d", device_address, element_address)
except Exception as e:
_LOGGER.error("绑定 App Key 失败:%s", e)
async def add_to_group(
self,
target_address: str,
element_address: int,
group_address: int,
model_id: int,
is_sig: bool = True,
) -> None:
"""添加设备到组。
Args:
target_address: 目标设备地址
element_address: 元素地址
group_address: 组地址
model_id: Model ID
is_sig: 是否为 SIG 标准分组
"""
try:
if is_sig:
# SIG 分组命令
# e8 ff 00 00 00 00 02 01 80 1b
target_bytes = bytes.fromhex(target_address.zfill(4))
element_bytes = element_address.to_bytes(2, "big")
group_bytes = group_address.to_bytes(2, "big")
model_bytes = model_id.to_bytes(2, "big")
# 构建命令帧
cmd_frame = (
b"\xe8\xff\x00\x00\x00\x00\x02\x01"
+ target_bytes
+ b"\x80\x1b"
+ element_bytes
+ group_bytes
+ model_bytes
+ b"\x00\x10"
)
# 转换为 16 进制字符串发送
cmd_hex = cmd_frame.hex().upper()
cmd = f"AT+MESH=TX,{cmd_hex}"
else:
# VENDOR 分组命令
# e8 ff 00 00 00 00 02 01 80 1b
target_bytes = bytes.fromhex(target_address.zfill(4))
element_bytes = element_address.to_bytes(2, "big")
group_bytes = group_address.to_bytes(2, "big")
model_bytes = model_id.to_bytes(2, "big")
cmd_frame = (
b"\xe8\xff\x00\x00\x00\x00\x02\x01"
+ target_bytes
+ b"\x80\x1b"
+ element_bytes
+ group_bytes
+ model_bytes
+ b"\x00\x00"
)
cmd_hex = cmd_frame.hex().upper()
cmd = f"AT+MESH=TX,{cmd_hex}"
await self.serial_reader.write_command(cmd)
_LOGGER.info(
"添加设备到组:%s, 元素:%d, 组地址:0x%04X, Model:0x%04X",
target_address,
element_address,
group_address,
model_id,
)
# 记录组配置
if group_address not in self._group_configs:
self._group_configs[group_address] = []
self._group_configs[group_address].append(
GroupConfig(
group_address=group_address,
model_id=model_id,
element_address=element_address,
)
)
except Exception as e:
_LOGGER.error("添加设备到组失败:%s", e)
async def remove_from_group(
self,
target_address: str,
element_address: int,
group_address: int,
model_id: int,
is_sig: bool = True,
) -> None:
"""从组中移除设备。
Args:
target_address: 目标设备地址
element_address: 元素地址
group_address: 组地址
model_id: Model ID
is_sig: 是否为 SIG 标准分组
"""
try:
if is_sig:
# SIG 删除组命令
# e8 ff 00 00 00 00 02 01 80 1d
target_bytes = bytes.fromhex(target_address.zfill(4))
element_bytes = element_address.to_bytes(2, "big")
group_bytes = group_address.to_bytes(2, "big")
model_bytes = model_id.to_bytes(2, "big")
cmd_frame = (
b"\xe8\xff\x00\x00\x00\x00\x02\x01"
+ target_bytes
+ b"\x80\x1d"
+ element_bytes
+ group_bytes
+ model_bytes
+ b"\x00\x10"
)
cmd_hex = cmd_frame.hex().upper()
cmd = f"AT+MESH=TX,{cmd_hex}"
else:
# VENDOR 删除组命令(类似,使用 80 1d 操作码)
target_bytes = bytes.fromhex(target_address.zfill(4))
element_bytes = element_address.to_bytes(2, "big")
group_bytes = group_address.to_bytes(2, "big")
model_bytes = model_id.to_bytes(2, "big")
cmd_frame = (
b"\xe8\xff\x00\x00\x00\x00\x02\x01"
+ target_bytes
+ b"\x80\x1d"
+ element_bytes
+ group_bytes
+ model_bytes
+ b"\x00\x00"
)
cmd_hex = cmd_frame.hex().upper()
cmd = f"AT+MESH=TX,{cmd_hex}"
await self.serial_reader.write_command(cmd)
_LOGGER.info(
"从组中移除设备:%s, 元素:%d, 组地址:0x%04X, Model:0x%04X",
target_address,
element_address,
group_address,
model_id,
)
# 移除组配置记录
if group_address in self._group_configs:
self._group_configs[group_address] = [
cfg
for cfg in self._group_configs[group_address]
if cfg.element_address != element_address and cfg.model_id != model_id
]
if not self._group_configs[group_address]:
del self._group_configs[group_address]
except Exception as e:
_LOGGER.error("从组中移除设备失败:%s", e)
def get_next_group_address(self) -> int:
"""获取下一个可用组地址."""
used_addresses = set(self._group_configs.keys())
for addr in range(DEFAULT_GROUP_ADDRESS_START, DEFAULT_GROUP_ADDRESS_END + 1):
if addr not in used_addresses:
return addr
raise RuntimeError("组地址已用尽")
def get_group_config(self, group_address: int) -> list[GroupConfig] | None:
"""获取指定组地址的配置。
Args:
group_address: 组地址
Returns:
组配置列表,None 表示未找到
"""
return self._group_configs.get(group_address)
async def send_vendor_command(
self,
target_address: str,
element_address: int,
opcode: int,
payload: bytes,
) -> None:
"""发送 VENDOR 命令。
Args:
target_address: 目标设备地址
element_address: 元素地址
opcode: 操作码
payload: 数据负载
"""
try:
# VENDOR 命令帧格式
target_bytes = bytes.fromhex(target_address.zfill(4))
element_bytes = element_address.to_bytes(2, "big")
opcode_bytes = opcode.to_bytes(2, "big")
cmd_frame = (
b"\xe8\xff\x00\x00\x00\x00\x02\x01"
+ target_bytes
+ opcode_bytes
+ element_bytes
+ payload
)
cmd_hex = cmd_frame.hex().upper()
cmd = f"AT+MESH=TX,{cmd_hex}"
await self.serial_reader.write_command(cmd)
_LOGGER.info(
"发送 VENDOR 命令:目标=%s, 元素=%d, Opcode=0x%04X",
target_address,
element_address,
opcode,
)
except Exception as e:
_LOGGER.error("发送 VENDOR 命令失败:%s", e)
async def handle_device_joined(self, mac_address: str, element_count: int) -> None:
"""处理设备加入事件。
Args:
mac_address: 设备 MAC 地址
element_count: 元素数量
"""
device = ProvDevice(
mac_address=mac_address,
element_count=element_count,
joined_at=datetime.now(),
)
self._devices[mac_address] = device
_LOGGER.info("设备加入:%s, 元素数量:%d", mac_address, element_count)
if self._on_device_found_callback:
self._on_device_found_callback(device)
# 配网完成,取消超时
if self._state == ProvState.PROV_IN_PROGRESS:
self._cancel_prov_timeout()
self._set_state(ProvState.PROV_COMPLETED)
if self._on_prov_complete_callback:
self._on_prov_complete_callback(device)
def handle_device_left(self, mac_address: str) -> None:
"""处理设备离开事件。
Args:
mac_address: 设备 MAC 地址
"""
if mac_address in self._devices:
del self._devices[mac_address]
_LOGGER.info("设备离开:%s", mac_address)
def get_device(self, mac_address: str) -> ProvDevice | None:
"""获取配网设备信息。
Args:
mac_address: 设备 MAC 地址
Returns:
配网设备信息,None 表示未找到
"""
return self._devices.get(mac_address)