-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapp.py
More file actions
165 lines (133 loc) · 6.16 KB
/
app.py
File metadata and controls
165 lines (133 loc) · 6.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from __future__ import annotations
import asyncio
from logging import getLogger
from typing import Awaitable, Callable, Optional
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from .speech_recognition import create_speech_recognizer
from .speech_synthesis import create_speech_synthesizer
from .types import SpeechRecognizer, SpeechSynthesizer
from .ws_proxy import WsProxy
logger = getLogger(__name__)
class StackChanInfo(BaseModel):
ip: str
state: str
class SpeakRequest(BaseModel):
text: str
class StackChanApp:
def __init__(
self,
speech_recognizer: SpeechRecognizer | None = None,
speech_synthesizer: SpeechSynthesizer | None = None,
) -> None:
self.speech_recognizer = speech_recognizer or create_speech_recognizer()
self.speech_synthesizer = speech_synthesizer or create_speech_synthesizer()
self.fastapi = FastAPI(title="StackChan WebSocket Server")
self._setup_fn: Optional[Callable[[WsProxy], Awaitable[None]]] = None
self._talk_session_fn: Optional[Callable[[WsProxy], Awaitable[None]]] = None
self._proxies: dict[str, WsProxy] = {}
self._proxies_lock = asyncio.Lock()
@self.fastapi.get("/health")
async def _health() -> dict[str, str]:
return {"status": "ok"}
@self.fastapi.websocket("/ws/stackchan")
async def _ws_audio(websocket: WebSocket):
await self._handle_ws(websocket)
@self.fastapi.get("/v1/stackchan", response_model=list[StackChanInfo])
async def _list_stackchans():
return await self._list_stackchan_infos()
@self.fastapi.get("/v1/stackchan/{stackchan_ip}", response_model=StackChanInfo)
async def _get_stackchan(stackchan_ip: str):
proxy = await self._get_proxy(stackchan_ip)
if proxy is None:
raise HTTPException(status_code=404, detail="stackchan not connected")
return StackChanInfo(ip=stackchan_ip, state=proxy.current_state.name.lower())
@self.fastapi.post("/v1/stackchan/{stackchan_ip}/wakeword", status_code=204)
async def _trigger_wakeword(stackchan_ip: str):
proxy = await self._get_proxy(stackchan_ip)
if proxy is None:
raise HTTPException(status_code=404, detail="stackchan not connected")
proxy.trigger_wakeword()
@self.fastapi.post("/v1/stackchan/{stackchan_ip}/speak", status_code=204)
async def _speak(stackchan_ip: str, body: SpeakRequest):
proxy = await self._get_proxy(stackchan_ip)
if proxy is None:
raise HTTPException(status_code=404, detail="stackchan not connected")
await proxy.speak(body.text)
def setup(self, fn: Callable[["WsProxy"], Awaitable[None]]):
self._setup_fn = fn
return fn
def talk_session(self, fn: Callable[["WsProxy"], Awaitable[None]]):
self._talk_session_fn = fn
return fn
async def _handle_ws(self, websocket: WebSocket) -> None:
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown"
proxy = WsProxy(
websocket,
speech_recognizer=self.speech_recognizer,
speech_synthesizer=self.speech_synthesizer,
)
existing = await self._register_proxy(client_ip, proxy)
await proxy.start()
if existing is not None and existing is not proxy:
logger.info("Replacing existing connection from %s", client_ip)
await existing.close()
try:
if self._setup_fn:
await self._setup_fn(proxy)
while not proxy.closed:
if not self._talk_session_fn:
await asyncio.sleep(0.05)
else:
await proxy.wait_for_talk_session()
disconnected = False
try:
await self._talk_session_fn(proxy)
except WebSocketDisconnect:
disconnected = True
raise
except Exception:
logger.exception("talk_session failed")
finally:
if not disconnected and not proxy.closed:
try:
await proxy.reset_state()
except WebSocketDisconnect:
disconnected = True
except Exception:
logger.exception("reset_state failed")
if proxy.receive_task and proxy.receive_task.done():
break
except WebSocketDisconnect:
pass
finally:
await proxy.close()
await self._unregister_proxy(client_ip, proxy)
async def _list_stackchan_infos(self) -> list[StackChanInfo]:
async with self._proxies_lock:
return [
StackChanInfo(ip=ip, state=proxy.current_state.name.lower())
for ip, proxy in self._proxies.items()
if not proxy.closed
]
async def _get_proxy(self, client_ip: str) -> WsProxy | None:
async with self._proxies_lock:
proxy = self._proxies.get(client_ip)
if proxy is None or proxy.closed:
return None
return proxy
async def _register_proxy(self, client_ip: str, proxy: WsProxy) -> WsProxy | None:
async with self._proxies_lock:
existing = self._proxies.get(client_ip)
self._proxies[client_ip] = proxy
return existing
async def _unregister_proxy(self, client_ip: str, proxy: WsProxy) -> None:
async with self._proxies_lock:
if self._proxies.get(client_ip) is proxy:
self._proxies.pop(client_ip, None)
def run(self, host: str = "0.0.0.0", port: int = 8000, reload: bool = True) -> None:
import uvicorn
# When passing an app instance, reload has no effect; kept for API compatibility.
uvicorn.run(self.fastapi, host=host, port=port, reload=reload)
__all__ = ["StackChanApp"]