Skip to content

Commit

Permalink
Merge pull request #287 from talex5/buf-write
Browse files Browse the repository at this point in the history
Add buffering of outgoing messages
  • Loading branch information
talex5 authored Nov 17, 2024
2 parents d8743c1 + dd0c602 commit c95f619
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 163 deletions.
155 changes: 37 additions & 118 deletions capnp-rpc-net/capTP_capnp.ml
Original file line number Diff line number Diff line change
@@ -1,33 +1,5 @@
open Eio.Std

module Metrics = struct
open Prometheus

let namespace = "capnp"

let subsystem = "net"

let connections =
let help = "Number of live capnp-rpc connections" in
Gauge.v ~help ~namespace ~subsystem "connections"

let messages_inbound_received_total =
let help = "Total number of messages received" in
Counter.v ~help ~namespace ~subsystem "messages_inbound_received_total"

let messages_outbound_enqueued_total =
let help = "Total number of messages enqueued to be transmitted" in
Counter.v ~help ~namespace ~subsystem "messages_outbound_enqueued_total"

let messages_outbound_sent_total =
let help = "Total number of messages transmitted" in
Counter.v ~help ~namespace ~subsystem "messages_outbound_sent_total"

let messages_outbound_dropped_total =
let help = "Total number of messages lost due to disconnections" in
Counter.v ~help ~namespace ~subsystem "messages_outbound_dropped_total"
end

module Log = Capnp_rpc.Debug.Log

module Builder = Capnp_rpc.Private.Schema.Builder
Expand All @@ -42,10 +14,8 @@ module Make (Network : S.NETWORK) = struct
module Serialise = Serialise.Make(Endpoint_types)

type t = {
sw : Switch.t;
endpoint : Endpoint.t;
conn : Conn.t;
xmit_queue : Capnp.Message.rw Capnp.BytesMessage.Message.t Queue.t;
mutable disconnecting : bool;
}

Expand All @@ -60,94 +30,49 @@ module Make (Network : S.NETWORK) = struct

let tags t = Conn.tags t.conn

let drop_queue q =
Prometheus.Counter.inc Metrics.messages_outbound_dropped_total (float_of_int (Queue.length q));
Queue.clear q

(* [flush ~xmit_queue endpoint] writes each message in the queue until it is empty.
Invariant:
Whenever Eio blocks or switches threads, a flush thread is running iff the
queue is non-empty. *)
let rec flush ~xmit_queue endpoint =
(* We keep the item on the queue until it is transmitted, as the queue state
tells us whether there is a [flush] currently running. *)
let next = Queue.peek xmit_queue in
match Endpoint.send endpoint next with
| Error `Closed ->
Endpoint.disconnect endpoint; (* We'll read a close soon *)
drop_queue xmit_queue
| Error (`Msg msg) ->
Log.warn (fun f -> f "Error sending messages: %s (will shutdown connection)" msg);
Endpoint.disconnect endpoint;
drop_queue xmit_queue
| Ok () ->
Prometheus.Counter.inc_one Metrics.messages_outbound_sent_total;
ignore (Queue.pop xmit_queue);
if not (Queue.is_empty xmit_queue) then
flush ~xmit_queue endpoint
(* else queue is empty and flush thread is done *)
| exception ex ->
drop_queue xmit_queue;
raise ex

(* Enqueue [message] in [xmit_queue] and ensure the flush thread is running. *)
let queue_send ~sw ~xmit_queue endpoint message =
Log.debug (fun f ->
let module M = Capnp_rpc.Private.Schema.MessageWrapper.Message in
f "queue_send: %d/%d allocated bytes in %d segs"
(M.total_size message)
(M.total_alloc_size message)
(M.num_segments message));
let was_idle = Queue.is_empty xmit_queue in
Queue.add message xmit_queue;
Prometheus.Counter.inc_one Metrics.messages_outbound_enqueued_total;
if was_idle then Eio.Fiber.fork ~sw (fun () -> flush ~xmit_queue endpoint)

let return_not_implemented t x =
Log.debug (fun f -> f ~tags:(tags t) "Returning Unimplemented");
let open Builder in
let m = Message.init_root () in
let _ : Builder.Message.t = Message.unimplemented_set_reader m x in
queue_send ~sw:t.sw ~xmit_queue:t.xmit_queue t.endpoint (Message.to_message m)

