diff --git a/.github/prompts/fix-lint.md b/.github/prompts/fix-lint.md new file mode 100644 index 0000000..4ea0636 --- /dev/null +++ b/.github/prompts/fix-lint.md @@ -0,0 +1 @@ +make lint を実行して、検出されたエラーを修正して diff --git a/Makefile b/Makefile index 89a9ae3..318a6f4 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,7 @@ lint: uv run ruff check stackchan_server example_apps uv run ty check stackchan_server example_apps + +lint-fix: + uv run ruff check --fix stackchan_server example_apps + uv run ty check stackchan_server example_apps diff --git a/docs/protocols.md b/docs/protocols.md index 8994325..fd6573a 100644 --- a/docs/protocols.md +++ b/docs/protocols.md @@ -18,6 +18,8 @@ enum class MessageKind : uint8_t { WakeWordEvt = 4, // クライアント→サーバ(wake word 検知通知) StateEvt = 5, // クライアント→サーバ(現在状態通知) SpeakDoneEvt = 6, // クライアント→サーバ(発話完了通知) + ServoCmd = 7, // サーバ→クライアント(サーボ動作シーケンス指示) + ServoDoneEvt = 8, // クライアント→サーバ(サーボ動作シーケンス完了通知) }; enum class MessageType : uint8_t { @@ -92,6 +94,26 @@ struct __attribute__((packed)) WsHeader { - payload: 1 byte(`1=done`) - 役割: TTS再生が完了したことを通知する。`Idle` 遷移とは独立に扱える。 +### Downlink: kind = ServoCmd (7) + +- 方向: サーバー -> クライアント +- メッセージ種別: `DATA` のみ使用 +- payload: 1 つの「サーボ動作シーケンス」をまとめて送る + - `` + - 続いて `command_count` 個のコマンド + - `Sleep (op=0)`: `` + - `MoveX (op=1)`: `` + - `MoveY (op=2)`: `` +- ファームウェアは受信後すぐにキューへ積み、`loop()` 内で非同期に順次実行する。 +- 新しい `ServoCmd` を受信した場合、現在のシーケンスは置き換える。 + +### Uplink: kind = ServoDoneEvt (8) + +- 方向: クライアント -> サーバー +- メッセージ種別: `DATA` のみ使用 +- payload: 1 byte(`1=done`) +- 役割: 直前に受信した `ServoCmd` のシーケンス全体が完了したことを通知する。 + ### kind の拡張例 - AudioPcm (1): 現行の PCM16LE アップリンク @@ -100,6 +122,8 @@ struct __attribute__((packed)) WsHeader { - WakeWordEvt (4): wake word 検知通知 - StateEvt (5): 現在状態通知 - SpeakDoneEvt (6): 発話完了通知 +- ServoCmd (7): サーボ動作シーケンス指示 +- ServoDoneEvt (8): サーボ動作シーケンス完了通知 ### 簡易バイト例(AudioPcm / DATA) diff --git a/example_apps/echo.py b/example_apps/echo.py index 64a84f2..7218540 100644 --- a/example_apps/echo.py +++ b/example_apps/echo.py @@ -5,7 +5,9 @@ from logging import getLogger from stackchan_server.app import StackChanApp -from stackchan_server.speech_recognition import WhisperCppSpeechToText, WhisperServerSpeechToText +from stackchan_server.speech_recognition import ( + WhisperCppSpeechToText, +) from stackchan_server.speech_synthesis import VoiceVoxSpeechSynthesizer from stackchan_server.ws_proxy import EmptyTranscriptError, WsProxy @@ -17,14 +19,12 @@ ) def _create_app() -> StackChanApp: - whisper_server_url = os.getenv("STACKCHAN_WHISPER_SERVER_URL") - whisper_server_port = os.getenv("STACKCHAN_WHISPER_SERVER_PORT") whisper_model = os.getenv("STACKCHAN_WHISPER_MODEL") - if whisper_server_url or whisper_server_port: - return StackChanApp( - speech_recognizer=WhisperServerSpeechToText(server_url=whisper_server_url), - speech_synthesizer=VoiceVoxSpeechSynthesizer(), - ) + # if os.getenv("STACKCHAN_WHISPER_SERVER_URL") or os.getenv("STACKCHAN_WHISPER_SERVER_PORT"): + # return StackChanApp( + # speech_recognizer=WhisperServerSpeechToText(server_url=whisper_server_url), + # speech_synthesizer=VoiceVoxSpeechSynthesizer(), + # ) if whisper_model: return StackChanApp( speech_recognizer=WhisperCppSpeechToText( diff --git a/example_apps/echo_with_move.py b/example_apps/echo_with_move.py new file mode 100644 index 0000000..104b91e --- /dev/null +++ b/example_apps/echo_with_move.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import logging +import os +from logging import getLogger + +from stackchan_server.app import StackChanApp +from stackchan_server.speech_recognition import ( + WhisperCppSpeechToText, +) +from stackchan_server.speech_synthesis import VoiceVoxSpeechSynthesizer +from stackchan_server.ws_proxy import ( + EmptyTranscriptError, + ServoMoveType, + ServoWaitType, + WsProxy, +) + +logger = getLogger(__name__) +logging.basicConfig( + level=os.getenv("STACKCHAN_LOG_LEVEL", "INFO"), + format="%(asctime)s.%(msecs)03d %(levelname)s:%(name)s:%(message)s", + datefmt="%H:%M:%S", +) + +def _create_app() -> StackChanApp: + whisper_model = os.getenv("STACKCHAN_WHISPER_MODEL") + # if os.getenv("STACKCHAN_WHISPER_SERVER_URL") or os.getenv("STACKCHAN_WHISPER_SERVER_PORT"): + # return StackChanApp( + # speech_recognizer=WhisperServerSpeechToText(server_url=whisper_server_url), + # speech_synthesizer=VoiceVoxSpeechSynthesizer(), + # ) + if whisper_model: + return StackChanApp( + speech_recognizer=WhisperCppSpeechToText( + model_path=whisper_model, + ), + speech_synthesizer=VoiceVoxSpeechSynthesizer(), + ) + return StackChanApp() + + +app = _create_app() + + +@app.setup +async def setup(proxy: WsProxy): + logger.info("WebSocket connected") + await proxy.move_servo([(ServoMoveType.MOVE_Y, 90, 100)]) + + +@app.talk_session +async def talk_session(proxy: WsProxy): + while True: + try: + await proxy.move_servo([(ServoMoveType.MOVE_Y, 80, 100)]) + + text = await proxy.listen() + + await proxy.move_servo([ + (ServoMoveType.MOVE_Y, 100, 100), + (ServoWaitType.SLEEP, 200), + (ServoMoveType.MOVE_Y, 90, 100), + (ServoWaitType.SLEEP, 200), + (ServoMoveType.MOVE_Y, 100, 100), + (ServoWaitType.SLEEP, 200), + (ServoMoveType.MOVE_Y, 90, 100), + ]) + + except EmptyTranscriptError: + await proxy.move_servo([(ServoMoveType.MOVE_Y, 90, 100)]) + return + logger.info("Heard: %s", text) + await proxy.speak(text) + + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run("example_apps.echo:app.fastapi", host="0.0.0.0", port=8000, reload=True) diff --git a/firmware/include/protocols.hpp b/firmware/include/protocols.hpp index df44eca..b39838d 100644 --- a/firmware/include/protocols.hpp +++ b/firmware/include/protocols.hpp @@ -19,6 +19,8 @@ enum class MessageKind : uint8_t WakeWordEvt = 4, // wake word event (client -> server) StateEvt = 5, // current state event (client -> server) SpeakDoneEvt = 6, // speaking completed event (client -> server) + ServoCmd = 7, // servo command sequence (server -> client) + ServoDoneEvt = 8, // servo sequence completed event (client -> server) }; enum class MessageType : uint8_t @@ -46,3 +48,14 @@ enum class RemoteState : uint8_t Thinking = 2, Speaking = 3, }; + +// payload for kind=ServoCmd, messageType=DATA +// +// command op=Sleep: +// command op=MoveX/Y: +enum class ServoCommandOp : uint8_t +{ + Sleep = 0, + MoveX = 1, + MoveY = 2, +}; diff --git a/firmware/include/servo.hpp b/firmware/include/servo.hpp new file mode 100644 index 0000000..eff200f --- /dev/null +++ b/firmware/include/servo.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "protocols.hpp" + +class BodyServo +{ +public: + BodyServo() = default; + + void init(); + void loop(); + void resetSequence(); + + bool enqueueSequence(const uint8_t *payload, size_t payload_len); + bool isBusy() const; + void setCompletionCallback(std::function cb); + +private: + struct AxisMotion + { + Servo servo; + int16_t current_degree = 90; + int16_t start_degree = 90; + int16_t target_degree = 90; + uint32_t move_start_ms = 0; + uint32_t move_duration_ms = 0; + uint32_t last_update_ms = 0; + bool moving = false; + }; + + struct Step + { + ServoCommandOp op; + int8_t angle = 0; + int16_t duration_ms = 0; + }; + + bool ensureAttached(); + void updateAxis(AxisMotion &axis, uint32_t now); + void startMove(AxisMotion &axis, int8_t degree, int16_t duration_ms); + void startCurrentStep(uint32_t now); + void advanceStep(); + void completeSequence(); + + AxisMotion axis_x_{}; + AxisMotion axis_y_{}; + bool attached_ = false; + + std::vector steps_{}; + size_t current_step_index_ = 0; + bool sequence_active_ = false; + bool step_started_ = false; + uint32_t sleep_deadline_ms_ = 0; + std::function on_complete_{}; +}; diff --git a/firmware/src/idf_component.yml b/firmware/src/idf_component.yml new file mode 100644 index 0000000..d752765 --- /dev/null +++ b/firmware/src/idf_component.yml @@ -0,0 +1,2 @@ +dependencies: + idf: '>=5.1' diff --git a/firmware/src/main.cpp b/firmware/src/main.cpp index 44d7741..34fb063 100644 --- a/firmware/src/main.cpp +++ b/firmware/src/main.cpp @@ -14,6 +14,7 @@ #include "../include/listening.hpp" #include "../include/wake_up_word.hpp" #include "../include/display.hpp" +#include "../include/servo.hpp" //////////////////// 設定 //////////////////// const char *WIFI_SSID = WIFI_SSID_H; @@ -31,6 +32,7 @@ static Speaking speaking(stateMachine); static Listening listening(wsClient, stateMachine, SAMPLE_RATE); static WakeUpWord wakeUpWord(stateMachine, SAMPLE_RATE); static Display display(stateMachine); +static BodyServo servo; // Protocol types are defined in include/protocols.hpp namespace @@ -116,6 +118,15 @@ void notifySpeakDone() } } +void notifyServoDone() +{ + const uint8_t payload = 1; // done + if (!sendUplinkPacket(MessageKind::ServoDoneEvt, MessageType::DATA, &payload, sizeof(payload))) + { + log_w("Failed to send ServoDoneEvt"); + } +} + bool applyRemoteStateCommand(const uint8_t *body, size_t bodyLen) { if (body == nullptr || bodyLen < 1) @@ -144,6 +155,16 @@ bool applyRemoteStateCommand(const uint8_t *body, size_t bodyLen) return false; } } + +bool applyServoCommand(const uint8_t *body, size_t bodyLen) +{ + if (!servo.enqueueSequence(body, bodyLen)) + { + log_w("Failed to apply servo command"); + return false; + } + return true; +} } // namespace void connectWiFi() @@ -221,6 +242,16 @@ void handleWsEvent(WStype_t type, uint8_t *payload, size_t length) log_w("StateCmd unsupported msgType=%u", static_cast(rx.messageType)); } break; + case MessageKind::ServoCmd: + if (static_cast(rx.messageType) == MessageType::DATA) + { + applyServoCommand(body, rx_payload_len); + } + else + { + log_w("ServoCmd unsupported msgType=%u", static_cast(rx.messageType)); + } + break; default: // M5.Display.printf("WS bin kind=%u len=%d\n", (unsigned)rx.kind, (int)length); break; @@ -249,6 +280,10 @@ void setup() speaking.setSpeakFinishedCallback([]() { notifySpeakDone(); }); + servo.init(); + servo.setCompletionCallback([]() { + notifyServoDone(); + }); wakeUpWord.init(); wakeUpWord.setWakeWordDetectedCallback([]() { notifyWakeWordDetected(); @@ -304,6 +339,7 @@ void loop() M5.update(); wsClient.loop(); handleCommunicationTimeout(); + servo.loop(); StateMachine::State current = stateMachine.getState(); switch (current) diff --git a/firmware/src/servo.cpp b/firmware/src/servo.cpp new file mode 100644 index 0000000..ffe5782 --- /dev/null +++ b/firmware/src/servo.cpp @@ -0,0 +1,318 @@ +#include "servo.hpp" + +#include + +#include +#include +#include + +namespace +{ +constexpr int kServoXPin = 6; +constexpr int kServoYPin = 7; +constexpr int kServoPulseMinUs = 500; +constexpr int kServoPulseMaxUs = 2400; +constexpr int kServoFrequencyHz = 50; +constexpr uint32_t kEasingDivisionMs = 10; + +int16_t clampDegree(int16_t degree) +{ + return std::clamp(degree, 0, 180); +} + +uint32_t clampDuration(int16_t duration_ms) +{ + return duration_ms <= 0 ? 0U : static_cast(duration_ms); +} + +int16_t readInt16Le(const uint8_t *src) +{ + int16_t value = 0; + memcpy(&value, src, sizeof(value)); + return value; +} +} // namespace + +void BodyServo::init() +{ + if (!ensureAttached()) + { + log_w("Failed to attach servos"); + return; + } + + axis_x_.servo.write(axis_x_.current_degree); + axis_y_.servo.write(axis_y_.current_degree); +} + +void BodyServo::loop() +{ + if (!attached_) + { + return; + } + + uint32_t now = millis(); + updateAxis(axis_x_, now); + updateAxis(axis_y_, now); + + if (!sequence_active_ || current_step_index_ >= steps_.size()) + { + return; + } + + if (!step_started_) + { + startCurrentStep(now); + } + + if (!sequence_active_ || current_step_index_ >= steps_.size()) + { + return; + } + + const Step &step = steps_[current_step_index_]; + bool finished = false; + switch (step.op) + { + case ServoCommandOp::Sleep: + finished = static_cast(now - sleep_deadline_ms_) >= 0; + break; + case ServoCommandOp::MoveX: + finished = !axis_x_.moving; + break; + case ServoCommandOp::MoveY: + finished = !axis_y_.moving; + break; + default: + log_w("Unknown servo step op=%u", static_cast(step.op)); + finished = true; + break; + } + + if (finished) + { + advanceStep(); + } +} + +void BodyServo::resetSequence() +{ + steps_.clear(); + current_step_index_ = 0; + sequence_active_ = false; + step_started_ = false; + sleep_deadline_ms_ = 0; + axis_x_.moving = false; + axis_y_.moving = false; +} + +bool BodyServo::enqueueSequence(const uint8_t *payload, size_t payload_len) +{ + if (!ensureAttached()) + { + return false; + } + if (payload == nullptr || payload_len < 1) + { + log_w("ServoCmd payload too short: %u", static_cast(payload_len)); + return false; + } + + const uint8_t command_count = payload[0]; + size_t offset = 1; + std::vector parsed_steps; + parsed_steps.reserve(command_count); + + for (uint8_t i = 0; i < command_count; ++i) + { + if (offset >= payload_len) + { + log_w("ServoCmd truncated at command=%u", static_cast(i)); + return false; + } + + const ServoCommandOp op = static_cast(payload[offset++]); + Step step{}; + step.op = op; + + switch (op) + { + case ServoCommandOp::Sleep: + if (offset + sizeof(int16_t) > payload_len) + { + log_w("ServoCmd sleep truncated at command=%u", static_cast(i)); + return false; + } + step.duration_ms = readInt16Le(payload + offset); + offset += sizeof(int16_t); + break; + case ServoCommandOp::MoveX: + case ServoCommandOp::MoveY: + if (offset + sizeof(int8_t) + sizeof(int16_t) > payload_len) + { + log_w("ServoCmd move truncated at command=%u", static_cast(i)); + return false; + } + step.angle = static_cast(payload[offset]); + offset += sizeof(int8_t); + step.duration_ms = readInt16Le(payload + offset); + offset += sizeof(int16_t); + break; + default: + log_w("ServoCmd unknown op=%u", static_cast(op)); + return false; + } + + parsed_steps.push_back(step); + } + + if (offset != payload_len) + { + log_w("ServoCmd payload has %u trailing bytes", static_cast(payload_len - offset)); + return false; + } + + resetSequence(); + steps_ = std::move(parsed_steps); + + if (steps_.empty()) + { + completeSequence(); + return true; + } + + current_step_index_ = 0; + sequence_active_ = true; + step_started_ = false; + log_i("Accepted servo sequence commands=%u", static_cast(command_count)); + return true; +} + +bool BodyServo::isBusy() const +{ + return sequence_active_ || axis_x_.moving || axis_y_.moving; +} + +void BodyServo::setCompletionCallback(std::function cb) +{ + on_complete_ = std::move(cb); +} + +bool BodyServo::ensureAttached() +{ + if (attached_) + { + return true; + } + + axis_x_.servo.setPeriodHertz(kServoFrequencyHz); + axis_y_.servo.setPeriodHertz(kServoFrequencyHz); + + const bool x_ok = axis_x_.servo.attach(kServoXPin, kServoPulseMinUs, kServoPulseMaxUs) > 0; + const bool y_ok = axis_y_.servo.attach(kServoYPin, kServoPulseMinUs, kServoPulseMaxUs) > 0; + attached_ = x_ok && y_ok; + return attached_; +} + +void BodyServo::updateAxis(AxisMotion &axis, uint32_t now) +{ + if (!axis.moving) + { + return; + } + + const uint32_t elapsed = now - axis.move_start_ms; + if ((now - axis.last_update_ms) < kEasingDivisionMs && elapsed < axis.move_duration_ms) + { + return; + } + + if (elapsed >= axis.move_duration_ms) + { + axis.current_degree = axis.target_degree; + axis.servo.write(axis.current_degree); + axis.moving = false; + axis.last_update_ms = now; + return; + } + + const float progress = static_cast(elapsed) / static_cast(axis.move_duration_ms); + axis.current_degree = axis.start_degree + static_cast((axis.target_degree - axis.start_degree) * progress); + axis.servo.write(axis.current_degree); + axis.last_update_ms = now; +} + +void BodyServo::startMove(AxisMotion &axis, int8_t degree, int16_t duration_ms) +{ + axis.target_degree = clampDegree(degree); + axis.start_degree = axis.current_degree; + axis.move_start_ms = millis(); + axis.last_update_ms = axis.move_start_ms; + axis.move_duration_ms = clampDuration(duration_ms); + + if (axis.move_duration_ms == 0 || axis.start_degree == axis.target_degree) + { + axis.current_degree = axis.target_degree; + axis.servo.write(axis.current_degree); + axis.moving = false; + return; + } + + axis.moving = true; +} + +void BodyServo::startCurrentStep(uint32_t now) +{ + if (!sequence_active_ || current_step_index_ >= steps_.size()) + { + return; + } + + const Step &step = steps_[current_step_index_]; + step_started_ = true; + switch (step.op) + { + case ServoCommandOp::Sleep: + sleep_deadline_ms_ = now + clampDuration(step.duration_ms); + break; + case ServoCommandOp::MoveX: + startMove(axis_x_, step.angle, step.duration_ms); + break; + case ServoCommandOp::MoveY: + startMove(axis_y_, step.angle, step.duration_ms); + break; + default: + advanceStep(); + break; + } +} + +void BodyServo::advanceStep() +{ + if (!sequence_active_) + { + return; + } + + ++current_step_index_; + step_started_ = false; + if (current_step_index_ >= steps_.size()) + { + completeSequence(); + } +} + +void BodyServo::completeSequence() +{ + steps_.clear(); + current_step_index_ = 0; + sequence_active_ = false; + step_started_ = false; + sleep_deadline_ms_ = 0; + log_i("Servo sequence completed"); + if (on_complete_) + { + on_complete_(); + } +} diff --git a/platformio.ini b/platformio.ini index cf49f67..c97916c 100644 --- a/platformio.ini +++ b/platformio.ini @@ -60,6 +60,7 @@ lib_deps = adafruit/Adafruit NeoPixel@^1.15.2 Links2004/WebSockets@^2.7.2 ESP32Async/AsyncTCP@^3.4.10 + madhephaestus/ESP32Servo@^3.1.3 lib_ldf_mode = deep diff --git a/stackchan_server/ws_proxy.py b/stackchan_server/ws_proxy.py index 6550485..b92ef41 100644 --- a/stackchan_server/ws_proxy.py +++ b/stackchan_server/ws_proxy.py @@ -3,11 +3,12 @@ import asyncio import os import struct +from collections import deque from contextlib import suppress -from enum import IntEnum +from enum import IntEnum, StrEnum from logging import getLogger from pathlib import Path -from typing import Optional +from typing import Literal, Optional, Sequence, TypeAlias, cast from fastapi import WebSocket, WebSocketDisconnect @@ -25,8 +26,12 @@ _WS_HEADER_SIZE = struct.calcsize(_WS_HEADER_FMT) _DOWN_WAV_CHUNK = 4096 # bytes per WebSocket frame for synthesized audio (raw PCM) -_DOWN_SEGMENT_MILLIS = 2000 # duration of a single START-DATA-END segment in milliseconds -_DOWN_SEGMENT_STAGGER_MILLIS = _DOWN_SEGMENT_MILLIS // 2 # half interval for the second segment start +_DOWN_SEGMENT_MILLIS = ( + 2000 # duration of a single START-DATA-END segment in milliseconds +) +_DOWN_SEGMENT_STAGGER_MILLIS = ( + _DOWN_SEGMENT_MILLIS // 2 +) # half interval for the second segment start _LISTEN_AUDIO_TIMEOUT_SECONDS = 10.0 _DEBUG_RECORDING_ENABLED = os.getenv("DEBUG_RECODING") == "1" @@ -45,6 +50,8 @@ class _WsKind(IntEnum): WAKEWORD_EVT = 4 STATE_EVT = 5 SPEAK_DONE_EVT = 6 + SERVO_CMD = 7 + SERVO_DONE_EVT = 8 class _WsMsgType(IntEnum): @@ -52,6 +59,87 @@ class _WsMsgType(IntEnum): DATA = 2 END = 3 + +class _ServoOp(IntEnum): + SLEEP = 0 + MOVE_X = 1 + MOVE_Y = 2 + + +class ServoMoveType(StrEnum): + MOVE_X = "move_x" + MOVE_Y = "move_y" + + +class ServoWaitType(StrEnum): + SLEEP = "sleep" + + +ServoMoveCommand: TypeAlias = tuple[ + Literal["move_x", "move_y"] | ServoMoveType, int, int +] +ServoSleepCommand: TypeAlias = tuple[Literal["sleep"] | ServoWaitType, int] +ServoCommand: TypeAlias = ServoMoveCommand | ServoSleepCommand + + +def _ensure_range(value: int, *, minimum: int, maximum: int, label: str) -> int: + if not minimum <= value <= maximum: + raise ValueError(f"{label} must be between {minimum} and {maximum}: {value}") + return value + + +def _encode_servo_commands(commands: Sequence[ServoCommand]) -> bytes: + normalized = list(commands) + _ensure_range(len(normalized), minimum=0, maximum=255, label="servo command count") + + payload = bytearray() + payload.append(len(normalized)) + + for index, command in enumerate(normalized): + if len(command) == 2: + sleep_command = cast(ServoSleepCommand, command) + name, raw_duration_ms = sleep_command + name = str(name) + if name != "sleep": + raise ValueError( + f"unsupported servo command at index {index}: {name}" + ) + duration_ms = _ensure_range( + int(raw_duration_ms), + minimum=-32768, + maximum=32767, + label="sleep duration", + ) + payload.append(_ServoOp.SLEEP) + payload.extend(struct.pack(" bool: @@ -145,6 +236,37 @@ async def send_state_command(self, state_id: int | FirmwareState) -> None: async def reset_state(self) -> None: await self.send_state_command(FirmwareState.IDLE) + async def move_servo(self, commands: Sequence[ServoCommand]) -> None: + payload = _encode_servo_commands(commands) + previous_counter = self._servo_sent_counter + target_counter = previous_counter + 1 + self._servo_sent_counter = target_counter + self._pending_servo_wait_targets.append(target_counter) + try: + await self._send_packet(_WsKind.SERVO_CMD, _WsMsgType.DATA, payload) + except Exception: + if ( + self._pending_servo_wait_targets + and self._pending_servo_wait_targets[-1] == target_counter + ): + self._pending_servo_wait_targets.pop() + self._servo_sent_counter = previous_counter + raise + + async def wait_servo_complete(self, timeout_seconds: float | None = 120.0) -> None: + target_counter = ( + self._pending_servo_wait_targets.popleft() + if self._pending_servo_wait_targets + else self._servo_done_counter + 1 + ) + await self._wait_for_counter( + current=lambda: self._servo_done_counter, + min_counter=target_counter, + timeout_seconds=timeout_seconds, + is_closed=lambda: self._closed, + label="servo completed event", + ) + async def start(self) -> None: if self._receiving_task is None: self._receiving_task = asyncio.create_task(self._receive_loop()) @@ -184,7 +306,9 @@ async def _receive_loop(self) -> None: continue if msg_type == _WsMsgType.DATA: - if not await self._listener.handle_data(self.ws, payload_bytes, payload): + if not await self._listener.handle_data( + self.ws, payload_bytes, payload + ): break continue @@ -213,6 +337,10 @@ async def _receive_loop(self) -> None: self._handle_speak_done_event(msg_type, payload) continue + if kind == _WsKind.SERVO_DONE_EVT: + self._handle_servo_done_event(msg_type, payload) + continue + await self.ws.close(code=1003, reason="unsupported kind") break except WebSocketDisconnect: @@ -248,12 +376,25 @@ def _handle_speak_done_event(self, msg_type: int, payload: bytes) -> None: return self._speaker.handle_speak_done_event() + def _handle_servo_done_event(self, msg_type: int, payload: bytes) -> None: + if msg_type != _WsMsgType.DATA: + return + if len(payload) < 1: + return + self._servo_done_counter += 1 + logger.info("Received servo done event") + async def _send_state_command(self, state_id: int | FirmwareState) -> None: payload = struct.pack(" None: hdr = struct.pack( _WS_HEADER_FMT, - _WsKind.STATE_CMD.value, - _WsMsgType.DATA.value, + int(kind), + int(msg_type), 0, self._down_seq, len(payload), @@ -261,10 +402,38 @@ async def _send_state_command(self, state_id: int | FirmwareState) -> None: await self.ws.send_bytes(hdr + payload) self._down_seq += 1 + async def _wait_for_counter( + self, + *, + current, + min_counter: int, + timeout_seconds: float | None, + is_closed, + label: str, + ) -> None: + loop = asyncio.get_running_loop() + deadline = (loop.time() + timeout_seconds) if timeout_seconds else None + while True: + if current() >= min_counter: + return + if is_closed(): + raise WebSocketDisconnect() + if deadline and loop.time() >= deadline: + raise TimeoutError(f"Timed out waiting for {label}") + await asyncio.sleep(0.05) + def _next_down_seq(self) -> int: seq = self._down_seq self._down_seq += 1 return seq -__all__ = ["WsProxy", "FirmwareState", "TimeoutError", "EmptyTranscriptError"] +__all__ = [ + "WsProxy", + "FirmwareState", + "TimeoutError", + "EmptyTranscriptError", + "ServoCommand", + "ServoMoveType", + "ServoWaitType", +]