| | import torch
|
| | import torch.nn as nn
|
| | from typing import Dict, Any, List
|
| | import asyncio
|
| | import websockets
|
| | import json
|
| | from pydantic import BaseModel
|
| |
|
| | class PeerMessage(BaseModel):
|
| | message_type: str
|
| | payload: Dict[str, Any]
|
| | peer_id: str
|
| |
|
| | class DecentModel(nn.Module):
|
| | """Base class for decentralized deep learning models"""
|
| |
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.peer_id = self._generate_peer_id()
|
| | self.peers: List[str] = []
|
| | self.websocket = None
|
| | self.state_updates = {}
|
| |
|
| | def _generate_peer_id(self) -> str:
|
| | """Generate a unique peer ID"""
|
| | import uuid
|
| | return str(uuid.uuid4())
|
| |
|
| | async def connect_to_network(self, network_url: str):
|
| | """Connect to the decentralized network"""
|
| | self.websocket = await websockets.connect(network_url)
|
| | await self._register_peer()
|
| |
|
| | async def _register_peer(self):
|
| | """Register this peer with the network"""
|
| | message = PeerMessage(
|
| | message_type="register",
|
| | payload={"model_type": self.__class__.__name__},
|
| | peer_id=self.peer_id
|
| | )
|
| | await self.websocket.send(message.json())
|
| |
|
| | async def broadcast_state_update(self, state_dict: Dict[str, torch.Tensor]):
|
| | """Broadcast model state updates to other peers"""
|
| | message = PeerMessage(
|
| | message_type="state_update",
|
| | payload={"state": self._serialize_state_dict(state_dict)},
|
| | peer_id=self.peer_id
|
| | )
|
| | await self.websocket.send(message.json())
|
| |
|
| | def _serialize_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, List[float]]:
|
| | """Serialize model state for transmission"""
|
| | return {k: v.cpu().numpy().tolist() for k, v in state_dict.items()}
|
| |
|
| | async def receive_state_updates(self):
|
| | """Receive and process state updates from other peers"""
|
| | while True:
|
| | message = await self.websocket.recv()
|
| | data = PeerMessage.parse_raw(message)
|
| | if data.message_type == "state_update":
|
| | self.state_updates[data.peer_id] = self._deserialize_state_dict(
|
| | data.payload["state"]
|
| | )
|
| |
|
| | def _deserialize_state_dict(self, state_dict: Dict[str, List[float]]) -> Dict[str, torch.Tensor]:
|
| | """Deserialize received model state"""
|
| | return {k: torch.tensor(v) for k, v in state_dict.items()}
|
| |
|
| | def aggregate_states(self):
|
| | """Aggregate state updates from all peers"""
|
| | if not self.state_updates:
|
| | return
|
| |
|
| |
|
| | aggregated_state = {}
|
| | for key in self.state_updates[list(self.state_updates.keys())[0]].keys():
|
| | tensors = [states[key] for states in self.state_updates.values()]
|
| | aggregated_state[key] = torch.mean(torch.stack(tensors), dim=0)
|
| |
|
| |
|
| | self.load_state_dict(aggregated_state)
|
| | self.state_updates.clear()
|
| |
|
| | def forward(self, *args, **kwargs):
|
| | """Forward pass - to be implemented by child classes"""
|
| | raise NotImplementedError |