let listen t =
let rec loop () =
match Endpoint.recv t.endpoint with
| Error e -> e
| Ok msg ->
let open Reader.Message in
let msg = of_message msg in
Prometheus.Counter.inc_one Metrics.messages_inbound_received_total;
match Parse.message msg with
| #Endpoint_types.In.t as msg ->
Log.debug (fun f ->
let tags = Endpoint_types.In.with_qid_tag (Conn.tags t.conn) msg in
f ~tags "<- %a" (Endpoint_types.In.pp_recv pp_msg) msg);
begin match msg with
| `Abort _ ->
t.disconnecting <- true;
Conn.handle_msg t.conn msg;
Endpoint.disconnect t.endpoint;
`Aborted
| _ ->
Conn.handle_msg t.conn msg;
loop ()
end
| `Unimplemented x as msg ->
Log.info (fun f ->
let tags = Endpoint_types.Out.with_qid_tag (Conn.tags t.conn) x in
f ~tags "<- Unimplemented(%a)" (Endpoint_types.Out.pp_recv pp_msg) x);
Conn.handle_msg t.conn msg;
loop ()
| `Not_implemented ->
Log.info (fun f -> f "<- unsupported message type");
return_not_implemented t msg;
loop ()
in
loop ()
Endpoint.send t.endpoint (Message.to_message m)

let rec listen t =
match Endpoint.recv ~tags:(tags t) t.endpoint with
| Error e -> e
| Ok msg ->
let open Reader.Message in
let msg = of_message msg in
match Parse.message msg with
| #Endpoint_types.In.t as msg ->
Log.debug (fun f ->
let tags = Endpoint_types.In.with_qid_tag (Conn.tags t.conn) msg in
f ~tags "<- %a" (Endpoint_types.In.pp_recv pp_msg) msg);
begin match msg with
| `Abort _ ->
t.disconnecting <- true;
Conn.handle_msg t.conn msg;
Endpoint.disconnect t.endpoint;
Conn.disconnect t.conn (Capnp_rpc_proto.Exception.v "Received Abort from peer");
`Aborted
| _ ->
Conn.handle_msg t.conn msg;
listen t
end
| `Unimplemented x as msg ->
Log.info (fun f ->
let tags = Endpoint_types.Out.with_qid_tag (Conn.tags t.conn) x in
f ~tags "<- Unimplemented(%a)" (Endpoint_types.Out.pp_recv pp_msg) x);
Conn.handle_msg t.conn msg;
listen t
| `Not_implemented ->
Log.info (fun f -> f "<- unsupported message type");
return_not_implemented t msg;
listen t

let send_abort t ex =
queue_send ~sw:t.sw ~xmit_queue:t.xmit_queue t.endpoint (Serialise.message (`Abort ex))
Endpoint.send t.endpoint (Serialise.message (`Abort ex));
Endpoint.flush t.endpoint (* We're probably about to disconnect *)

let disconnect t ex =
if not t.disconnecting then (
Expand All @@ -160,21 +85,17 @@ module Make (Network : S.NETWORK) = struct
let disconnecting t = t.disconnecting

let connect ~sw ~restore ?(tags=Logs.Tag.empty) endpoint =
let xmit_queue = Queue.create () in
let queue_send msg = queue_send ~sw ~xmit_queue endpoint (Serialise.message msg) in
let queue_send msg = Endpoint.send endpoint (Serialise.message msg) in
let restore = Restorer.fn restore in
let fork = Fiber.fork ~sw in
let conn = Conn.create ~restore ~tags ~fork ~queue_send in
{
sw;
conn;
endpoint;
xmit_queue;
disconnecting = false;
}

let listen t =
Prometheus.Gauge.inc_one Metrics.connections;
let tags = Conn.tags t.conn in
begin
match listen t with
Expand All @@ -187,8 +108,6 @@ module Make (Network : S.NETWORK) = struct
);
send_abort t (Capnp_rpc.Exception.v ~ty:`Failed (Printexc.to_string ex))
end;
Log.info (fun f -> f ~tags "Connection closed");
Prometheus.Gauge.dec_one Metrics.connections;
Eio.Cancel.protect (fun () ->
disconnect t (Capnp_rpc.Exception.v ~ty:`Disconnected "Connection closed")
);
Expand Down
112 changes: 87 additions & 25 deletions capnp-rpc-net/endpoint.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,27 @@
open Eio.Std

