diff --git a/Dockerfile b/Dockerfile index 7544194..c554b6a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -48,6 +48,7 @@ FROM optcast AS unittest ENV RUST_LOG=info ENV NCCL_SOCKET_IFNAME=lo +ENV RUSTFLAGS="--cfg no_spinloop" RUN cd reduction_server && cargo test --all -- --nocapture --test-threads=1 FROM nvcr.io/nvidia/cuda:12.3.1-devel-ubuntu22.04 AS final diff --git a/reduction_server/src/client.rs b/reduction_server/src/client.rs index c24b22d..e3d3b68 100644 --- a/reduction_server/src/client.rs +++ b/reduction_server/src/client.rs @@ -4,6 +4,7 @@ * See LICENSE for license information */ +use std::hint; use std::io::{Read, Write}; use std::net::{TcpListener, TcpStream}; use std::sync::Arc; @@ -60,6 +61,12 @@ fn do_client(args: &Args, comms: Vec<(Comm, Comm)>) { let start = std::time::Instant::now(); loop { + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } + for (i, req, sbuf, rbuf, mhs) in reqs.iter_mut() { if req.is_none() && reqed < args.try_count { *req = Some( @@ -72,6 +79,12 @@ fn do_client(args: &Args, comms: Vec<(Comm, Comm)>) { let mut rrequest: Option = None; loop { + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } + if srequest.is_none() { srequest = nccl_net::isend( scomm, diff --git a/reduction_server/src/server.rs b/reduction_server/src/server.rs index 654e407..ba59864 100644 --- a/reduction_server/src/server.rs +++ b/reduction_server/src/server.rs @@ -12,7 +12,7 @@ use std::sync::atomic::AtomicUsize; use std::sync::Arc; use half::{bf16, f16}; -use log::{info, trace, warn, error}; +use log::{error, info, trace, warn}; use crate::reduce::{Reduce, WorkingMemory}; use crate::utils::*; @@ -103,7 +103,11 @@ fn reduce_loop( trace!("rank({})/job({}) reduce wait recv", i, job_idx); loop { - hint::spin_loop(); + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } let send_ready = send_ready.load(std::sync::atomic::Ordering::Relaxed); let send_expect = (1 << args.send_threads) - 1; let recv_ready = recv_ready.load(std::sync::atomic::Ordering::Relaxed); @@ -213,7 +217,11 @@ fn send_loop( for (idx, (readys, send)) in sends.iter().enumerate().cycle() { for ready in readys.iter() { loop { - hint::spin_loop(); + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } let ready = ready.load(std::sync::atomic::Ordering::Relaxed); // trace!( // "[send] rank({})/job({}) send ready: 0b{:016b}", @@ -235,7 +243,11 @@ fn send_loop( let mut reqs = vec_of_none(send.len()); loop { - hint::spin_loop(); + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } if rank.load(std::sync::atomic::Ordering::Relaxed) != nrank { warn!("rank != nrank"); warn!("send thread({}) exit.", i); @@ -260,7 +272,12 @@ fn send_loop( let start = std::time::Instant::now(); loop { - hint::spin_loop(); + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } + if rank.load(std::sync::atomic::Ordering::Relaxed) != nrank { warn!("rank != nrank"); warn!("send thread({}) exit.", i); @@ -360,7 +377,11 @@ fn recv_loop( for (job_idx, (readys, recv)) in recvs.iter_mut().enumerate() { for ready in readys.iter() { loop { - hint::spin_loop(); + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } let ready = ready.load(std::sync::atomic::Ordering::Relaxed); // trace!( // "[recv] rank({})/job({}) recv ready: 0b{:016b}", @@ -382,7 +403,11 @@ fn recv_loop( let mut reqs = vec_of_none(recv.len()); loop { - hint::spin_loop(); + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } if rank.load(std::sync::atomic::Ordering::Relaxed) != nrank { warn!("rank != nrank"); warn!("recv thread({}) exit.", i); @@ -408,7 +433,12 @@ fn recv_loop( let start = std::time::Instant::now(); loop { - hint::spin_loop(); + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } + if rank.load(std::sync::atomic::Ordering::Relaxed) != nrank { warn!("rank != nrank"); warn!("recv thread({}) exit.", i); @@ -522,7 +552,12 @@ fn upstream_loop( for (idx, (send_ready, reduce_readys, buf)) in jobs.iter_mut().enumerate() { for reduce_ready in reduce_readys.iter() { loop { - hint::spin_loop(); + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } + let reduce_ready = reduce_ready.load(std::sync::atomic::Ordering::Relaxed); if reduce_ready == 0 { break; @@ -536,7 +571,11 @@ fn upstream_loop( } loop { - hint::spin_loop(); + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } let send_ready = send_ready.load(std::sync::atomic::Ordering::Relaxed); let send_expect = (1 << args.send_threads) - 1; if send_ready == send_expect { @@ -554,6 +593,12 @@ fn upstream_loop( let mut rrequest: Option = None; loop { + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } + if srequest.is_none() { srequest = nccl_net::isend(&scomm, send_mh, buf.lock().as_ref(), tag).unwrap(); if srequest.is_some() { @@ -561,8 +606,7 @@ fn upstream_loop( } } if rrequest.is_none() { - rrequest = - nccl_net::irecv(&rcomm, recv_mh, buf.lock().as_mut(), tag).unwrap(); + rrequest = nccl_net::irecv(&rcomm, recv_mh, buf.lock().as_mut(), tag).unwrap(); if srequest.is_some() { trace!("upstream recv : idx: {} start", idx); } @@ -573,6 +617,12 @@ fn upstream_loop( } loop { + if cfg!(no_spinloop) { + std::thread::sleep(NO_SPINLOOP_INTERVAL); + } else { + hint::spin_loop(); + } + if srequest.is_some() { match nccl_net::test(&srequest.as_ref().unwrap()) { Ok((send_done, _)) => { @@ -583,7 +633,7 @@ fn upstream_loop( } Err(e) => { error!("upstream send : idx: {} error: {:?}", idx, e); - return + return; } } } @@ -597,7 +647,7 @@ fn upstream_loop( } Err(e) => { error!("upstream recv : idx: {} error: {:?}", idx, e); - return + return; } } } diff --git a/reduction_server/src/utils.rs b/reduction_server/src/utils.rs index dc4ab7d..bf068fd 100644 --- a/reduction_server/src/utils.rs +++ b/reduction_server/src/utils.rs @@ -8,10 +8,12 @@ use std::fmt::Debug; use std::time::Duration; use clap::{Parser, ValueEnum}; -use half::{f16, bf16}; +use half::{bf16, f16}; use log::info; use num_traits::FromPrimitive; +pub(crate) const NO_SPINLOOP_INTERVAL: std::time::Duration = std::time::Duration::from_millis(100); + pub(crate) fn transpose(v: Vec>) -> Vec> { assert!(!v.is_empty()); let len = v[0].len(); @@ -140,8 +142,8 @@ pub(crate) fn vec_of_none(n: usize) -> Vec> { #[cfg(test)] pub mod tests { - use std::sync::Once; use crate::nccl_net; + use std::sync::Once; static INIT: Once = Once::new();