diff --git a/packages/loro-websocket/src/client/index.ts b/packages/loro-websocket/src/client/index.ts index de9f11e..a60b9bc 100644 --- a/packages/loro-websocket/src/client/index.ts +++ b/packages/loro-websocket/src/client/index.ts @@ -717,8 +717,8 @@ export class LoroWebsocketClient { void this.sendRejoinRequest(roomId, msg.roomId, adaptor, active.room, auth); } else { // Remove local room state so client does not auto-retry unless requested - this.cleanupRoom(msg.roomId, msg.crdt); this.emitRoomStatus(roomId, RoomJoinStatus.Error); + this.cleanupRoom(msg.roomId, msg.crdt); } break; } diff --git a/packages/loro-websocket/src/server/simple-server.ts b/packages/loro-websocket/src/server/simple-server.ts index 4a7c619..f4a6827 100644 --- a/packages/loro-websocket/src/server/simple-server.ts +++ b/packages/loro-websocket/src/server/simple-server.ts @@ -52,9 +52,7 @@ export interface SimpleServerConfig { * Optional handshake auth: called during WS HTTP upgrade. * Return true to accept, false to reject. */ - handshakeAuth?: ( - req: IncomingMessage - ) => boolean | Promise; + handshakeAuth?: (req: IncomingMessage) => boolean | Promise; } interface RoomDocument { @@ -191,7 +189,7 @@ export class SimpleServer { ); void closers - .catch(() => { }) + .catch(() => {}) .finally(() => { try { wss.close(() => { @@ -212,16 +210,16 @@ export class SimpleServer { private async gracefulCloseWebSocket(ws: WebSocket): Promise { try { await this.waitForSocketDrain(ws); - } catch { } + } catch {} try { ws.close(1001, "Server stopping"); - } catch { } + } catch {} setTimeout(() => { try { if (ws.readyState !== WebSocket.CLOSED) ws.terminate(); - } catch { } + } catch {} }, 50); } @@ -243,7 +241,11 @@ export class SimpleServer { } const buffered = readBufferedAmount(); - if (buffered == null || buffered <= 0 || Date.now() - start >= timeoutMs) { + if ( + buffered == null || + buffered <= 0 || + Date.now() - start >= timeoutMs + ) { resolve(); return; } @@ -346,7 +348,7 @@ export class SimpleServer { const joinResult = roomDoc.descriptor.adaptor.handleJoinRequest( roomDoc.data, - message.version, + message.version ); // Send join response with current document version @@ -368,8 +370,7 @@ export class SimpleServer { client ); const shouldBackfill = - (hasOthers || - roomDoc.descriptor.allowBackfillWhenNoOtherClients) && + (hasOthers || roomDoc.descriptor.allowBackfillWhenNoOtherClients) && joinResult.updates && joinResult.updates.length; @@ -406,7 +407,7 @@ export class SimpleServer { message.batchId, UpdateStatusCode.PayloadTooLarge, message.crdt, - message.roomId, + message.roomId ); return; } @@ -420,7 +421,7 @@ export class SimpleServer { message.batchId, UpdateStatusCode.PermissionDenied, message.crdt, - message.roomId, + message.roomId ); client.fragments.delete(message.batchId); return; @@ -435,7 +436,7 @@ export class SimpleServer { message.batchId, UpdateStatusCode.PermissionDenied, message.crdt, - message.roomId, + message.roomId ); return; } @@ -448,7 +449,7 @@ export class SimpleServer { try { const newDocumentData = roomDoc.descriptor.adaptor.applyUpdates( roomDoc.data, - message.updates, + message.updates ); roomDoc.data = newDocumentData; } catch (error) { @@ -458,12 +459,11 @@ export class SimpleServer { message.batchId, UpdateStatusCode.InvalidUpdate, message.crdt, - message.roomId, + message.roomId ); return; } - if (roomDoc.descriptor.shouldPersist) { roomDoc.dirty = true; } @@ -475,7 +475,7 @@ export class SimpleServer { message.batchId, UpdateStatusCode.Ok, message.crdt, - message.roomId, + message.roomId ); if (updatesForBroadcast.length > 0) { @@ -486,12 +486,7 @@ export class SimpleServer { updates: updatesForBroadcast, batchId: message.batchId, }; - this.broadcastToRoom( - message.roomId, - message.crdt, - outgoing, - client - ); + this.broadcastToRoom(message.roomId, message.crdt, outgoing, client); } } catch (error) { console.error(error); @@ -500,7 +495,7 @@ export class SimpleServer { message.batchId, UpdateStatusCode.Unknown, message.crdt, - message.roomId, + message.roomId ); } } @@ -518,13 +513,16 @@ export class SimpleServer { message.batchId, UpdateStatusCode.PermissionDenied, message.crdt, - message.roomId, + message.roomId ); return; } const batch = { - data: Array.from({ length: message.fragmentCount }, () => new Uint8Array()), + data: Array.from( + { length: message.fragmentCount }, + () => new Uint8Array() + ), totalSize: message.totalSizeBytes, received: 0, header: message, @@ -538,7 +536,7 @@ export class SimpleServer { message.batchId, UpdateStatusCode.FragmentTimeout, message.crdt, - message.roomId, + message.roomId ); }, 10000); @@ -556,7 +554,7 @@ export class SimpleServer { message.batchId, UpdateStatusCode.FragmentTimeout, message.crdt, - message.roomId, + message.roomId ); return; } @@ -586,7 +584,7 @@ export class SimpleServer { message.batchId, UpdateStatusCode.PermissionDenied, message.crdt, - message.roomId, + message.roomId ); return; } @@ -598,7 +596,7 @@ export class SimpleServer { message.batchId, UpdateStatusCode.PermissionDenied, message.crdt, - message.roomId, + message.roomId ); client.fragments.delete(message.batchId); return; @@ -612,7 +610,7 @@ export class SimpleServer { try { const newDocumentData = roomDoc.descriptor.adaptor.applyUpdates( roomDoc.data, - [totalData], + [totalData] ); roomDoc.data = newDocumentData; } catch (error) { @@ -622,7 +620,7 @@ export class SimpleServer { message.batchId, UpdateStatusCode.InvalidUpdate, message.crdt, - message.roomId, + message.roomId ); client.fragments.delete(message.batchId); return; @@ -637,7 +635,7 @@ export class SimpleServer { message.batchId, UpdateStatusCode.Ok, message.crdt, - message.roomId, + message.roomId ); // Broadcast original fragments to other clients in the room @@ -695,10 +693,7 @@ export class SimpleServer { if (descriptor.shouldPersist && this.config.onLoadDocument) { try { - const loaded = await this.config.onLoadDocument( - roomId, - crdtType - ); + const loaded = await this.config.onLoadDocument(roomId, crdtType); if (loaded) { data = loaded; } @@ -799,9 +794,10 @@ export class SimpleServer { return `${roomId}:${crdtType}`; } - private parseRoomKey( - roomKey: string - ): { roomId: string; crdtType: CrdtType } { + private parseRoomKey(roomKey: string): { + roomId: string; + crdtType: CrdtType; + } { const sep = roomKey.lastIndexOf(":"); if (sep === -1) { return { roomId: roomKey, crdtType: CrdtType.Loro }; diff --git a/rust/loro-websocket-server/src/lib.rs b/rust/loro-websocket-server/src/lib.rs index 02fc505..1afac7e 100644 --- a/rust/loro-websocket-server/src/lib.rs +++ b/rust/loro-websocket-server/src/lib.rs @@ -44,7 +44,7 @@ use loro::awareness::EphemeralStore; use loro::{ExportMode, LoroDoc}; pub use loro_protocol as protocol; use protocol::{ - try_decode, CrdtType, JoinErrorCode, Permission, ProtocolMessage, UpdateStatusCode, + try_decode, CrdtType, JoinErrorCode, Permission, ProtocolMessage, RoomErrorCode, UpdateStatusCode, }; use tracing::{debug, error, info, warn}; @@ -52,10 +52,13 @@ use tracing::{debug, error, info, warn}; const MAX_FRAGMENTS: u64 = 4096; // hard cap on number of fragments per batch const MAX_BATCH_BYTES: u64 = 64 * 1024 * 1024; // 64 MiB per batch +/// Key identifying a room by its CRDT type and room ID. #[derive(Clone, Debug, PartialEq, Eq)] -struct RoomKey { - crdt: CrdtType, - room: String, +pub struct RoomKey { + /// The CRDT type of the room. + pub crdt: CrdtType, + /// The room identifier. + pub room: String, } impl Hash for RoomKey { fn hash(&self, state: &mut H) { @@ -105,6 +108,27 @@ type SaveFuture = Pin> + Send + 'stat type LoadFn = Arc LoadFuture + Send + Sync>; type SaveFn = Arc) -> SaveFuture + Send + Sync>; +/// Arguments provided to `on_update`. +pub struct UpdateArgs { + pub workspace: String, + pub room: String, + pub crdt: CrdtType, + pub conn_id: u64, + pub updates: Vec>, + pub doc: Option, + pub ctx: Option, +} + +pub struct UpdatedDoc { + pub status: UpdateStatusCode, + pub ctx: Option, + pub doc: Option, +} + +type UpdateFuture = + Pin> + Send + 'static>>; +type UpdateFn = Arc) -> UpdateFuture + Send + Sync>; + /// Arguments provided to `authenticate`. pub struct AuthArgs { pub room: String, @@ -143,6 +167,7 @@ type CloseConnectionFn = pub struct ServerConfig { pub on_load_document: Option>, pub on_save_document: Option>, + pub on_update: Option>, pub save_interval_ms: Option, pub default_permission: Permission, pub authenticate: Option, @@ -161,8 +186,12 @@ pub struct ServerConfig { pub on_close_connection: Option, } -// CRDT document abstraction to reduce match-based branching -trait CrdtDoc: Send { +/// CRDT document abstraction to reduce match-based branching. +/// +/// This trait is implemented by different document types (Loro, Ephemeral, Elo). +/// You can use it to query document state through the `RoomDocState.doc` field. +pub trait CrdtDoc: Send { + /// Get the current version vector as bytes. fn get_version(&self) -> Vec { Vec::new() } @@ -185,6 +214,15 @@ trait CrdtDoc: Send { fn remove_when_last_subscriber_leaves(&self) -> bool { false } + fn get_loro_doc(&self) -> Option { + None + } + fn set_loro_doc(&mut self, _doc: LoroDoc) -> bool { + false + } + fn as_loro_doc_mut(&mut self) -> Option<&mut LoroDoc> { + None + } } struct LoroRoomDoc { @@ -213,6 +251,16 @@ impl CrdtDoc for LoroRoomDoc { fn import_snapshot(&mut self, data: &[u8]) { let _ = self.doc.import(data); } + fn get_loro_doc(&self) -> Option { + Some(self.doc.clone()) + } + fn set_loro_doc(&mut self, doc: LoroDoc) -> bool { + self.doc = doc; + true + } + fn as_loro_doc_mut(&mut self) -> Option<&mut LoroDoc> { + Some(&mut self.doc) + } } struct EphemeralRoomDoc { @@ -470,6 +518,7 @@ impl Default for ServerConfig { Self { on_load_document: None, on_save_document: None, + on_update: None, save_interval_ms: None, default_permission: Permission::Write, authenticate: None, @@ -479,21 +528,27 @@ impl Default for ServerConfig { } } -struct RoomDocState { - doc: Box, - dirty: bool, - ctx: Option, +/// State of a document in a room. +pub struct RoomDocState { + /// The underlying CRDT document (trait object). + pub doc: Box, + /// Whether the document has unsaved changes. + pub dirty: bool, + /// Optional application-specific context. + pub ctx: Option, } -struct Hub { - // room -> vec of (conn_id, sender) - subs: HashMap>, - // room -> document state (Loro persistent, Ephemeral in-memory, Elo index) - docs: HashMap>, +/// A hub managing subscriptions and documents for a single workspace. +pub struct Hub { + /// Room -> vec of (conn_id, sender). Use this to inspect connected clients per room. + pub subs: HashMap>, + /// Room -> document state. Use this to inspect loaded documents. + pub docs: HashMap>, config: ServerConfig, - // (conn_id, room) -> permission - perms: HashMap<(u64, RoomKey), Permission>, - workspace: String, + /// (conn_id, room) -> permission. Use this to inspect client permissions. + pub perms: HashMap<(u64, RoomKey), Permission>, + /// The workspace identifier. + pub workspace: String, // Fragment reassembly state: per room + batch id fragments: HashMap<(RoomKey, protocol::BatchId), FragmentBatch>, } @@ -701,6 +756,120 @@ where Some(data) } } + + /// Apply updates to a room's document and broadcast to all subscribers. + /// + /// This is useful for server-initiated updates (e.g., from HTTP endpoints). + /// Calls the `on_update` hook if configured, and respects its result. + /// Returns the number of subscribers the update was sent to, or an error. + /// + /// # Example + /// ```ignore + /// let hubs = registry.hubs().lock().await; + /// if let Some(hub) = hubs.get("my-workspace") { + /// let mut h = hub.lock().await; + /// let room = RoomKey { crdt: CrdtType::Loro, room: "my-room".into() }; + /// match h.push_update(&room, vec![update_bytes]).await { + /// Ok(n) => println!("Broadcasted to {} subscribers", n), + /// Err(e) => eprintln!("Failed: {}", e), + /// } + /// } + /// ``` + pub async fn push_update(&mut self, room: &RoomKey, updates: Vec>) -> Result { + // Check room exists + if !self.docs.contains_key(room) { + return Err("room not found".into()); + } + + // Call on_update hook if configured + let mut skip_apply = false; + if let Some(update_hook) = self.config.on_update.clone() { + let ctx = self.docs.get(room).and_then(|s| s.ctx.clone()); + let doc = self.docs.get(room).and_then(|s| s.doc.get_loro_doc()); + let args = UpdateArgs { + workspace: self.workspace.clone(), + room: room.room.clone(), + crdt: room.crdt, + conn_id: 0, // Server-initiated, no connection + updates: updates.clone(), + doc, + ctx, + }; + let mut result = (update_hook)(args).await; + + if result.status != UpdateStatusCode::Ok { + return Err(format!("on_update hook rejected: {:?}", result.status)); + } + + skip_apply = self.process_update_hook_result(room, &mut result); + } + + // Apply to doc (unless hook already did) + if !skip_apply { + if let Some(state) = self.docs.get_mut(room) { + state.doc.apply_updates(&updates)?; + if state.doc.should_persist() { + state.dirty = true; + } + } + } + + // Broadcast to all subscribers + let batch_id = next_batch_id(); + let msg = ProtocolMessage::DocUpdate { + crdt: room.crdt, + room_id: room.room.clone(), + updates, + batch_id, + }; + let encoded = match loro_protocol::encode(&msg) { + Ok(b) => b, + Err(e) => return Err(format!("encode failed: {:?}", e)), + }; + + let mut sent = 0usize; + if let Some(list) = self.subs.get_mut(room) { + let mut dead: HashSet = HashSet::new(); + for (id, tx) in list.iter() { + if tx.send(Message::Binary(encoded.clone().into())).is_err() { + dead.insert(*id); + } else { + sent += 1; + } + } + if !dead.is_empty() { + list.retain(|(id, _)| !dead.contains(id)); + } + } + + Ok(sent) + } + + fn process_update_hook_result( + &mut self, + room: &RoomKey, + result: &mut UpdatedDoc, + ) -> bool { + let mut replaced_doc = false; + if result.ctx.is_some() || result.doc.is_some() { + if let Some(state) = self.docs.get_mut(room) { + if let Some(new_ctx) = result.ctx.take() { + state.ctx = Some(new_ctx); + } + if result.status == UpdateStatusCode::Ok { + if let Some(new_doc) = result.doc.take() { + if state.doc.set_loro_doc(new_doc) { + state.dirty = true; + replaced_doc = true; + } + } + } else { + result.doc = None; + } + } + } + replaced_doc + } } struct FragmentBatch { @@ -796,7 +965,11 @@ fn send_ack( } } -struct HubRegistry { +/// Registry that manages all workspace hubs. +/// +/// This can be shared between your WebSocket server and HTTP endpoints +/// to expose information about connected clients and rooms. +pub struct HubRegistry { config: ServerConfig, hubs: tokio::sync::Mutex>>>>, } @@ -805,13 +978,366 @@ impl HubRegistry where DocCtx: Clone + Send + Sync + 'static, { - fn new(config: ServerConfig) -> Self { + /// Create a new hub registry with the given configuration. + pub fn new(config: ServerConfig) -> Self { Self { config, hubs: tokio::sync::Mutex::new(HashMap::new()), } } + /// Access the underlying hubs map. + /// + /// Returns a reference to the mutex-protected map of workspace ID -> Hub. + /// Use this to implement your own inspection logic in HTTP endpoints. + /// + /// # Example + /// ```ignore + /// let hubs = registry.hubs().lock().await; + /// for (workspace_id, hub) in hubs.iter() { + /// let h = hub.lock().await; + /// for (room_key, subscribers) in h.subs.iter() { + /// println!("Room {} has {} subscribers", room_key.room, subscribers.len()); + /// } + /// } + /// ``` + pub fn hubs(&self) -> &tokio::sync::Mutex>>>> { + &self.hubs + } + + /// Open a room, creating the hub and loading the document if needed. + /// + /// This is idempotent - if the room already exists, nothing happens. + /// Useful for pre-creating rooms before any WebSocket clients connect. + /// + /// # Example + /// ```ignore + /// // Pre-create a room so it's ready when clients connect + /// registry.open_room("my-workspace", CrdtType::Loro, "my-room").await; + /// ``` + pub async fn open_room(&self, workspace: &str, crdt: CrdtType, room_id: &str) { + let hub = self.get_or_create(workspace).await; + let mut h = hub.lock().await; + let room = RoomKey { + crdt, + room: room_id.to_string(), + }; + h.ensure_room_loaded(&room).await; + } + + /// Edit a Loro document directly on the server and notify subscribers. + /// + /// This method loads the room (if needed), runs the provided callback with a + /// reference to the underlying `LoroDoc`, exports a snapshot, and broadcasts + /// it to all subscribers. The callback should perform mutations and call + /// `commit()` when done so the snapshot captures the new state. + /// + /// After the edit, if the room has no subscribers, it will be saved (if dirty) + /// and closed to avoid leaving orphan rooms. If `force_close` is true, the room + /// will be closed even if it has subscribers. + pub async fn edit_loro_doc( + &self, + workspace: &str, + room_id: &str, + edit: F, + force_close: bool, + ) -> Result<(), String> + where + F: FnOnce(&LoroDoc) -> Result<(), String> + Send, + { + let hub = self.get_or_create(workspace).await; + let room = RoomKey { + crdt: CrdtType::Loro, + room: room_id.to_string(), + }; + + // Do the work, capturing whether we should close and the result + let (result, should_close) = { + let mut h = hub.lock().await; + h.ensure_room_loaded(&room).await; + + let Some(state) = h.docs.get_mut(&room) else { + // Room not found after ensure_room_loaded - shouldn't happen + return Err("room not found".into()); + }; + + let edit_result = { + let Some(doc) = state.doc.as_loro_doc_mut() else { + return Err("room is not a Loro document".into()); + }; + edit(doc) + }; + + if let Err(e) = edit_result { + let has_subs = h.subs.get(&room).map(|v| !v.is_empty()).unwrap_or(false); + (Err(e), force_close || !has_subs) + } else { + let state = h.docs.get_mut(&room).unwrap(); // safe: we just checked above + if state.doc.should_persist() { + state.dirty = true; + } + + let snapshot = state.doc.export_snapshot(); + let has_subs = h.subs.get(&room).map(|v| !v.is_empty()).unwrap_or(false); + + if let Some(snap) = snapshot { + if !snap.is_empty() { + let batch_id = next_batch_id(); + let msg = ProtocolMessage::DocUpdate { + crdt: CrdtType::Loro, + room_id: room.room.clone(), + updates: vec![snap], + batch_id, + }; + match loro_protocol::encode(&msg) { + Ok(encoded) => { + h.broadcast(&room, 0, Message::Binary(encoded.into())); + (Ok(()), force_close || !has_subs) + } + Err(e) => { + (Err(format!("encode failed: {:?}", e)), force_close || !has_subs) + } + } + } else { + (Ok(()), force_close || !has_subs) + } + } else { + (Ok(()), force_close || !has_subs) + } + } + }; + + // Close room if no subscribers or force_close requested (deferred until lock released) + if should_close { + self.close_room(workspace, CrdtType::Loro, room_id, force_close).await; + } + + result + } + + /// Close a room if it has no subscribers (or forcefully). + /// + /// Returns `true` if the room was closed, `false` if it has active subscribers + /// (when `force` is false) or didn't exist. Saves dirty documents before closing + /// if `on_save_document` is configured. + /// + /// When `force` is true, the room is closed even if there are active subscribers. + /// Their sender channels will be dropped (they won't receive further updates). + /// + /// # Example + /// ```ignore + /// // Close only if no subscribers + /// if registry.close_room("my-workspace", CrdtType::Loro, "my-room", false).await { + /// println!("Room closed"); + /// } + /// + /// // Force close regardless of subscribers + /// registry.close_room("my-workspace", CrdtType::Loro, "my-room", true).await; + /// ``` + pub async fn close_room(&self, workspace: &str, crdt: CrdtType, room_id: &str, force: bool) -> bool { + let hubs = self.hubs.lock().await; + let Some(hub) = hubs.get(workspace) else { + return false; + }; + let mut h = hub.lock().await; + let room = RoomKey { + crdt, + room: room_id.to_string(), + }; + + // Check if room has subscribers (unless forcing) + if !force { + if let Some(subs) = h.subs.get(&room) { + if !subs.is_empty() { + return false; + } + } + } + + // Notify subscribers before closing + if let Some(subs) = h.subs.get(&room) { + if !subs.is_empty() { + let err_msg = ProtocolMessage::RoomError { + crdt, + room_id: room_id.to_string(), + code: RoomErrorCode::Evicted, + message: "Room closed by server".to_string(), + }; + if let Ok(bytes) = loro_protocol::encode(&err_msg) { + for (_, tx) in subs.iter() { + let _ = tx.send(Message::Binary(bytes.clone().into())); + } + } + } + } + + // Save dirty document before closing if on_save_document is configured + if let Some(saver) = &self.config.on_save_document { + if let Some(state) = h.docs.get_mut(&room) { + if state.dirty && state.doc.should_persist() { + if let Some(snapshot) = state.doc.export_snapshot() { + let args = SaveDocArgs { + workspace: workspace.to_string(), + room: room_id.to_string(), + crdt, + data: snapshot, + ctx: state.ctx.clone(), + }; + match (saver)(args).await { + Ok(()) => { + state.dirty = false; + debug!(workspace=%workspace, room=%room_id, "saved room before closing"); + } + Err(e) => { + warn!(workspace=%workspace, room=%room_id, %e, "failed to save room before closing"); + } + } + } + } + } + } + + // Remove room state + h.docs.remove(&room); + h.subs.remove(&room); + h.perms.retain(|(_, k), _| k != &room); + h.fragments.retain(|(k, _), _| k != &room); + true + } + + /// Save a room's document if it has unsaved changes. + /// + /// Returns `Ok(true)` if the document was saved, `Ok(false)` if there was + /// nothing to save (not dirty or doesn't support persistence), or an error + /// if saving failed or `on_save_document` is not configured. + /// + /// # Example + /// ```ignore + /// match registry.save_room("my-workspace", CrdtType::Loro, "my-room").await { + /// Ok(true) => println!("Saved"), + /// Ok(false) => println!("Nothing to save"), + /// Err(e) => eprintln!("Save failed: {}", e), + /// } + /// ``` + pub async fn save_room(&self, workspace: &str, crdt: CrdtType, room_id: &str) -> Result { + let Some(saver) = &self.config.on_save_document else { + return Err("on_save_document not configured".into()); + }; + + let hubs = self.hubs.lock().await; + let Some(hub) = hubs.get(workspace) else { + return Err("workspace not found".into()); + }; + let mut h = hub.lock().await; + let room = RoomKey { + crdt, + room: room_id.to_string(), + }; + + let Some(state) = h.docs.get_mut(&room) else { + return Err("room not found".into()); + }; + + if !state.dirty || !state.doc.should_persist() { + return Ok(false); + } + + let Some(snapshot) = state.doc.export_snapshot() else { + return Ok(false); + }; + + let args = SaveDocArgs { + workspace: workspace.to_string(), + room: room_id.to_string(), + crdt, + data: snapshot, + ctx: state.ctx.clone(), + }; + + (saver)(args).await.map_err(|e| e)?; + state.dirty = false; + debug!(workspace=%workspace, room=%room_id, "room saved"); + Ok(true) + } + + /// Close a hub (workspace) if all its rooms have no subscribers (or forcefully). + /// + /// Returns `true` if the hub was closed, `false` if any room has active + /// subscribers (when `force` is false). Saves all dirty rooms before closing + /// if `on_save_document` is configured. + /// + /// When `force` is true, the hub is closed even if rooms have active subscribers. + /// Their sender channels will be dropped (they won't receive further updates). + /// + /// Note: The hub's saver task will stop when the hub is dropped (after + /// all Arc references are released). + pub async fn close_hub(&self, workspace: &str, force: bool) -> bool { + let mut hubs = self.hubs.lock().await; + let Some(hub) = hubs.get(workspace) else { + return false; + }; + + let mut h = hub.lock().await; + // Check if any room has subscribers (unless forcing) + if !force { + for subs in h.subs.values() { + if !subs.is_empty() { + return false; + } + } + } + + // Notify all subscribers in all rooms before closing + for (room, subs) in h.subs.iter() { + if !subs.is_empty() { + let err_msg = ProtocolMessage::RoomError { + crdt: room.crdt, + room_id: room.room.clone(), + code: RoomErrorCode::Evicted, + message: "Hub closed by server".to_string(), + }; + if let Ok(bytes) = loro_protocol::encode(&err_msg) { + for (_, tx) in subs.iter() { + let _ = tx.send(Message::Binary(bytes.clone().into())); + } + } + } + } + + // Save all dirty rooms before closing if on_save_document is configured + if let Some(saver) = &self.config.on_save_document { + let rooms: Vec = h.docs.keys().cloned().collect(); + for room in rooms { + if let Some(state) = h.docs.get_mut(&room) { + if state.dirty && state.doc.should_persist() { + if let Some(snapshot) = state.doc.export_snapshot() { + let args = SaveDocArgs { + workspace: workspace.to_string(), + room: room.room.clone(), + crdt: room.crdt, + data: snapshot, + ctx: state.ctx.clone(), + }; + match (saver)(args).await { + Ok(()) => { + state.dirty = false; + debug!(workspace=%workspace, room=%room.room, "saved room before closing hub"); + } + Err(e) => { + warn!(workspace=%workspace, room=%room.room, %e, "failed to save room before closing hub"); + } + } + } + } + } + } + } + drop(h); + + hubs.remove(workspace); + true + } + async fn get_or_create(&self, workspace: &str) -> Arc>> { let mut map = self.hubs.lock().await; if let Some(h) = map.get(workspace) { @@ -891,8 +1417,22 @@ pub async fn serve_incoming_with_config( where DocCtx: Clone + Send + Sync + 'static, { - let registry = Arc::new(HubRegistry::new(config.clone())); + let registry = Arc::new(HubRegistry::new(config)); + serve_incoming_with_registry(listener, registry).await +} +/// Serve a pre-bound listener using an existing registry. +/// +/// This allows you to share the registry with other parts of your application, +/// for example to expose HTTP endpoints that query the registry state. +/// +pub async fn serve_incoming_with_registry( + listener: TcpListener, + registry: Arc>, +) -> Result<(), Box> +where + DocCtx: Clone + Send + Sync + 'static, +{ loop { match listener.accept().await { Ok((stream, peer)) => { @@ -1297,24 +1837,52 @@ where if let Some(buf) = h.add_fragment_and_maybe_finish(&room, batch_id, index, fragment) { + + let mut skip_apply = false; + if let Some(update_hook) = h.config.on_update.clone() { + let ctx = h.docs.get(&room).and_then(|s| s.ctx.clone()); + let doc = h.docs.get(&room).and_then(|s| s.doc.get_loro_doc()); + drop(h); + let args = UpdateArgs { + workspace: workspace_id.clone(), + room: room.room.clone(), + crdt, + conn_id, + updates: vec![buf.clone()], + doc, + ctx, + }; + let mut update_hook_result = (update_hook)(args).await; + h = hub.lock().await; + skip_apply = h.process_update_hook_result(&room, &mut update_hook_result); + if update_hook_result.status != UpdateStatusCode::Ok { + send_ack(&tx, crdt, &room.room, batch_id, update_hook_result.status); + continue; + } + } + // On completion: parse and apply to stored doc state if applicable - let apply_result = match crdt { - CrdtType::Loro - | CrdtType::LoroEphemeralStore - | CrdtType::LoroEphemeralStorePersisted => { - let start = std::time::Instant::now(); - let res = h.apply_updates(&room, &[buf.clone()]); - let elapsed_ms = start.elapsed().as_millis(); - if res.is_ok() { - debug!(room=?room.room, updates=1, ms=%elapsed_ms, "applied reassembled updates"); + let apply_result = if skip_apply { + Ok(()) + } else { + match crdt { + CrdtType::Loro + | CrdtType::LoroEphemeralStore + | CrdtType::LoroEphemeralStorePersisted => { + let start = std::time::Instant::now(); + let res = h.apply_updates(&room, &[buf.clone()]); + let elapsed_ms = start.elapsed().as_millis(); + if res.is_ok() { + debug!(room=?room.room, updates=1, ms=%elapsed_ms, "applied reassembled updates"); + } + res } - res - } - CrdtType::Elo => { - // Apply as indexing-only - h.apply_updates(&room, &[buf.clone()]) + CrdtType::Elo => { + // Apply as indexing-only + h.apply_updates(&room, &[buf.clone()]) + } + _ => Ok(()), } - _ => Ok(()), }; if apply_result.is_ok() { @@ -1374,23 +1942,51 @@ where continue; } let mut h = hub.lock().await; - let apply_result = match crdt { - CrdtType::Loro - | CrdtType::LoroEphemeralStore - | CrdtType::LoroEphemeralStorePersisted => { - let start = std::time::Instant::now(); - let res = h.apply_updates(&room, &updates); - let elapsed_ms = start.elapsed().as_millis(); - if res.is_ok() { - debug!(room=?room.room, updates=%updates.len(), ms=%elapsed_ms, "applied and broadcast updates"); - } - res + + let mut skip_apply = false; + if let Some(update_hook) = h.config.on_update.clone() { + let ctx = h.docs.get(&room).and_then(|s| s.ctx.clone()); + let doc = h.docs.get(&room).and_then(|s| s.doc.get_loro_doc()); + drop(h); + let args = UpdateArgs { + workspace: workspace_id.clone(), + room: room.room.clone(), + crdt, + conn_id, + updates: updates.clone(), + doc, + ctx, + }; + let mut update_hook_result = (update_hook)(args).await; + h = hub.lock().await; + skip_apply = h.process_update_hook_result(&room, &mut update_hook_result); + if update_hook_result.status != UpdateStatusCode::Ok { + send_ack(&tx, crdt, &room.room, batch_id, update_hook_result.status); + continue; } - CrdtType::Elo => { - // Index headers only; payload remains opaque to server. - h.apply_updates(&room, &updates) + } + + let apply_result = if skip_apply { + Ok(()) + } else { + match crdt { + CrdtType::Loro + | CrdtType::LoroEphemeralStore + | CrdtType::LoroEphemeralStorePersisted => { + let start = std::time::Instant::now(); + let res = h.apply_updates(&room, &updates); + let elapsed_ms = start.elapsed().as_millis(); + if res.is_ok() { + debug!(room=?room.room, updates=%updates.len(), ms=%elapsed_ms, "applied and broadcast updates"); + } + res + } + CrdtType::Elo => { + // Index headers only; payload remains opaque to server. + h.apply_updates(&room, &updates) + } + _ => Ok(()), } - _ => Ok(()), }; if apply_result.is_ok() { diff --git a/rust/loro-websocket-server/tests/close_hook.rs b/rust/loro-websocket-server/tests/close_hook.rs new file mode 100644 index 0000000..e2f3675 --- /dev/null +++ b/rust/loro-websocket-server/tests/close_hook.rs @@ -0,0 +1,90 @@ +use loro_websocket_client::Client; +use loro_websocket_server as server; +use server::protocol::{CrdtType, ProtocolMessage}; +use std::sync::Arc; +use tokio::sync::{Mutex, Notify}; + +type Cfg = server::ServerConfig<()>; + +#[derive(Clone, Debug)] +struct CloseRecord { + workspace: String, + conn_id: u64, + rooms: Vec<(CrdtType, String)>, +} + +#[tokio::test(flavor = "current_thread")] +async fn on_close_connection_receives_workspace_and_rooms() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind tcp listener"); + let addr = listener.local_addr().expect("local addr"); + + let close_calls: Arc>> = Arc::new(Mutex::new(Vec::new())); + let notify = Arc::new(Notify::new()); + + let close_calls_cfg = close_calls.clone(); + let notify_cfg = notify.clone(); + + let server_task = tokio::spawn(async move { + let cfg: Cfg = server::ServerConfig { + handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))), + on_close_connection: Some(Arc::new(move |args: server::CloseConnectionArgs| { + let close_calls = close_calls_cfg.clone(); + let notify = notify_cfg.clone(); + Box::pin(async move { + let server::CloseConnectionArgs { + workspace, + conn_id, + rooms, + } = args; + close_calls.lock().await.push(CloseRecord { + workspace, + conn_id, + rooms, + }); + notify.notify_waiters(); + Ok(()) + }) + })), + ..Default::default() + }; + server::serve_incoming_with_config(listener, cfg) + .await + .expect("serve incoming"); + }); + + let url = format!("ws://{}/ws-close?token=secret", addr); + let mut client = Client::connect(&url).await.expect("connect client"); + let room_id = "close-room"; + let join = ProtocolMessage::JoinRequest { + crdt: CrdtType::Loro, + room_id: room_id.to_string(), + auth: Vec::new(), + version: Vec::new(), + }; + client.send(&join).await.expect("send join"); + match client.next().await.expect("join response") { + Some(ProtocolMessage::JoinResponseOk { .. }) => {} + other => panic!("unexpected response: {:?}", other), + } + + let notified = notify.notified(); + client.close().await.expect("close client"); + + tokio::time::timeout(std::time::Duration::from_secs(2), notified) + .await + .expect("close hook not called in time"); + + let calls = close_calls.lock().await; + assert_eq!(calls.len(), 1, "expected exactly one close hook call"); + let record = &calls[0]; + assert_eq!(record.workspace, "ws-close"); + assert!(record.conn_id > 0); + assert_eq!(record.rooms.len(), 1); + let (crdt, room) = &record.rooms[0]; + assert_eq!(*crdt, CrdtType::Loro); + assert_eq!(room, room_id); + + server_task.abort(); +} diff --git a/rust/loro-websocket-server/tests/edit_loro_doc.rs b/rust/loro-websocket-server/tests/edit_loro_doc.rs new file mode 100644 index 0000000..9a31a75 --- /dev/null +++ b/rust/loro-websocket-server/tests/edit_loro_doc.rs @@ -0,0 +1,92 @@ +use loro as loro_crdt; +use loro_websocket_client::Client; +use loro_websocket_server as server; +use loro_websocket_server::protocol::{self as proto, CrdtType}; +use std::sync::Arc; + +#[tokio::test(flavor = "current_thread")] +async fn edit_loro_doc_notifies_subscribers() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind listener"); + let addr = listener.local_addr().expect("local addr"); + let cfg: server::ServerConfig<()> = server::ServerConfig { + handshake_auth: Some(Arc::new(|_| true)), + ..Default::default() + }; + let registry: Arc> = Arc::new(server::HubRegistry::new(cfg)); + let server_registry = registry.clone(); + let server_task = tokio::spawn(async move { + server::serve_incoming_with_registry(listener, server_registry) + .await + .expect("server exited"); + }); + + let url = format!("ws://{}/workspace", addr); + let mut client = Client::connect(&url).await.expect("client connect"); + let room = "edit-room".to_string(); + let join = proto::ProtocolMessage::JoinRequest { + crdt: CrdtType::Loro, + room_id: room.clone(), + auth: Vec::new(), + version: Vec::new(), + }; + client.send(&join).await.expect("send join"); + + // Drain the join response before editing. + loop { + match client.next().await.expect("next message") { + Some(proto::ProtocolMessage::JoinResponseOk { .. }) => break, + Some(_) => continue, + None => panic!("connection closed while joining"), + } + } + + // Verify that the client is registered as a subscriber. + let hub_arc = { + let hubs_guard = registry.hubs().lock().await; + hubs_guard + .get("workspace") + .cloned() + .expect("workspace hub not found") + }; + let subscriber_count = { + let hub_guard = hub_arc.lock().await; + let key = server::RoomKey { + crdt: CrdtType::Loro, + room: room.clone(), + }; + hub_guard.subs.get(&key).map(|v| v.len()).unwrap_or(0) + }; + assert_eq!(subscriber_count, 1, "expected the client to be subscribed"); + + registry + .edit_loro_doc("workspace", &room, |doc| { + let text = doc.get_text("text"); + text.insert(0, "from-server").unwrap(); + doc.commit(); + Ok(()) + }, false) // force_close = false, room will stay open since it has a subscriber + .await + .expect("edit succeeded"); + + let mut got_update = false; + for _ in 0..4 { + if let Some(proto::ProtocolMessage::DocUpdate { updates, .. }) = + client.next().await.expect("next message after edit") + { + let doc = loro_crdt::LoroDoc::new(); + for data in updates { + let _ = doc.import(&data); + } + if doc.get_text("text").to_string() == "from-server" { + got_update = true; + break; + } + } + } + assert!(got_update, "client did not receive server edit"); + + drop(client); + server_task.abort(); +} diff --git a/rust/loro-websocket-server/tests/update_hook.rs b/rust/loro-websocket-server/tests/update_hook.rs new file mode 100644 index 0000000..40aa7b2 --- /dev/null +++ b/rust/loro-websocket-server/tests/update_hook.rs @@ -0,0 +1,462 @@ +use loro as loro_crdt; +use loro_websocket_client::Client; +use loro_websocket_server as server; +use server::protocol::{CrdtType, ProtocolMessage, UpdateStatusCode}; +use std::sync::Arc; +use tokio::sync::{Mutex, Notify}; +use tokio::time::{timeout, Duration}; + +type Cfg = server::ServerConfig<()>; + +#[derive(Clone, Debug)] +struct UpdateRecord { + workspace: String, + room: String, + crdt: CrdtType, + conn_id: u64, + updates_len: usize, +} + +#[tokio::test(flavor = "current_thread")] +async fn on_update_hook_called() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind tcp listener"); + let addr = listener.local_addr().expect("local addr"); + + let update_calls: Arc>> = Arc::new(Mutex::new(Vec::new())); + let notify = Arc::new(Notify::new()); + + let update_calls_cfg = update_calls.clone(); + let notify_cfg = notify.clone(); + + let server_task = tokio::spawn(async move { + let cfg: Cfg = server::ServerConfig { + on_update: Some(Arc::new(move |args: server::UpdateArgs<()>| { + let update_calls = update_calls_cfg.clone(); + let notify = notify_cfg.clone(); + Box::pin(async move { + let server::UpdateArgs { + workspace, + room, + crdt, + conn_id, + updates, + doc: _, + ctx: _, + } = args; + update_calls.lock().await.push(UpdateRecord { + workspace, + room, + crdt, + conn_id, + updates_len: updates.len(), + }); + notify.notify_waiters(); + server::UpdatedDoc { + status: UpdateStatusCode::Ok, + ctx: None, + doc: None, + } + }) + })), + // Use handshake auth to ensure workspace_id is captured + handshake_auth: Some(Arc::new(|_| true)), + ..Default::default() + }; + server::serve_incoming_with_config(listener, cfg) + .await + .unwrap(); + }); + + // Connect client + let url = format!("ws://{}/my-workspace", addr); + let mut client = Client::connect(&url).await.expect("connect"); + + // Join room + client + .send(&ProtocolMessage::JoinRequest { + crdt: CrdtType::Loro, + room_id: "room1".to_string(), + auth: vec![], + version: vec![], + }) + .await + .expect("send join"); + + // Wait for join response + match client.next().await.expect("recv") { + Some(ProtocolMessage::JoinResponseOk { .. }) => {} + msg => panic!("unexpected msg: {:?}", msg), + } + + // Send update + let update_payload = vec![1, 2, 3, 4]; + client + .send(&ProtocolMessage::DocUpdate { + crdt: CrdtType::Loro, + room_id: "room1".to_string(), + updates: vec![update_payload.clone()], + batch_id: server::protocol::BatchId([0; 8]), // dummy batch id + }) + .await + .expect("send update"); + + // Wait for hook to be called + notify.notified().await; + + let calls = update_calls.lock().await; + assert_eq!(calls.len(), 1); + let record = &calls[0]; + assert_eq!(record.workspace, "my-workspace"); + assert_eq!(record.room, "room1"); + assert_eq!(record.crdt, CrdtType::Loro); + assert_eq!(record.updates_len, 1); + // conn_id is dynamic, just check it's non-zero + assert!(record.conn_id > 0); + + server_task.abort(); +} + +#[tokio::test(flavor = "current_thread")] +async fn on_update_hook_can_reject() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind tcp listener"); + let addr = listener.local_addr().expect("local addr"); + + let server_task = tokio::spawn(async move { + let cfg: Cfg = server::ServerConfig { + on_update: Some(Arc::new(move |args: server::UpdateArgs<()>| { + Box::pin(async move { + let status = if args.room == "rejected" { + UpdateStatusCode::PermissionDenied + } else { + UpdateStatusCode::Ok + }; + server::UpdatedDoc { + status, + ctx: None, + doc: None, + } + }) + })), + ..Default::default() + }; + server::serve_incoming_with_config(listener, cfg) + .await + .unwrap(); + }); + + let url = format!("ws://{}/workspace", addr); + let mut c1 = Client::connect(&url).await.expect("c1 connect"); + let mut c2 = Client::connect(&url).await.expect("c2 connect"); + + // 1. Both join "rejected" + for c in [&mut c1, &mut c2] { + c.send(&ProtocolMessage::JoinRequest { + crdt: CrdtType::Loro, + room_id: "rejected".to_string(), + auth: vec![], + version: vec![], + }).await.expect("join rejected"); + match c.next().await.expect("recv") { + Some(ProtocolMessage::JoinResponseOk { .. }) => {}, + m => panic!("unexpected join response: {:?}", m), + } + // Consume snapshot + match c.next().await.expect("recv") { + Some(ProtocolMessage::DocUpdate { .. }) => {}, + m => panic!("expected initial snapshot, got: {:?}", m), + } + } + + // 2. C1 sends update to "rejected" -> Should be rejected + let batch_id_1 = server::protocol::BatchId([1; 8]); + c1.send(&ProtocolMessage::DocUpdate { + crdt: CrdtType::Loro, + room_id: "rejected".to_string(), + updates: vec![vec![1, 2, 3]], + batch_id: batch_id_1, + }).await.expect("send update 1"); + + // C1 gets Ack(PermissionDenied) + match c1.next().await.expect("recv") { + Some(ProtocolMessage::Ack { ref_id, status, .. }) => { + assert_eq!(ref_id, batch_id_1); + assert_eq!(status, UpdateStatusCode::PermissionDenied); + } + m => panic!("unexpected msg c1: {:?}", m), + } + + // 3. Both join "accepted" + for c in [&mut c1, &mut c2] { + c.send(&ProtocolMessage::JoinRequest { + crdt: CrdtType::Loro, + room_id: "accepted".to_string(), + auth: vec![], + version: vec![], + }).await.expect("join accepted"); + match c.next().await.expect("recv") { + Some(ProtocolMessage::JoinResponseOk { .. }) => {}, + m => panic!("unexpected join response: {:?}", m), + } + // Consume snapshot + match c.next().await.expect("recv") { + Some(ProtocolMessage::DocUpdate { .. }) => {}, + m => panic!("expected initial snapshot, got: {:?}", m), + } + } + + // 4. C1 sends update to "accepted" -> Should be accepted + let batch_id_2 = server::protocol::BatchId([2; 8]); + c1.send(&ProtocolMessage::DocUpdate { + crdt: CrdtType::Loro, + room_id: "accepted".to_string(), + updates: vec![vec![4, 5, 6]], + batch_id: batch_id_2, + }).await.expect("send update 2"); + + // C1 gets Ack(Ok) + match c1.next().await.expect("recv") { + Some(ProtocolMessage::Ack { ref_id, status, .. }) => { + assert_eq!(ref_id, batch_id_2); + assert_eq!(status, UpdateStatusCode::Ok); + } + m => panic!("unexpected msg c1: {:?}", m), + } + + // 5. C2 should receive the update for "accepted". + // Crucially, it should NOT have received the update for "rejected" before this. + match c2.next().await.expect("recv") { + Some(ProtocolMessage::DocUpdate { room_id, batch_id, .. }) => { + assert_eq!(room_id, "accepted"); + assert_eq!(batch_id, batch_id_2); + } + m => panic!("unexpected msg c2: {:?}", m), + } + + server_task.abort(); +} + +#[tokio::test(flavor = "current_thread")] +async fn on_update_hook_persists_ctx_between_calls() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind tcp listener"); + let addr = listener.local_addr().expect("local addr"); + + let seen_ctx: Arc>>> = Arc::new(Mutex::new(Vec::new())); + let notify = Arc::new(Notify::new()); + + let server_task = tokio::spawn({ + let seen_ctx = seen_ctx.clone(); + let notify = notify.clone(); + async move { + let cfg: server::ServerConfig = server::ServerConfig { + on_update: Some(Arc::new(move |args: server::UpdateArgs| { + let seen_ctx = seen_ctx.clone(); + let notify = notify.clone(); + Box::pin(async move { + { + let mut guard = seen_ctx.lock().await; + guard.push(args.ctx.clone()); + if guard.len() >= 2 { + notify.notify_waiters(); + } + } + server::UpdatedDoc { + status: UpdateStatusCode::Ok, + ctx: Some("persisted".to_string()), + doc: None, + } + }) + })), + ..Default::default() + }; + server::serve_incoming_with_config(listener, cfg) + .await + .unwrap(); + } + }); + + let url = format!("ws://{}/ctx", addr); + let mut client = Client::connect(&url).await.expect("connect"); + client + .send(&ProtocolMessage::JoinRequest { + crdt: CrdtType::Loro, + room_id: "room-ctx".to_string(), + auth: vec![], + version: vec![], + }) + .await + .expect("send join"); + + match client.next().await.expect("join response") { + Some(ProtocolMessage::JoinResponseOk { .. }) => {} + other => panic!("unexpected join response: {:?}", other), + } + match client.next().await.expect("snapshot") { + Some(ProtocolMessage::DocUpdate { .. }) => {} + other => panic!("expected initial snapshot, got {:?}", other), + } + + let first_batch = server::protocol::BatchId([3; 8]); + client + .send(&ProtocolMessage::DocUpdate { + crdt: CrdtType::Loro, + room_id: "room-ctx".to_string(), + updates: vec![vec![1]], + batch_id: first_batch, + }) + .await + .expect("send first update"); + match client.next().await.expect("first ack") { + Some(ProtocolMessage::Ack { ref_id, status, .. }) => { + assert_eq!(ref_id, first_batch); + assert_eq!(status, UpdateStatusCode::Ok); + } + other => panic!("expected ack, got {:?}", other), + } + + let second_batch = server::protocol::BatchId([4; 8]); + client + .send(&ProtocolMessage::DocUpdate { + crdt: CrdtType::Loro, + room_id: "room-ctx".to_string(), + updates: vec![vec![2]], + batch_id: second_batch, + }) + .await + .expect("send second update"); + match client.next().await.expect("second ack") { + Some(ProtocolMessage::Ack { ref_id, status, .. }) => { + assert_eq!(ref_id, second_batch); + assert_eq!(status, UpdateStatusCode::Ok); + } + other => panic!("expected ack 2, got {:?}", other), + } + let ctxs = timeout(Duration::from_secs(5), async { + loop { + { + let guard = seen_ctx.lock().await; + if guard.len() >= 2 { + return guard.clone(); + } + } + notify.notified().await; + } + }) + .await + .expect("timed out waiting for second hook invocation"); + assert_eq!(ctxs.len(), 2, "hook should run twice"); + assert!(ctxs[0].is_none(), "first call should have no ctx"); + assert_eq!(ctxs[1].as_deref(), Some("persisted")); + + server_task.abort(); +} + +#[tokio::test(flavor = "current_thread")] +async fn on_update_hook_can_supply_doc() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind tcp listener"); + let addr = listener.local_addr().expect("local addr"); + + let notify = Arc::new(Notify::new()); + + let server_task = tokio::spawn({ + let notify = notify.clone(); + async move { + let cfg: Cfg = server::ServerConfig { + on_update: Some(Arc::new(move |_args: server::UpdateArgs<()>| { + let notify = notify.clone(); + Box::pin(async move { + let doc = { + let doc = loro_crdt::LoroDoc::new(); + let text = doc.get_text("shared"); + text.insert(0, "from-hook").unwrap(); + doc + }; + notify.notify_waiters(); + server::UpdatedDoc { + status: UpdateStatusCode::Ok, + ctx: None, + doc: Some(doc), + } + }) + })), + ..Default::default() + }; + server::serve_incoming_with_config(listener, cfg) + .await + .unwrap(); + } + }); + + let url = format!("ws://{}/doc", addr); + let mut c1 = Client::connect(&url).await.expect("c1 connect"); + c1.send(&ProtocolMessage::JoinRequest { + crdt: CrdtType::Loro, + room_id: "room-doc".to_string(), + auth: vec![], + version: vec![], + }) + .await + .expect("join doc room"); + match c1.next().await.expect("join response") { + Some(ProtocolMessage::JoinResponseOk { .. }) => {} + other => panic!("unexpected join response: {:?}", other), + } + match c1.next().await.expect("initial snapshot") { + Some(ProtocolMessage::DocUpdate { .. }) => {} + other => panic!("expected initial snapshot, got {:?}", other), + } + + let notify_wait = notify.notified(); + let batch_id = server::protocol::BatchId([5; 8]); + c1.send(&ProtocolMessage::DocUpdate { + crdt: CrdtType::Loro, + room_id: "room-doc".to_string(), + updates: vec![vec![9]], + batch_id, + }) + .await + .expect("send update to trigger hook doc"); + match c1.next().await.expect("ack") { + Some(ProtocolMessage::Ack { ref_id, status, .. }) => { + assert_eq!(ref_id, batch_id); + assert_eq!(status, UpdateStatusCode::Ok); + } + other => panic!("expected ack, got {:?}", other), + } + + notify_wait.await; + drop(c1); + + let mut c2 = Client::connect(&url).await.expect("c2 connect"); + c2.send(&ProtocolMessage::JoinRequest { + crdt: CrdtType::Loro, + room_id: "room-doc".to_string(), + auth: vec![], + version: vec![], + }) + .await + .expect("join after hook doc"); + match c2.next().await.expect("join response") { + Some(ProtocolMessage::JoinResponseOk { .. }) => {} + other => panic!("unexpected join response: {:?}", other), + } + + let snapshot = match c2.next().await.expect("snapshot after hook doc") { + Some(ProtocolMessage::DocUpdate { updates, .. }) => updates, + other => panic!("expected doc snapshot, got {:?}", other), + }; + let doc = loro_crdt::LoroDoc::new(); + for data in snapshot { + let _ = doc.import(&data); + } + assert_eq!(doc.get_text("shared").to_string(), "from-hook"); + + server_task.abort(); +}