module Metrics = struct
open Prometheus

let namespace = "capnp"

let subsystem = "net"

let connections =
let help = "Number of live capnp-rpc connections" in
Gauge.v ~help ~namespace ~subsystem "connections"

let messages_inbound_received_total =
let help = "Total number of messages received" in
Counter.v ~help ~namespace ~subsystem "messages_inbound_received_total"

let messages_outbound_enqueued_total =
let help = "Total number of messages enqueued to be transmitted" in
Counter.v ~help ~namespace ~subsystem "messages_outbound_enqueued_total"
end

module Write = Eio.Buf_write

let src = Logs.Src.create "endpoint" ~doc:"Send and receive Cap'n'Proto messages"
module Log = (val Logs.src_log src: Logs.LOG)

Expand All @@ -11,16 +33,20 @@ type flow = Eio.Flow.two_way_ty r

type t = {
flow : flow;
writer : Write.t;
decoder : Capnp.Codecs.FramedStream.t;
peer_id : Auth.Digest.t;
recv_buf : Cstruct.t;
}

let peer_id t = t.peer_id

let of_flow ~peer_id flow =
let decoder = Capnp.Codecs.FramedStream.empty compression in
let flow = (flow :> flow) in
{ flow; decoder; peer_id }
let writer = Write.create 4096 in
let recv_buf = Cstruct.create 4096 in
{ flow; writer; decoder; peer_id; recv_buf }

let dump_msg =
let next = ref 0 in
Expand All @@ -33,42 +59,78 @@ let dump_msg =
close_out ch

let send t msg =
let data = Capnp.Codecs.serialize ~compression msg in
if record_sent_messages then dump_msg data;
match Eio.Flow.copy_string data t.flow with
| ()
| exception End_of_file -> Ok ()
| exception (Eio.Io (Eio.Net.E Connection_reset _, _) as ex) ->
Log.info (fun f -> f "%a" Eio.Exn.pp ex);
Error `Closed
| exception ex ->
Eio.Fiber.check ();
Error (`Msg (Printexc.to_string ex))
Log.debug (fun f ->
let module M = Capnp_rpc.Private.Schema.MessageWrapper.Message in
f "queue_send: %d/%d allocated bytes in %d segs"
(M.total_size msg)
(M.total_alloc_size msg)
(M.num_segments msg));
Capnp.Codecs.serialize_iter_copyless ~compression msg ~f:(fun x len -> Write.string t.writer x ~len);
Prometheus.Counter.inc_one Metrics.messages_outbound_enqueued_total;
if record_sent_messages then dump_msg (Capnp.Codecs.serialize ~compression msg)

