Fix decentralized clients connection issue#1110
Conversation
JulienVig
left a comment
There was a problem hiding this comment.
Great work! I have a few refactoring comments but nothing blocking. I'll approve the merge once the next demo is passed :)
| // If this participant was a latest model provider node, | ||
| // replace the provider node to another node | ||
| if (this.#providerNode === peerId) { | ||
| this.#providerNode = this.connections.keySeq().first() |
There was a problem hiding this comment.
Maybe ensure that the new providerNode is not already in #synchingNodes here
| const syncConnection = await this.#pool.getPeers( | ||
| Set([providerInfo.providerNode]), | ||
| this.server, | ||
| ()=>{} | ||
| ) |
There was a problem hiding this comment.
Calling getPeers adds the providerNode to the peer pool (this line) which is then never removed and unused in aggregation either.
I don't think that's a big deal as I assume this model sync won't happen that often but worth adding a comment to help debug if it becomes problematic
| } else if (msg.type === type.RetryPeerConnections){ | ||
| debug(`[${shortenId(this.ownId)}] retrying peer connection establishment`) | ||
| // clear the communication round peer pool | ||
| await this.#pool?.shutdown() | ||
| this.#pool = new PeerPool(this.ownId) | ||
| // clear the connections | ||
| this.#connections = Map() | ||
| this.setAggregatorNodes(Set(this.ownId)) | ||
| continue | ||
| } else if (msg.type === type.ConnectionFail){ | ||
| debug(`[${shortenId(this.ownId)}] disconnect from the server`) | ||
| await this.disconnect() | ||
| throw new Error("Client disconnected after connection failure") | ||
| } | ||
| } | ||
| // Exchange weight updates with peers and return aggregated weights | ||
| return await this.exchangeWeightUpdates(weights) | ||
| const aggregatedWeight = await this.exchangeWeightUpdates(weights) |
There was a problem hiding this comment.
Is this scenario ever covered in the unit tests?
| // Server signals peers to reestablish peer connections | ||
| export interface RetryPeerConnections { | ||
| type: type.RetryPeerConnections | ||
| aggregationRound: number |
There was a problem hiding this comment.
I don't think aggregationRound is ever used with RetryPeerConnections, is it on purpose?
| case type.PeersForRound: | ||
| return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID) | ||
| case type.SignalNewPeer: | ||
| return 'newNode' in o && isNodeID(o.newNode) |
There was a problem hiding this comment.
If I'm not mistaken SignalNewPeer is always sent by the server to a client, if so there's no need to check its properties, you can return true directly (as you're doing for SignalModelProvider)
| this.#connectionRetry += 1; | ||
|
|
||
| // If the number of retries exceeds the threshold, exclude the failed peers from the round | ||
| // and retry peer connection only with the remaining peers | ||
| if (this.#connectionRetry >= this.task.trainingInformation.maxConnectionRetry){ |
There was a problem hiding this comment.
With the current logic, peers are excluded after two retries if maxConnectionRetry = 3 while I would expect the logic to try 3 times and exclude on the fourth timeout.
| /** | ||
| * Set a timeout to check peer connections establishment | ||
| */ | ||
| private startTimeout(maxTime: number = 60_000): void { |
There was a problem hiding this comment.
I wonder if we should make this timeout a task parameter task like maxConnectionRetry, what do you think?
| * Receive model from the model provider. | ||
| */ | ||
| private async receiveModel(providerConn: PeerConnection): Promise<WeightsContainer>{ | ||
| const message = await waitMessageWithTimeout(providerConn, type.SharedModel, 30_000, "Timeout while waiting for the latest model") |
There was a problem hiding this comment.
I think we should make this timeout value a task parameter, depending on the model size it can make sense to increase the timeout
| // Set ModelWeightAccess of the client | ||
| this.#client.setModelWeightAccess({ | ||
| getModelWeight: () => { | ||
| return new WeightsContainer(this.trainer.model.weights.weights.map(t => t.clone())); | ||
| }, | ||
| setModelWeight: (weights) => { | ||
| this.trainer.model.weights = weights; | ||
| } | ||
| }); | ||
| // Simply propagate the training status events emitted by the client | ||
| this.#client.on("status", (status) => this.emit("status", status)); | ||
| this.#client.on("participants", (nbParticipants) => this.emit("participants", nbParticipants)); | ||
| this.#client.on("modelSynced", (latestWeights) => this.emit("modelSynced", latestWeights)); |
There was a problem hiding this comment.
Maybe you can avoid this new ModelWeightAccess interface (that is only useful for decentralized and needs to be created after initialization) if you do this:
this.#client.on("modelSynced", (latestWeights) => {
this.trainer.model.weights = weights;
this.emit("modelSynced", latestWeights)
});
And to access the model from decentralized_client.ts, you can store a copy of the model like latestModel as an attribute of decentralized client that you can set in finishRound by having the trainer add the model as argument l144 of trainer.ts:
this.#client.finishRound(networkWeights);
Let me know what you think or if my description is too vague.
| const msg = await Promise.race([ | ||
| waitMessage(this.server, type.StartWeightSharing), | ||
| waitMessage(this.server, type.RetryPeerConnections), | ||
| waitMessage(this.server, type.ConnectionFail), | ||
| ]) |
There was a problem hiding this comment.
This creates two event listeners that are not resolved per round and accumulates throughout the rounds. I don't think it's a huge deal but can you add a comment/TODO to note this?
Solves issue: #1109
This pull request tackles decentralized learning peer connection stability and synchronization issues. We are mainly solving 3 main issues.
Connection Ready Check and Signaling Weight Sharing
ConnectionsRetry,StartWeightSharingmessages are newly added to prevent faster peers proceed to weight sharing while slower peers are still establishing connections. Detailed process is as follows:ConnectionsReadymessage to the server.ConnectionsReadymessages.ConnectionsReadymessages matches the number of expected participants for the round, the server sends aStartWeightSharingmessage to all round participants.StartWeightSharing.Connection Retries and Failed Client Disconnection
RetryPeerConnections,ConnectionFailmessages are newly added. In addition, webapp shows error message when a client is disconnected after repetitive failure.maxConnectionRetryparameter is added totrainingInformation, and added to the webapp so that users can specify the value during task creation.Connection failure handling process is as follows:
maxConnectionRetry, the server sends aRetryPeerConnectionmessage to all peers in the current round.RetryPeerConnection, they clean up their peer pool and aggregator nodes, then rerun the peer connection phase.maxConnectionRetryattempts, the server removes the failed peers from the round peer list.ConnectionFailmessage to the failed peers.ConnectionFail, it disconnects from the server.RetryPeerConnectionand retry the peer connection phase without the failed peers.Model Syncing for Participants Joining in the Middle of Training
Model syncing between peers is implemented for participants who are joining in the middle of training.
ModelSyncRequest,SignalModelProvider,SignalNewPeer,SharedModelmessages are newly added. In addition,ModelWeightAccessinterface was added indisco.tsso that client object can adjust model weights when it receives the latest model from the provider peer.Model syncing process is as follows:
NewDecentralizedNodeInfo.NewDecentralizedNodeInfo, the new client sets a local flag indicating that it needs model syncing.ModelSyncRequestmessage to the server.ModelSyncRequest, the server sends messages as step 5 and 6, using selected model provider information from previous training round.SignalModelProviderto the new participant with information about the provider peer.SignalNewPeerto the provider peer with information about the newly joined peer.SharedModelmessage.