use crate::{
carets::*,
db::*,
graphql::sync_graphql_server,
log::log_sync_init,
state::*,
};
use crossbeam_channel::{
unbounded,
Receiver as CCReceiver,
Sender as CCSender,
};
use edit_common::commands::*;
use edit_common::simple_ws;
use edit_common::simple_ws::*;
use failure::Error;
use oatie::doc::*;
use oatie::rtf::*;
use rand::{
thread_rng,
Rng,
};
use serde_json;
use std::env;
use std::{
collections::HashMap,
thread,
time::Duration,
};
use url::Url;
use ws;
fn debug_sync_delay() -> Option<u64> {
env::var("EDIT_DEBUG_SYNC_DELAY")
.ok()
.and_then(|x| x.parse::<u64>().ok())
}
const INITIAL_SYNC_VERSION: usize = 100;
const PAGE_TITLE_LEN: usize = 100;
pub fn default_new_doc(id: &str) -> Doc<RtfSchema> {
doc![DocGroup(Attrs::Header(1), [DocText(id),])]
}
pub fn valid_page_id(input: &str) -> bool {
if input.is_empty() || input.len() > PAGE_TITLE_LEN {
return false;
}
input
.chars()
.all(|x| x.is_digit(10) || x.is_ascii_alphabetic() || x == '_' || x == '-')
}
fn generate_random_page_id() -> String {
thread_rng().gen_ascii_chars().take(6).collect()
}
pub struct ClientNotify(pub String, pub ClientUpdate);
pub enum ClientUpdate {
Connect {
client_id: String,
out: simple_ws::Sender,
},
Commit {
client_id: String,
op: Op<RtfSchema>,
version: usize,
},
Disconnect {
client_id: String,
},
Overwrite {
doc: Doc<RtfSchema>,
},
}
struct ClientSocket {
page_id: String,
client_id: String,
tx_master: CCSender<ClientNotify>,
}
impl SimpleSocket for ClientSocket {
type Args = (String, CCSender<ClientNotify>);
fn initialize(
(client_id, tx_master): Self::Args,
url: &str,
out: simple_ws::Sender,
) -> Result<ClientSocket, Error> {
let url = Url::parse("http://localhost/").unwrap().join(url).unwrap();
let mut path = url.path().to_owned();
if path.starts_with("/$/ws/") {
path = path["/$/ws".len()..].to_string();
}
let page_id = if valid_page_id(&path[1..]) {
path[1..].to_string()
} else {
"home".to_string()
};
eprintln!("(!) Client {:?} connected to {:?}", client_id, page_id);
let _ = tx_master.send(ClientNotify(
page_id.to_string(),
ClientUpdate::Connect {
client_id: client_id.to_string(),
out: out,
},
));
Ok(ClientSocket {
page_id: page_id.to_string(),
client_id: client_id.to_string(),
tx_master,
})
}
fn handle_message(&mut self, data: &[u8]) -> Result<(), Error> {
let command: ServerCommand = serde_json::from_slice(&data)?;
match command {
ServerCommand::Commit(client_id, op, version) => {
let _ = self.tx_master.send(ClientNotify(
self.page_id.to_string(),
ClientUpdate::Commit {
client_id,
op,
version,
},
));
}
ServerCommand::TerminateProxy => {
}
ServerCommand::Log(log) => {
log_raw!(self.client_id, log);
}
}
Ok(())
}
fn cleanup(&mut self) -> Result<(), Error> {
self.tx_master.send(ClientNotify(
self.page_id.to_owned(),
ClientUpdate::Disconnect {
client_id: self.client_id.to_owned(),
},
));
Ok(())
}
}
pub struct PageController {
page_id: String,
db_pool: DbPool,
state: SyncState,
clients: HashMap<String, simple_ws::Sender>,
}
#[allow(unused)]
impl PageController {
fn sync_commit(&mut self, client_id: &str, op: Op<RtfSchema>, input_version: usize) {
let op = self
.state
.commit(&client_id, op, input_version)
.expect("Could not commit client operation.");
if let Ok(doc) = remove_carets(&self.state.doc) {
let conn = self.db_pool.get().unwrap();
create_page(&conn, &self.page_id, &doc);
}
let command = ClientCommand::Update(self.state.version, client_id.to_owned(), op);
self.broadcast_client_command(&command);
}
fn broadcast_client_command(&self, command: &ClientCommand) {
let json = serde_json::to_string(&command).unwrap();
for (_, client) in &self.clients {
let _ = client.lock().unwrap().send(json.clone());
}
}
fn send_client_command(
&self,
client: &simple_ws::Sender,
command: &ClientCommand,
) -> Result<(), Error> {
let json = serde_json::to_string(&command).unwrap();
Ok(client.lock().unwrap().send(json.clone())?)
}
fn send_client_restart(&self, client_id: &str) -> Result<(), Error> {
let code = ws::CloseCode::Restart;
let reason = "Server received an updated version of the document.";
self.clients.get(client_id).map(|client| {
let _ = client.lock().unwrap().close_with_reason(code, reason);
});
Ok(())
}
fn broadcast_restart(&self) -> Result<(), Error> {
let code = ws::CloseCode::Restart;
let reason = "Server received an updated version of the document.";
for (_, client) in &self.clients {
let _ = client.lock().unwrap().close_with_reason(code, reason);
}
Ok(())
}
fn handle(&mut self, notification: ClientUpdate) {
match notification {
ClientUpdate::Connect { client_id, out } => {
let version = self.state.version;
let command =
ClientCommand::Init(client_id.to_string(), self.state.doc.0.clone(), version);
let _ = self.send_client_command(&out, &command);
self.state.clients.insert(client_id.to_string(), version);
self.clients.insert(client_id.to_string(), out);
}
ClientUpdate::Disconnect { client_id } => {
let op = remove_carets_op(&self.state.doc, vec![client_id.clone()]).unwrap();
let version = self.state.version;
self.sync_commit(&client_id, op, version);
self.state.clients.remove(&client_id);
self.clients.remove(&client_id);
}
ClientUpdate::Commit {
client_id,
op,
version,
} => {
if let Some(delay) = debug_sync_delay() {
thread::sleep(Duration::from_millis(delay));
}
let sync = ::std::panic::catch_unwind(::std::panic::AssertUnwindSafe(|| {
self.sync_commit(&client_id, op, version);
}));
if let Err(err) = sync {
eprintln!(
"received invalid packet from client: {:?} - {:?}",
client_id, err
);
}
}
ClientUpdate::Overwrite { doc } => {
let _ = self.broadcast_restart();
self.state = SyncState::new(doc, INITIAL_SYNC_VERSION);
self.clients = HashMap::new();
}
}
}
}
pub fn spawn_sync_thread(
page_id: String,
rx_notify: CCReceiver<ClientUpdate>,
inner_doc: Doc<RtfSchema>,
db_pool: DbPool,
) -> Result<(), Error> {
thread::spawn(move || {
let mut sync = PageController {
page_id,
db_pool,
state: SyncState::new(inner_doc, INITIAL_SYNC_VERSION),
clients: HashMap::new(),
};
while let Some(notification) = rx_notify.recv() {
sync.handle(notification);
}
});
Ok(())
}
struct PageMaster {
db_pool: DbPool,
pages: HashMap<String, CCSender<ClientUpdate>>,
}
impl PageMaster {
fn new(db_pool: DbPool) -> PageMaster {
PageMaster {
db_pool,
pages: hashmap![],
}
}
fn acquire_page(&mut self, page_id: &str) -> CCSender<ClientUpdate> {
if self.pages.get(page_id).is_none() {
println!("(%) loading new page for {:?}", page_id);
let conn = self.db_pool.get().unwrap();
let inner_doc = get_single_page(&conn, page_id).unwrap_or_else(|| {
eprintln!("warning: could not find page {:?}, using default.", page_id);
default_new_doc(page_id)
});
let (tx_notify, rx_notify) = unbounded();
self.pages.insert(page_id.to_string(), tx_notify.clone());
let _ = spawn_sync_thread(
page_id.to_owned(),
rx_notify,
inner_doc,
self.db_pool.clone(),
);
tx_notify
} else {
self.pages.get(page_id).map(|x| x.clone()).unwrap()
}
}
}
fn spawn_page_master(db_pool: DbPool, rx_master: CCReceiver<ClientNotify>) {
thread::spawn(move || {
let mut page_map = PageMaster::new(db_pool);
while let Some(ClientNotify(page_id, notification)) = rx_master.recv() {
let _ = page_map.acquire_page(&page_id).send(notification);
}
});
}
pub fn sync_socket_server(port: u16) {
let db_pool = db_pool_create();
log_sync_init(db_pool.clone());
log_sync!("SERVER", Spawn);
let (tx_master, rx_master) = unbounded::<ClientNotify>();
spawn_page_master(db_pool.clone(), rx_master);
::std::thread::spawn({
take!(=db_pool, =tx_master);
move || {
sync_graphql_server(db_pool, tx_master);
}
});
let url = format!("0.0.0.0:{}", port);
eprintln!(
" Sync server is listening for WebSocket connections on port {}",
port
);
let _ = ws::listen(url, {
take!(=tx_master);
move |out| {
log_sync!("SERVER", ClientConnect);
eprintln!("Client connected.");
SocketHandler::<ClientSocket>::new(
(
generate_random_page_id(),
tx_master.clone(),
),
out,
)
}
});
}