let rec recv t =
let rec recv ~tags t =
match Capnp.Codecs.FramedStream.get_next_frame t.decoder with
| Ok msg -> Ok (Capnp.BytesMessage.Message.readonly msg)
| Ok msg ->
Prometheus.Counter.inc_one Metrics.messages_inbound_received_total;
(* We often want to send multiple response messages while processing a batch of requests,
so pause the writer to collect them. We'll unpause on the next [single_read]. *)
Write.pause t.writer;
Ok (Capnp.BytesMessage.Message.readonly msg)
| Error Capnp.Codecs.FramingError.Unsupported -> failwith "Unsupported Cap'n'Proto frame received"
| Error Capnp.Codecs.FramingError.Incomplete ->
Log.debug (fun f -> f "Incomplete; waiting for more data...");
let buf = Cstruct.create 4096 in (* TODO: make this efficient *)
match Eio.Flow.single_read t.flow buf with
Log.debug (fun f -> f ~tags "Incomplete; waiting for more data...");
(* We probably scheduled one or more application fibers to run while handling the last
batch of messages. Give them a chance to run now while the writer is paused, because
they might want to send more messages immediately. *)
Fiber.yield ();
Write.unpause t.writer;
match Eio.Flow.single_read t.flow t.recv_buf with
| got ->
Log.debug (fun f -> f "Read %d bytes" got);
Capnp.Codecs.FramedStream.add_fragment t.decoder (Cstruct.to_string buf ~len:got);
recv t
Log.debug (fun f -> f ~tags "Read %d bytes" got);
Capnp.Codecs.FramedStream.add_fragment t.decoder (Cstruct.to_string t.recv_buf ~len:got);
recv ~tags t
| exception End_of_file ->
Log.info (fun f -> f "Connection closed");
Log.info (fun f -> f ~tags "Received end-of-stream");
Error `Closed
| exception (Eio.Io (Eio.Net.E Connection_reset _, _) as ex) ->
Log.info (fun f -> f "%a" Eio.Exn.pp ex);
Log.info (fun f -> f ~tags "Receive failed: %a" Eio.Exn.pp ex);
Error `Closed

let disconnect t =
try
Eio.Flow.shutdown t.flow `All
with
| Invalid_argument _
| Eio.Io (Eio.Net.E Connection_reset _, _) ->
with Eio.Io (Eio.Net.E Connection_reset _, _) ->
(* TCP connection already shut down, so TLS shutdown failed. Ignore. *)
()

let flush t =
Write.unpause t.writer;
(* Give the writer a chance to send the last of the data.
We could use [Write.flush] to be sure the data got sent, but this code is
only used to send aborts, which isn't very important and it's probably
better to drop the buffered messages if one yield isn't enough. *)
Fiber.yield ()

let rec run_writer ~tags t =
let bufs = Write.await_batch t.writer in
match Eio.Flow.single_write t.flow bufs with
| n -> Write.shift t.writer n; run_writer ~tags t
| exception (Eio.Io (Eio.Net.E Connection_reset _, _) as ex) ->
Log.info (fun f -> f ~tags "Send failed: %a" Eio.Exn.pp ex)
| exception ex ->
Eio.Fiber.check ();
Log.warn (fun f -> f ~tags "Error sending messages: %a (will shutdown connection)" Fmt.exn ex)

let run_writer ~tags t =
let cleanup () =
Prometheus.Gauge.dec_one Metrics.connections;
disconnect t (* The listen fiber will read end-of-stream soon *)
in
Prometheus.Gauge.inc_one Metrics.connections;
match run_writer ~tags t with
| () -> cleanup ()
| exception ex ->
let bt = Printexc.get_raw_backtrace () in
cleanup ();
Printexc.raise_with_backtrace ex bt
16 changes: 12 additions & 4 deletions capnp-rpc-net/endpoint.mli
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ val src : Logs.src
type t
(** A wrapper for a byte-stream (flow). *)

val send : t -> 'a Capnp.BytesMessage.Message.t -> (unit, [`Closed | `Msg of string]) result
(** [send t msg] transmits [msg]. *)
val send : t -> 'a Capnp.BytesMessage.Message.t -> unit
(** [send t msg] enqueues [msg]. *)

val recv : t -> (Capnp.Message.ro Capnp.BytesMessage.Message.t, [> `Closed]) result
(** [recv t] reads the next message from the remote peer.
val run_writer : tags:Logs.Tag.set -> t -> unit
(** [run_writer ~tags t] runs a loop that transmits batches of messages from [t].
It returns when the flow is closed. *)

val recv : tags:Logs.Tag.set -> t -> (Capnp.Message.ro Capnp.BytesMessage.Message.t, [> `Closed]) result
(** [recv ~tags t] reads the next message from the remote peer.
It returns [Error `Closed] if the connection to the peer is lost. *)

val of_flow : peer_id:Auth.Digest.t -> _ Eio.Flow.two_way -> t
Expand All @@ -19,6 +23,10 @@ val of_flow : peer_id:Auth.Digest.t -> _ Eio.Flow.two_way -> t
val peer_id : t -> Auth.Digest.t
(** [peer_id t] is the fingerprint of the peer's public key,
or [Auth.Digest.insecure] if TLS isn't being used. *)

val flush : t -> unit
(** [flush t] is useful to try to send any buffered data before disconnecting.
Otherwise, the final abort message is likely to get lost. *)

val disconnect : t -> unit
(** [disconnect t] shuts down the underlying flow. *)
Loading

0 comments on commit c95f619

Please sign in to comment.