diff --git a/benchmarks/message-io/incoming-message-stream.js b/benchmarks/message-io/incoming-message-stream.js new file mode 100644 index 000000000..14271a37a --- /dev/null +++ b/benchmarks/message-io/incoming-message-stream.js @@ -0,0 +1,43 @@ +const { createBenchmark } = require('../common'); +const { Readable } = require('stream'); + +const Debug = require('tedious/lib/debug'); +const IncomingMessageStream = require('tedious/lib/incoming-message-stream'); +const { Packet } = require('tedious/lib/packet'); + +const bench = createBenchmark(main, { + n: [100, 1000, 10000, 100000] +}); + +function main({ n }) { + const debug = new Debug(); + + const stream = Readable.from((async function*() { + for (let i = 0; i < n; i++) { + const packet = new Packet(2); + packet.last(true); + packet.addData(Buffer.from([1, 2, 3, 4, 5, 6, 7, 8, 9])); + + yield packet.buffer; + } + })()); + + const incoming = new IncomingMessageStream(debug); + stream.pipe(incoming); + + bench.start(); + console.profile('incoming-message-stream'); + + (async function() { + let total = 0; + + for await (m of incoming) { + for await (const buf of m) { + total += buf.length; + } + } + + console.profileEnd('incoming-message-stream'); + bench.end(n); + })(); +} diff --git a/benchmarks/message-io/outgoing-message-stream.js b/benchmarks/message-io/outgoing-message-stream.js new file mode 100644 index 000000000..7899d1e43 --- /dev/null +++ b/benchmarks/message-io/outgoing-message-stream.js @@ -0,0 +1,72 @@ +const { createBenchmark } = require('../common'); +const { Duplex } = require('stream'); + +const Debug = require('../../lib/debug'); +const OutgoingMessageStream = require('../../lib/outgoing-message-stream'); +const Message = require('../../lib/message'); + +const bench = createBenchmark(main, { + n: [100, 1000, 10000, 100000] +}); + +function main({ n }) { + const debug = new Debug(); + + const stream = new Duplex({ + read() {}, + write(chunk, encoding, callback) { + // Just consume the data + callback(); + } + }); + + const payload = [ + Buffer.alloc(1024), + Buffer.alloc(1024), + Buffer.alloc(1024), + Buffer.alloc(256), + Buffer.alloc(256), + Buffer.alloc(256), + Buffer.alloc(256), + ]; + + const out = new OutgoingMessageStream(debug, { + packetSize: 8 + 1024 + }); + out.pipe(stream); + + bench.start(); + console.profile('write-message'); + + function writeNextMessage(i) { + if (i == n) { + out.end(); + out.once('finish', () => { + console.profileEnd('write-message'); + bench.end(n); + }); + return; + } + + const m = new Message({ type: 2, resetConnection: false }); + out.write(m); + + for (const buf of payload) { + m.write(buf); + } + + m.end(); + + if (out.needsDrain) { + out.once('drain', () => { + writeNextMessage(i + 1); + }); + } else { + process.nextTick(() => { + writeNextMessage(i + 1); + }); + } + } + + writeNextMessage(0); +} diff --git a/benchmarks/message-io/read-message.js b/benchmarks/message-io/read-message.js new file mode 100644 index 000000000..413e6f47c --- /dev/null +++ b/benchmarks/message-io/read-message.js @@ -0,0 +1,39 @@ +const { createBenchmark } = require('../common'); +const { Readable } = require('stream'); + +const Debug = require('tedious/lib/debug'); +const MessageIO = require('tedious/lib/message-io'); +const { Packet } = require('tedious/lib/packet'); + +const bench = createBenchmark(main, { + n: [100, 1000, 10000, 100000] +}); + +function main({ n }) { + const debug = new Debug(); + + const stream = Readable.from((async function*() { + for (let i = 0; i < n; i++) { + const packet = new Packet(2); + packet.last(true); + packet.addData(Buffer.from([1, 2, 3, 4, 5, 6, 7, 8, 9])); + + yield packet.buffer; + } + })()); + + (async function() { + bench.start(); + console.profile('read-message'); + + let total = 0; + for (let i = 0; i < n; i++) { + for await (const chunk of MessageIO.readMessage(stream, debug)) { + total += chunk.length; + } + } + + console.profileEnd('read-message'); + bench.end(n); + })(); +} diff --git a/benchmarks/message-io/write-message.js b/benchmarks/message-io/write-message.js new file mode 100644 index 000000000..114df794f --- /dev/null +++ b/benchmarks/message-io/write-message.js @@ -0,0 +1,43 @@ +const { createBenchmark, createConnection } = require('../common'); +const { Duplex } = require('stream'); + +const Debug = require('tedious/lib/debug'); +const MessageIO = require('tedious/lib/message-io'); + +const bench = createBenchmark(main, { + n: [100, 1000, 10000, 100000] +}); + +function main({ n }) { + const debug = new Debug(); + + const stream = new Duplex({ + read() {}, + write(chunk, encoding, callback) { + // Just consume the data + callback(); + } + }); + + const payload = [ + Buffer.alloc(1024), + Buffer.alloc(1024), + Buffer.alloc(1024), + Buffer.alloc(256), + Buffer.alloc(256), + Buffer.alloc(256), + Buffer.alloc(256), + ]; + + (async function() { + bench.start(); + console.profile('write-message'); + + for (let i = 0; i <= n; i++) { + await MessageIO.writeMessage(stream, debug, 8 + 1024, 2, payload); + } + + console.profileEnd('write-message'); + bench.end(n); + })(); +} diff --git a/src/message-io.ts b/src/message-io.ts index 30a733ab6..28fa89489 100644 --- a/src/message-io.ts +++ b/src/message-io.ts @@ -1,6 +1,6 @@ import DuplexPair from 'native-duplexpair'; -import { Duplex } from 'stream'; +import { Duplex, type Readable, type Writable } from 'stream'; import * as tls from 'tls'; import { Socket } from 'net'; import { EventEmitter } from 'events'; @@ -8,10 +8,24 @@ import { EventEmitter } from 'events'; import Debug from './debug'; import Message from './message'; -import { TYPE } from './packet'; +import { HEADER_LENGTH, Packet, TYPE } from './packet'; import IncomingMessageStream from './incoming-message-stream'; import OutgoingMessageStream from './outgoing-message-stream'; +import { BufferList } from 'bl'; +import { ConnectionError } from './errors'; + +function withResolvers() { + let resolve: (value: T | PromiseLike) => void; + let reject: (reason?: any) => void; + + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + + return { resolve: resolve!, reject: reject!, promise }; +} class MessageIO extends EventEmitter { declare socket: Socket; @@ -182,6 +196,260 @@ class MessageIO extends EventEmitter { return result.value; } + + /** + * Write the given `payload` wrapped in TDS messages to the given `stream`. + * + * @param stream The stream to write the message to. + * @param debug The debug instance to use for logging. + * @param packetSize The maximum packet size to use. + * @param type The type of the message to write. + * @param payload The payload to write. + * @param resetConnection Whether the server should reset the connection after processing the message. + */ + static async writeMessage(stream: Writable, debug: Debug, packetSize: number, type: number, payload: AsyncIterable | Iterable, resetConnection = false) { + if (!stream.writable) { + throw new Error('Premature close'); + } + + let drainResolve: (() => void) | null = null; + let drainReject: ((reason?: any) => void) | null = null; + + function onDrain() { + if (drainResolve) { + const cb = drainResolve; + drainResolve = null; + drainReject = null; + cb(); + } + } + + const waitForDrain = () => { + let promise; + ({ promise, resolve: drainResolve, reject: drainReject } = withResolvers()); + return promise; + }; + + function onError(err: Error) { + if (drainReject) { + const cb = drainReject; + drainResolve = null; + drainReject = null; + cb(err); + } + } + + stream.on('drain', onDrain); + stream.on('close', onDrain); + stream.on('error', onError); + + try { + const bl = new BufferList(); + const length = packetSize - HEADER_LENGTH; + let packetNumber = 0; + + let isAsync; + let iterator; + + if ((payload as AsyncIterable)[Symbol.asyncIterator]) { + isAsync = true; + iterator = (payload as AsyncIterable)[Symbol.asyncIterator](); + } else { + isAsync = false; + iterator = (payload as Iterable)[Symbol.iterator](); + } + + while (true) { + try { + let value, done; + if (isAsync) { + ({ value, done } = await (iterator as AsyncIterator).next()); + } else { + ({ value, done } = (iterator as Iterator).next()); + } + + if (done) { + break; + } + + bl.append(value); + } catch (err) { + // If the stream is still writable, the error came from + // the payload. We will end the message with the ignore flag set. + if (stream.writable) { + const packet = new Packet(type); + packet.packetId(packetNumber += 1); + packet.resetConnection(resetConnection); + packet.last(true); + packet.ignore(true); + + debug.packet('Sent', packet); + debug.data(packet); + + if (stream.write(packet.buffer) === false) { + await waitForDrain(); + } + } + + throw err; + } + + while (bl.length > length) { + const data = bl.slice(0, length); + bl.consume(length); + + // TODO: Get rid of creating `Packet` instances here. + const packet = new Packet(type); + packet.packetId(packetNumber += 1); + packet.resetConnection(resetConnection); + packet.addData(data); + + debug.packet('Sent', packet); + debug.data(packet); + + if (stream.write(packet.buffer) === false) { + await waitForDrain(); + } + } + } + + const data = bl.slice(); + bl.consume(data.length); + + // TODO: Get rid of creating `Packet` instances here. + const packet = new Packet(type); + packet.packetId(packetNumber += 1); + packet.resetConnection(resetConnection); + packet.last(true); + packet.ignore(false); + packet.addData(data); + + debug.packet('Sent', packet); + debug.data(packet); + + if (stream.write(packet.buffer) === false) { + await waitForDrain(); + } + } finally { + stream.removeListener('drain', onDrain); + stream.removeListener('close', onDrain); + stream.removeListener('error', onError); + } + } + + /** + * Read the next TDS message from the given `stream`. + * + * This method returns an async generator that yields the data of the next message. + * The generator will throw an error if the stream is closed before the message is fully read. + * The generator will throw an error if the stream emits an error event. + * + * @param stream The stream to read the message from. + * @param debug The debug instance to use for logging. + * @returns An async generator that yields the data of the next message. + */ + static async *readMessage(stream: Readable, debug: Debug) { + if (!stream.readable) { + throw new Error('Premature close'); + } + + const bl = new BufferList(); + + let resolve: ((value: void | PromiseLike) => void) | null = null; + let reject: ((reason?: any) => void) | null = null; + + const waitForReadable = () => { + let promise; + ({ promise, resolve, reject } = withResolvers()); + return promise; + }; + + const onReadable = () => { + if (resolve) { + const cb = resolve; + resolve = null; + reject = null; + cb(); + } + }; + + const onError = (err: Error) => { + if (reject) { + const cb = reject; + resolve = null; + reject = null; + cb(err); + } + }; + + const onClose = () => { + if (reject) { + const cb = reject; + resolve = null; + reject = null; + cb(new Error('Premature close')); + } + }; + + stream.on('readable', onReadable); + stream.on('error', onError); + stream.on('close', onClose); + + try { + while (true) { + // Wait for the stream to become readable (or error out or close). + await waitForReadable(); + + let chunk: Buffer; + while ((chunk = stream.read()) !== null) { + bl.append(chunk); + + // The packet header is always 8 bytes of length. + while (bl.length >= HEADER_LENGTH) { + // Get the full packet length + const length = bl.readUInt16BE(2); + if (length < HEADER_LENGTH) { + throw new ConnectionError('Unable to process incoming packet'); + } + + if (bl.length >= length) { + const data = bl.slice(0, length); + bl.consume(length); + + // TODO: Get rid of creating `Packet` instances here. + const packet = new Packet(data); + debug.packet('Received', packet); + debug.data(packet); + + yield packet.data(); + + // Did the stream error while we yielded? + // if (error) { + // throw error; + // } + + if (packet.isLast()) { + // This was the last packet. Is there any data left in the buffer? + // If there is, this might be coming from the next message (e.g. a response to a `ATTENTION` + // message sent from the client while reading an incoming response). + // + // Put any remaining bytes back on the stream so we can read them on the next `readMessage` call. + if (bl.length) { + stream.unshift(bl.slice()); + } + + return; + } + } + } + } + } + } finally { + stream.removeListener('readable', onReadable); + stream.removeListener('close', onClose); + stream.removeListener('error', onError); + } + } } export default MessageIO; diff --git a/test/unit/message-io-test.ts b/test/unit/message-io-test.ts index b7f94c632..9b1f9b1b7 100644 --- a/test/unit/message-io-test.ts +++ b/test/unit/message-io-test.ts @@ -5,18 +5,297 @@ import { promisify } from 'util'; import DuplexPair from 'native-duplexpair'; import { TLSSocket } from 'tls'; import { readFileSync } from 'fs'; -import { Duplex } from 'stream'; +import { Duplex, Readable } from 'stream'; import Debug from '../../src/debug'; import MessageIO from '../../src/message-io'; import Message from '../../src/message'; import { Packet, TYPE } from '../../src/packet'; +import { BufferListStream } from 'bl'; const packetType = 2; const packetSize = 8 + 4; const delay = promisify(setTimeout); +function assertNoDanglingEventListeners(stream: Duplex) { + assert.strictEqual(stream.listenerCount('error'), 0); + assert.strictEqual(stream.listenerCount('drain'), 0); +} + +describe('MessageIO.writeMessage', function() { + let debug: Debug; + + beforeEach(function() { + debug = new Debug(); + }); + + it('wraps the given packet contents into a TDS packet and writes it to the given stream', async function() { + const payload = Buffer.from([1, 2, 3]); + const stream = new BufferListStream(); + + await MessageIO.writeMessage(stream, debug, packetSize, packetType, [ payload ]); + + const buf = stream.read(); + assert.instanceOf(buf, Buffer); + + const packet = new Packet(buf); + assert.strictEqual(packet.type(), packetType); + assert.strictEqual(packet.length(), payload.length + 8); + assert.strictEqual(packet.statusAsString(), 'EOM'); + assert.isTrue(packet.isLast()); + assert.deepEqual(packet.data(), payload); + + assert.isNull(stream.read()); + }); + + it('handles errors while iterating over the payload', async function() { + const payload = Buffer.from([1, 2, 3]); + const stream = new BufferListStream(); + + let hadError = false; + try { + await MessageIO.writeMessage(stream, debug, packetSize, packetType, (async function*() { + yield payload; + throw new Error('iteration error'); + })()); + } catch (err: any) { + hadError = true; + + assert.instanceOf(err, Error); + assert.strictEqual(err.message, 'iteration error'); + } + + assert(hadError); + assertNoDanglingEventListeners(stream); + }); + + it('handles errors while iterating over the payload, while the stream is waiting for drain', async function() { + const payload = Buffer.from([1, 2, 3, 4]); + + const callbacks: Array<() => void> = []; + const stream = new Duplex({ + write(chunk, encoding, callback) { + // Collect all callbacks so that we can simulate draining the stream later + callbacks.push(callback); + }, + read() {}, + + // instantly return false on write requests to indicate that the stream needs to drain + highWaterMark: 1 + }); + + let hadError = false; + try { + await MessageIO.writeMessage(stream, debug, packetSize, packetType, (async function*() { + yield payload; + + // Simulate draining the stream after the exception was thrown + setTimeout(() => { + let cb; + while (cb = callbacks.shift()) { + cb(); + } + }, 100); + + throw new Error('iteration error'); + })()); + } catch (err: any) { + hadError = true; + + assert.instanceOf(err, Error); + assert.strictEqual(err.message, 'iteration error'); + } + + assert(hadError); + assertNoDanglingEventListeners(stream); + }); + + it('handles errors on the stream while handling errors from the payload while waiting for the stream to drain', async function() { + const payload = Buffer.from([1, 2, 3, 4]); + + const stream = new Duplex({ + write(chunk, encoding, callback) { + // never call the callback so that the stream never drains + }, + read() {}, + + // instantly return false on write requests to indicate that the stream needs to drain + highWaterMark: 1 + }); + + setTimeout(() => { + assert(stream.writableNeedDrain); + stream.destroy(new Error('write error')); + }, 100); + + let hadError = false; + try { + await MessageIO.writeMessage(stream, debug, packetSize, packetType, (async function*() { + yield payload; + + // Simulate an error on the stream after an error from the payload + setTimeout(() => { + stream.destroy(new Error('write error')); + }, 100); + + throw new Error('iteration error'); + })()); + } catch (err: any) { + hadError = true; + + assert.instanceOf(err, Error); + assert.strictEqual(err.message, 'write error'); + } + + assert(hadError); + assertNoDanglingEventListeners(stream); + }); + + it('handles errors on the stream during writing', async function() { + const payload = Buffer.from([1, 2, 3]); + const stream = new Duplex({ + write(chunk, encoding, callback) { + callback(new Error('write error')); + }, + read() {} + }); + + let hadError = false; + try { + await MessageIO.writeMessage(stream, debug, packetSize, packetType, [ payload ]); + } catch (err: any) { + hadError = true; + + assert.instanceOf(err, Error); + assert.strictEqual(err.message, 'write error'); + } + + assert(hadError); + assertNoDanglingEventListeners(stream); + }); + + it('handles errors on the stream while waiting for the stream to drain', async function() { + const payload = Buffer.from([1, 2, 3]); + const stream = new Duplex({ + write(chunk, encoding, callback) { + // never call callback so that the stream never drains + }, + read() {}, + + // instantly return false on write requests to indicate that the stream needs to drain + highWaterMark: 1 + }); + + setTimeout(() => { + assert(stream.writableNeedDrain); + stream.destroy(new Error('write error')); + }, 100); + + let hadError = false; + try { + await MessageIO.writeMessage(stream, debug, packetSize, packetType, [ payload, payload, payload ]); + } catch (err: any) { + hadError = true; + + assert.instanceOf(err, Error); + assert.strictEqual(err.message, 'write error'); + } + + assert(hadError); + assertNoDanglingEventListeners(stream); + }); + + it('handles errors on the stream while waiting for more data to be written', async function() { + const payload = Buffer.from([1, 2, 3]); + const stream = new Duplex({ + write(chunk, encoding, callback) { + // never call callback so that the stream never drains + }, + read() {}, + + // instantly return false on write requests to indicate that the stream needs to drain + highWaterMark: 1 + }); + + setTimeout(() => { + assert(stream.writableNeedDrain); + stream.destroy(new Error('write error')); + }, 100); + + let hadError = false; + try { + await MessageIO.writeMessage(stream, debug, packetSize, packetType, (async function*() { + yield payload; + yield payload; + yield payload; + })()); + } catch (err: any) { + hadError = true; + + assert.instanceOf(err, Error); + assert.strictEqual(err.message, 'write error'); + } + + assert(hadError); + assertNoDanglingEventListeners(stream); + }); +}); + +describe('MessageIO.readMessage', function() { + let debug: Debug; + + beforeEach(function() { + debug = new Debug(); + }); + + it('reads a TDS packet from the given stream and returns its contents', async function() { + const payload = Buffer.from([1, 2, 3]); + const packet = new Packet(packetType); + packet.last(true); + packet.addData(payload); + + const stream = new BufferListStream(); + stream.write(packet.buffer); + + const message = MessageIO.readMessage(stream, debug); + + const chunks = []; + for await (const chunk of message) { + chunks.push(chunk); + } + + assert.deepEqual(chunks, [ payload ]); + }); + + it('handles errors while reading from the stream', async function() { + const payload = Buffer.from([1, 2, 3]); + const packet = new Packet(packetType); + packet.last(true); + packet.addData(payload); + + const stream = Readable.from((async function*() { + throw new Error('read error'); + })()); + + let hadError = false; + + const chunks = []; + try { + for await (const message of MessageIO.readMessage(stream, debug)) { + chunks.push(message); + } + } catch (err: any) { + hadError = true; + + assert.instanceOf(err, Error); + assert.strictEqual(err.message, 'read error'); + } + + assert(hadError); + }); +}); + describe('MessageIO', function() { let server: Server; let serverConnection: Socket;