Skip to content

Commit

Permalink
feat: initial aggregation tree support
Browse files Browse the repository at this point in the history
use `--upstream` to specify a parent in the tree

Signed-off-by: Wataru Ishida <wataru.ishid@gmail.com>
  • Loading branch information
ishidawataru committed Mar 26, 2024
1 parent bb3c3f8 commit 3da62b9
Show file tree
Hide file tree
Showing 2 changed files with 288 additions and 3 deletions.
288 changes: 285 additions & 3 deletions reduction_server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use std::collections::HashMap;
use std::hint;
use std::io::{Read, Write};
use std::net::TcpListener;
use std::net::{TcpListener, TcpStream};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;

Expand All @@ -18,7 +18,7 @@ use crate::reduce::{Reduce, WorkingMemory};
use crate::utils::*;

use crate::nccl_net;
use crate::nccl_net::Comm;
use crate::nccl_net::{Comm, Request};

use crate::partitioned_vec::PartitionedVec;

Expand Down Expand Up @@ -446,6 +446,181 @@ fn recv_loop<T: Float>(
}
}

fn upstream_loop<T: Float>(
args: &Args,
rank: &AtomicUsize,
mut jobs: Vec<(
Arc<AtomicUsize>,
Vec<Arc<AtomicUsize>>,
Arc<PartitionedVec<T>>,
)>,
) {
let nrank = args.nrank;

info!("connecting to upstream {}", args.upstream);
let mut stream = loop {
let res = TcpStream::connect(&args.upstream);
if res.is_ok() {
break res.unwrap();
}
// sleep 1s
std::thread::sleep(std::time::Duration::from_secs(1));
};

let mut buffer = [0u8; 4];
stream.read(buffer.as_mut()).unwrap();
let size = u32::from_le_bytes(buffer);
let mut handle = vec![0u8; size as usize];
stream.read(handle.as_mut()).unwrap();

let (lcomm, lhandle) = nccl_net::listen().unwrap();

// send size of handle
let size = lhandle.len() as u32;
stream.write_all(&size.to_le_bytes()).unwrap();
// send handle
stream.write_all(&lhandle).unwrap();

let mut scomm: Option<Comm> = None;
let mut rcomm: Option<Comm> = None;

loop {
if scomm.is_none() {
scomm = nccl_net::connect(handle.as_slice()).unwrap();
}
if rcomm.is_none() {
rcomm = nccl_net::accept(&lcomm).unwrap();
}
if scomm.is_some() && rcomm.is_some() {
break;
}
}

let scomm = scomm.unwrap();
let rcomm = rcomm.unwrap();

let mhs = jobs
.iter()
.map(|(_, _, buf)| {
let send_mh = nccl_net::reg_mr(&scomm, &buf.lock()).unwrap();
let recv_mh = nccl_net::reg_mr(&rcomm, &buf.lock()).unwrap();
(send_mh, recv_mh)
})
.collect::<Vec<_>>();

loop {
if rank.load(std::sync::atomic::Ordering::Relaxed) == args.nrank {
break;
}
std::thread::sleep(std::time::Duration::from_millis(100));
}

info!("upstream connected");
let tag = 0x69;
let size = args.count * std::mem::size_of::<T>();
let dummy_buf =
PartitionedVec::<T>::new(alignment(size), args.count, args.reduce_threads).unwrap();

loop {
for (idx, (send_ready, reduce_readys, buf)) in jobs.iter_mut().enumerate() {
for reduce_ready in reduce_readys.iter() {
loop {
hint::spin_loop();
let reduce_ready = reduce_ready.load(std::sync::atomic::Ordering::Relaxed);
if reduce_ready == 0 {
break;
}
if rank.load(std::sync::atomic::Ordering::Relaxed) != nrank {
warn!("rank != nrank");
warn!("upstream thread({}) exit.", 0);
return;
}
}
}

loop {
hint::spin_loop();
let send_ready = send_ready.load(std::sync::atomic::Ordering::Relaxed);
let send_expect = (1 << args.send_threads) - 1;
if send_ready == send_expect {
break;
}
if rank.load(std::sync::atomic::Ordering::Relaxed) != nrank {
warn!("rank != nrank");
warn!("upstream thread({}) exit.", 0);
return;
}
}

let (send_mh, recv_mh) = &mhs[idx];
let mut srequest: Option<Request> = None;
let mut rrequest: Option<Request> = None;

loop {
if srequest.is_none() {
srequest = nccl_net::isend(&scomm, send_mh, buf.lock().as_ref(), tag).unwrap();
if srequest.is_some() {
trace!("upstream send : idx: {} start", idx);
}
}
if rrequest.is_none() {
rrequest =
nccl_net::irecv(&rcomm, recv_mh, dummy_buf.lock().as_mut(), tag).unwrap();
if srequest.is_some() {
trace!("upstream recv : idx: {} start", idx);
}
}
if srequest.is_some() && rrequest.is_some() {
break;
}
}

loop {
if srequest.is_some() {
match nccl_net::test(&srequest.as_ref().unwrap()) {
Ok((send_done, _)) => {
if send_done {
trace!("upstream send : idx: {} done", idx);
srequest = None;
}
}
Err(e) => {
error!("upstream send : idx: {} error: {:?}", idx, e);
return
}
}
}
if rrequest.is_some() {
match nccl_net::test(&rrequest.as_ref().unwrap()) {
Ok((recv_done, _)) => {
if recv_done {
trace!("upstream recv : idx: {} done", idx);
rrequest = None;
}
}
Err(e) => {
error!("upstream recv : idx: {} error: {:?}", idx, e);
return
}
}
}
if srequest.is_none() && rrequest.is_none() {
break;
}
}

for reduce_ready in reduce_readys.iter_mut() {
reduce_ready.store(
(1 << args.send_threads) - 1,
std::sync::atomic::Ordering::Relaxed,
);
}

send_ready.store(0, std::sync::atomic::Ordering::Relaxed);
}
}
}

