diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index bbf403de6..1a66eb4fe 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -3790,10 +3790,6 @@ static int accept(nccl_net_ofi_listen_comm_t *listen_comm, /* Reset request state for connect response message */ prepare_send_conn_resp_req(l_comm); - l_comm->stage = COMM_SEND_CONN; - - case COMM_SEND_CONN: - /* Initialize connect response message */ ret = prepare_conn_resp(ep, l_comm, dev_id); if (ret != 0) { @@ -3806,6 +3802,10 @@ static int accept(nccl_net_ofi_listen_comm_t *listen_comm, /* Send r_comm's remote comm ID */ conn_msg->remote_comm_id = r_comm->remote_comm_id; + l_comm->stage = COMM_SEND_CONN; + + case COMM_SEND_CONN: + /* COMM_SEND_CONN: Send connect response message to remote */ ret = post_send_conn_resp(r_comm, conn_msg, device, ep, req); if (ret == -FI_EAGAIN) { @@ -5159,10 +5159,6 @@ static int connect(nccl_net_ofi_ep_t *base_ep, } comm_state->req = &req->base; - comm_state->stage = COMM_SEND_CONN; - - case COMM_SEND_CONN: - /* Prepare request to receive connect response message */ s_comm->conn_resp_req = prepare_recv_conn_resp_req(s_comm); if (OFI_UNLIKELY(s_comm->conn_resp_req == NULL)) { @@ -5170,6 +5166,10 @@ static int connect(nccl_net_ofi_ep_t *base_ep, return -EINVAL; } + comm_state->stage = COMM_SEND_CONN; + + case COMM_SEND_CONN: + /* COMM_SEND_CONN: Post a connect message to send peer connections */ ret = post_send_conn(s_comm, device, ep, req); if (ret == -FI_EAGAIN) {