fn do_server<T: Float + 'static>(args: Args) {
let mut args = args;

Expand Down Expand Up @@ -514,7 +689,7 @@ fn do_server<T: Float + 'static>(args: Args) {
})
.collect::<Vec<_>>();

// transpose readys[job][thread]
// transpose readys[reduce_threads][reduce_jobs] to readys[reduce_jobs][reduce_threads]
let (send_readys, recv_readys): (Vec<_>, Vec<_>) = (0..args.reduce_jobs)
.map(|i| {
(0..args.reduce_threads)
Expand All @@ -523,6 +698,31 @@ fn do_server<T: Float + 'static>(args: Args) {
})
.unzip();

let send_readys = if args.upstream.is_empty() {
send_readys
} else {
// launch upstream thread
let args = Arc::clone(&args);
let jobs = send_readys
.into_iter()
.enumerate()
.map(|(i, send_ready)| {
let ready = Arc::new(AtomicUsize::new((1 << args.send_threads) - 1));
let (sbuf, _) = &bufs[i];
(ready, send_ready, Arc::clone(sbuf))
})
.collect::<Vec<_>>();

let readys = jobs
.iter()
.map(|(ready, _, _)| vec![Arc::clone(ready)])
.collect::<Vec<_>>();

let rank = Arc::clone(&rank);
std::thread::spawn(move || upstream_loop(&args, &rank, jobs));
readys
};

// launch send threads
let send_chs = (0..args.send_threads)
.map(|send_idx| {
Expand Down Expand Up @@ -646,6 +846,88 @@ mod tests {
server.join().unwrap();
}

fn do_test_upstream(dt: &str) {
initialize();
let nrank = 2;
let root = {
let dt = dt.to_string();
std::thread::spawn(move || {
let nrank = format!("{}", nrank);
let args = Args::parse_from([
"--verbose", // doesn't work without specifying a flag that doesn't take an argument
"--port",
"8080",
"--data-type",
&dt,
"--nrank",
&nrank,
"--target",
"root",
]);
server(args);
})
};
(0..nrank)
.map(|i| {
let port = format!("{}", 8081 + i);
let parent = {
let dt = dt.to_string();
let port = port.to_string();
let parent = format!("parent{}", i);
std::thread::spawn(move || {
let nrank = format!("{}", nrank);
let args = Args::parse_from([
"--verbose", // doesn't work without specifying a flag that doesn't take an argument
"--upstream",
"localhost:8080",
"--port",
&port,
"--data-type",
&dt,
"--nrank",
&nrank,
"--target",
&parent,
]);
server(args);
})
};
let children = (0..nrank).map(move |j| {
let dt = dt.to_string();
let port = port.to_string();
let child = format!("child{}.{}", i, j);
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(100));
let address = format!("127.0.0.1:{}", port);
let args = Args::parse_from([
"--client",
"--address",
&address,
"--data-type",
&dt,
"--nreq",
"1", // when using socket plugin, concurrent recv/send requests doesn't work
"--target",
&child,
]);
client(args);
})
});
vec![parent].into_iter().chain(children)
})
.flatten()
.collect::<Vec<_>>()
.into_iter()
.for_each(|h| h.join().unwrap());

root.join().unwrap();
}

#[test]
fn test_server_with_upstream_f32() {
do_test_upstream("f32");
}

#[test]
fn test_server_f32() {
do_test("f32");
Expand Down
3 changes: 3 additions & 0 deletions reduction_server/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ pub(crate) struct Args {
#[arg(short, long, default_value = "0.0.0.0")]
pub address: String,

#[arg(long, default_value = "")]
pub upstream: String,

#[arg(long, default_value = "1048576")]
pub count: usize,

Expand Down

0 comments on commit 3da62b9

Please sign in to comment.