fix(stream): let PushReadableStream handle cancelled streams.

Also add tests

Relates to #648
This commit is contained in:
Simon Chan 2024-07-11 17:25:39 +08:00
parent a1c6450b2f
commit fb4507ddc5
No known key found for this signature in database
GPG key ID: A8B69F750B9BCEDD
10 changed files with 1267 additions and 100 deletions

View file

@ -0,0 +1,151 @@
import { describe, expect, it, jest } from "@jest/globals";
import { PromiseResolver } from "@yume-chan/async";
import { ReadableStream, WritableStream } from "@yume-chan/stream-extra";
import type { AdbSocket } from "../../../adb.js";
import { AdbSubprocessNoneProtocol } from "./none.js";
describe("AdbSubprocessNoneProtocol", () => {
describe("stdout", () => {
it("should pipe data from `socket`", async () => {
const closed = new PromiseResolver<void>();
const socket: AdbSocket = {
service: "",
close: jest.fn(() => {}),
closed: closed.promise,
readable: new ReadableStream({
async start(controller) {
controller.enqueue(new Uint8Array([1, 2, 3]));
controller.enqueue(new Uint8Array([4, 5, 6]));
await closed.promise;
controller.close();
},
}),
writable: new WritableStream(),
};
const process = new AdbSubprocessNoneProtocol(socket);
const reader = process.stdout.getReader();
await expect(reader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([1, 2, 3]),
});
await expect(reader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([4, 5, 6]),
});
});
it("should close when `socket` is closed", async () => {
const closed = new PromiseResolver<void>();
const socket: AdbSocket = {
service: "",
close: jest.fn(() => {}),
closed: closed.promise,
readable: new ReadableStream({
async start(controller) {
controller.enqueue(new Uint8Array([1, 2, 3]));
controller.enqueue(new Uint8Array([4, 5, 6]));
await closed.promise;
controller.close();
},
}),
writable: new WritableStream(),
};
const process = new AdbSubprocessNoneProtocol(socket);
const reader = process.stdout.getReader();
await expect(reader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([1, 2, 3]),
});
await expect(reader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([4, 5, 6]),
});
closed.resolve();
await expect(reader.read()).resolves.toEqual({
done: true,
});
});
});
describe("stderr", () => {
it("should be empty", async () => {
const closed = new PromiseResolver<void>();
const socket: AdbSocket = {
service: "",
close: jest.fn(() => {}),
closed: closed.promise,
readable: new ReadableStream({
async start(controller) {
controller.enqueue(new Uint8Array([1, 2, 3]));
controller.enqueue(new Uint8Array([4, 5, 6]));
await closed.promise;
controller.close();
},
}),
writable: new WritableStream(),
};
const process = new AdbSubprocessNoneProtocol(socket);
const reader = process.stderr.getReader();
closed.resolve();
await expect(reader.read()).resolves.toEqual({ done: true });
});
});
describe("exit", () => {
it("should resolve when `socket` closes", async () => {
const closed = new PromiseResolver<void>();
const socket: AdbSocket = {
service: "",
close: jest.fn(() => {}),
closed: closed.promise,
readable: new ReadableStream(),
writable: new WritableStream(),
};
const process = new AdbSubprocessNoneProtocol(socket);
closed.resolve();
await expect(process.exit).resolves.toBe(0);
});
});
it("`resize` shouldn't throw any error", () => {
const socket: AdbSocket = {
service: "",
close: jest.fn(() => {}),
closed: new PromiseResolver<void>().promise,
readable: new ReadableStream(),
writable: new WritableStream(),
};
const process = new AdbSubprocessNoneProtocol(socket);
expect(() => process.resize()).not.toThrow();
});
it("`kill` should close `socket`", async () => {
const close = jest.fn(() => {});
const socket: AdbSocket = {
service: "",
close,
closed: new PromiseResolver<void>().promise,
readable: new ReadableStream(),
writable: new WritableStream(),
};
const process = new AdbSubprocessNoneProtocol(socket);
await process.kill();
expect(close).toHaveBeenCalledTimes(1);
});
});

View file

@ -0,0 +1,205 @@
import { describe, expect, it } from "@jest/globals";
import { PromiseResolver } from "@yume-chan/async";
import {
ReadableStream,
type ReadableStreamDefaultController,
WritableStream,
} from "@yume-chan/stream-extra";
import { type AdbSocket } from "../../../adb.js";
import {
AdbShellProtocolId,
AdbShellProtocolPacket,
AdbSubprocessShellProtocol,
} from "./shell.js";
function createMockSocket(
readable: (controller: ReadableStreamDefaultController<Uint8Array>) => void,
): [AdbSocket, PromiseResolver<void>] {
const closed = new PromiseResolver<void>();
const socket: AdbSocket = {
service: "",
close() {},
closed: closed.promise,
readable: new ReadableStream({
async start(controller) {
controller.enqueue(
AdbShellProtocolPacket.serialize({
id: AdbShellProtocolId.Stdout,
data: new Uint8Array([1, 2, 3]),
}),
);
controller.enqueue(
AdbShellProtocolPacket.serialize({
id: AdbShellProtocolId.Stderr,
data: new Uint8Array([4, 5, 6]),
}),
);
await closed.promise;
readable(controller);
},
}),
writable: new WritableStream(),
};
return [socket, closed];
}
describe("AdbSubprocessShellProtocol", () => {
describe("`stdout` and `stderr`", () => {
it("should parse data from `socket", async () => {
const [socket] = createMockSocket(() => {});
const process = new AdbSubprocessShellProtocol(socket);
const stdoutReader = process.stdout.getReader();
const stderrReader = process.stderr.getReader();
await expect(stdoutReader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([1, 2, 3]),
});
await expect(stderrReader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([4, 5, 6]),
});
});
it("should be able to be cancelled", async () => {
const [socket, closed] = createMockSocket((controller) => {
controller.enqueue(
AdbShellProtocolPacket.serialize({
id: AdbShellProtocolId.Stdout,
data: new Uint8Array([7, 8, 9]),
}),
);
controller.enqueue(
AdbShellProtocolPacket.serialize({
id: AdbShellProtocolId.Stderr,
data: new Uint8Array([10, 11, 12]),
}),
);
});
const process = new AdbSubprocessShellProtocol(socket);
const stdoutReader = process.stdout.getReader();
const stderrReader = process.stderr.getReader();
stdoutReader.cancel();
closed.resolve();
await expect(stderrReader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([4, 5, 6]),
});
await expect(stderrReader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([10, 11, 12]),
});
});
});
describe("`socket` close", () => {
describe("with `exit` message", () => {
it("should close `stdout`, `stderr` and resolve `exit`", async () => {
const [socket, closed] = createMockSocket(
async (controller) => {
controller.enqueue(
AdbShellProtocolPacket.serialize({
id: AdbShellProtocolId.Exit,
data: new Uint8Array([42]),
}),
);
controller.close();
},
);
const process = new AdbSubprocessShellProtocol(socket);
const stdoutReader = process.stdout.getReader();
const stderrReader = process.stderr.getReader();
await expect(stdoutReader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([1, 2, 3]),
});
await expect(stderrReader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([4, 5, 6]),
});
closed.resolve();
await expect(stdoutReader.read()).resolves.toEqual({
done: true,
});
await expect(stderrReader.read()).resolves.toEqual({
done: true,
});
await expect(process.exit).resolves.toBe(42);
});
});
describe("with no `exit` message", () => {
it("should close `stdout`, `stderr` and reject `exit`", async () => {
const [socket, closed] = createMockSocket((controller) => {
controller.close();
});
const process = new AdbSubprocessShellProtocol(socket);
const stdoutReader = process.stdout.getReader();
const stderrReader = process.stderr.getReader();
await expect(stdoutReader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([1, 2, 3]),
});
await expect(stderrReader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([4, 5, 6]),
});
closed.resolve();
await Promise.all([
expect(stdoutReader.read()).resolves.toEqual({
done: true,
}),
expect(stderrReader.read()).resolves.toEqual({
done: true,
}),
expect(process.exit).rejects.toThrow(),
]);
});
});
});
describe("`socket.readable` invalid data", () => {
it("should error `stdout`, `stderr` and reject `exit`", async () => {
const [socket, closed] = createMockSocket(async (controller) => {
controller.enqueue(new Uint8Array([7, 8, 9]));
controller.close();
});
const process = new AdbSubprocessShellProtocol(socket);
const stdoutReader = process.stdout.getReader();
const stderrReader = process.stderr.getReader();
await expect(stdoutReader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([1, 2, 3]),
});
await expect(stderrReader.read()).resolves.toEqual({
done: false,
value: new Uint8Array([4, 5, 6]),
});
closed.resolve();
await Promise.all([
expect(stdoutReader.read()).rejects.toThrow(),
expect(stderrReader.read()).rejects.toThrow(),
expect(process.exit).rejects.toThrow(),
]);
});
});
});

View file

@ -28,8 +28,8 @@ export enum AdbShellProtocolId {
WindowSizeChange, WindowSizeChange,
} }
// This packet format is used in both direction. // This packet format is used in both directions.
const AdbShellProtocolPacket = new Struct({ littleEndian: true }) export const AdbShellProtocolPacket = new Struct({ littleEndian: true })
.uint8("id", placeholder<AdbShellProtocolId>()) .uint8("id", placeholder<AdbShellProtocolId>())
.uint32("length") .uint32("length")
.uint8Array("data", { lengthField: "length" }); .uint8Array("data", { lengthField: "length" });
@ -107,14 +107,10 @@ export class AdbSubprocessShellProtocol implements AdbSubprocessProtocol {
this.#exit.resolve(chunk.data[0]!); this.#exit.resolve(chunk.data[0]!);
break; break;
case AdbShellProtocolId.Stdout: case AdbShellProtocolId.Stdout:
if (!stdoutController.abortSignal.aborted) { await stdoutController.enqueue(chunk.data);
await stdoutController.enqueue(chunk.data);
}
break; break;
case AdbShellProtocolId.Stderr: case AdbShellProtocolId.Stderr:
if (!stderrController.abortSignal.aborted) { await stderrController.enqueue(chunk.data);
await stderrController.enqueue(chunk.data);
}
break; break;
} }
}, },

View file

@ -138,20 +138,7 @@ export class AdbDaemonSocketController
} }
async enqueue(data: Uint8Array) { async enqueue(data: Uint8Array) {
// Consumers can `cancel` the `readable` if they are not interested in future data. await this.#readableController.enqueue(data);
// Throw away the data if that happens.
if (this.#readableController.abortSignal.aborted) {
return;
}
try {
await this.#readableController.enqueue(data);
} catch (e) {
if (this.#readableController.abortSignal.aborted) {
return;
}
throw e;
}
} }
public ack(bytes: number) { public ack(bytes: number) {
@ -182,23 +169,13 @@ export class AdbDaemonSocketController
} }
dispose() { dispose() {
try { this.#readableController.close();
this.#readableController.close();
} catch {
// ignore
}
this.#closedPromise.resolve(); this.#closedPromise.resolve();
} }
} }
/** /**
* A duplex stream representing a socket to ADB daemon. * A duplex stream representing a socket to ADB daemon.
*
* To close it, call either `socket.close()`,
* `socket.readable.cancel()`, `socket.readable.getReader().cancel()`,
* `socket.writable.abort()`, `socket.writable.getWriter().abort()`,
* `socket.writable.close()` or `socket.writable.getWriter().close()`.
*/ */
export class AdbDaemonSocket implements AdbDaemonSocketInfo, AdbSocket { export class AdbDaemonSocket implements AdbDaemonSocketInfo, AdbSocket {
#controller: AdbDaemonSocketController; #controller: AdbDaemonSocketController;

View file

@ -125,11 +125,7 @@ class AdbServerStream {
async dispose() { async dispose() {
await this.#buffered.cancel().catch(NOOP); await this.#buffered.cancel().catch(NOOP);
await this.#writer.close().catch(NOOP); await this.#writer.close().catch(NOOP);
try { await this.#connection.close();
await this.#connection.close();
} catch {
// ignore
}
} }
} }

View file

@ -28,8 +28,12 @@ import type {
ScrcpyVideoStreamMetadata, ScrcpyVideoStreamMetadata,
} from "../codec.js"; } from "../codec.js";
import { ScrcpyVideoCodecId } from "../codec.js"; import { ScrcpyVideoCodecId } from "../codec.js";
import type { ScrcpyDisplay, ScrcpyEncoder } from "../types.js"; import type { ScrcpyDisplay } from "../types.js";
import { ScrcpyOptions, toScrcpyOptionValue } from "../types.js"; import {
ScrcpyOptions,
ScrcpyOptions0_00,
toScrcpyOptionValue,
} from "../types.js";
import { CodecOptions } from "./codec-options.js"; import { CodecOptions } from "./codec-options.js";
import type { ScrcpyOptionsInit1_16 } from "./init.js"; import type { ScrcpyOptionsInit1_16 } from "./init.js";
@ -123,7 +127,7 @@ export class ScrcpyOptions1_16 extends ScrcpyOptions<ScrcpyOptionsInit1_16> {
} }
constructor(init: ScrcpyOptionsInit1_16) { constructor(init: ScrcpyOptionsInit1_16) {
super(undefined, init, ScrcpyOptions1_16.DEFAULTS); super(ScrcpyOptions0_00, init, ScrcpyOptions1_16.DEFAULTS);
this.#clipboard = new PushReadableStream<string>((controller) => { this.#clipboard = new PushReadableStream<string>((controller) => {
this.#clipboardController = controller; this.#clipboardController = controller;
}); });
@ -136,10 +140,6 @@ export class ScrcpyOptions1_16 extends ScrcpyOptions<ScrcpyOptionsInit1_16> {
); );
} }
override setListEncoders(): void {
throw new Error("Not supported");
}
override setListDisplays(): void { override setListDisplays(): void {
// Set to an invalid value // Set to an invalid value
// Server will print valid values before crashing // Server will print valid values before crashing
@ -147,10 +147,6 @@ export class ScrcpyOptions1_16 extends ScrcpyOptions<ScrcpyOptionsInit1_16> {
this.value.displayId = -1; this.value.displayId = -1;
} }
override parseEncoder(): ScrcpyEncoder | undefined {
throw new Error("Not supported");
}
override parseDisplay(line: string): ScrcpyDisplay | undefined { override parseDisplay(line: string): ScrcpyDisplay | undefined {
const match = line.match(/^\s+scrcpy --display (\d+)$/); const match = line.match(/^\s+scrcpy --display (\d+)$/);
if (match) { if (match) {
@ -185,43 +181,28 @@ export class ScrcpyOptions1_16 extends ScrcpyOptions<ScrcpyOptionsInit1_16> {
async #parseClipboardMessage(stream: AsyncExactReadable) { async #parseClipboardMessage(stream: AsyncExactReadable) {
const message = await ScrcpyClipboardDeviceMessage.deserialize(stream); const message = await ScrcpyClipboardDeviceMessage.deserialize(stream);
// Allow `clipboard.cancel()` to discard messages await this.#clipboardController.enqueue(message.content);
if (!this.#clipboardController.abortSignal.aborted) {
await this.#clipboardController.enqueue(message.content);
}
} }
override async parseDeviceMessage( override async parseDeviceMessage(
id: number, id: number,
stream: AsyncExactReadable, stream: AsyncExactReadable,
): Promise<void> { ): Promise<void> {
try { switch (id) {
switch (id) { case 0:
case 0: await this.#parseClipboardMessage(stream);
await this.#parseClipboardMessage(stream); break;
break; default:
default: await super.parseDeviceMessage(id, stream);
throw new Error(`Unknown device message type ${id}`); break;
}
} catch (e) {
try {
this.#clipboardController.error(e);
} catch {
// The stream is already errored
}
throw e;
} }
} }
override endDeviceMessageStream(e?: unknown): void { override endDeviceMessageStream(e?: unknown): void {
try { if (e) {
if (e) { this.#clipboardController.error(e);
this.#clipboardController.error(e); } else {
} else { this.#clipboardController.close();
this.#clipboardController.close();
}
} catch {
// The stream is already errored
} }
} }

View file

@ -115,9 +115,10 @@ export abstract class ScrcpyOptions<T extends object> {
this.value = value as Required<T>; this.value = value as Required<T>;
if (Base !== undefined) { if (Base !== undefined) {
// `value` might be incompatible with `Base`, // `value` can be incompatible with `Base`,
// but the derive class must ensure the incompatible values are not used by base class, // as long as the derived class handles the incompatibility,
// and only the `setListXXX` methods in base class will modify the value, // (and ensure the incompatible values are not used in `Base`).
// On other hand, only the `setListXXX` methods in `Base` will modify `value`,
// which is common to all versions. // which is common to all versions.
// //
// `Base` is a derived class of `ScrcpyOptions`, its constructor will call // `Base` is a derived class of `ScrcpyOptions`, its constructor will call
@ -221,3 +222,32 @@ export abstract class ScrcpyOptions<T extends object> {
return this.#base.createScrollController(); return this.#base.createScrollController();
} }
} }
/**
* Blanket implementation of unsupported features in ScrcpyOptions1_16
*/
export class ScrcpyOptions0_00 extends ScrcpyOptions<never> {
get defaults(): Required<never> {
throw new Error("Not supported");
}
serialize(): string[] {
throw new Error("Not supported");
}
constructor(init: never) {
super(undefined, init, {} as never);
}
override setListEncoders(): void {
throw new Error("Not supported");
}
override parseEncoder(): ScrcpyEncoder | undefined {
throw new Error("Not supported");
}
override parseDeviceMessage(id: number): Promise<void> {
throw new Error(`Unknown device message type ${id}`);
}
}

View file

@ -142,9 +142,9 @@ export class BufferedReadableStream implements AsyncExactReadable {
const { done, value } = await this.reader.read(); const { done, value } = await this.reader.read();
if (done) { if (done) {
return; return;
} else {
await controller.enqueue(value);
} }
await controller.enqueue(value);
} }
}); });
} else { } else {

View file

@ -0,0 +1,637 @@
import { describe, expect, it, jest } from "@jest/globals";
import { delay } from "@yume-chan/async";
import type { PushReadableStreamController } from "./push-readable.js";
import { PushReadableStream } from "./push-readable.js";
describe("PushReadableStream", () => {
describe(".cancel", () => {
it("should abort the `AbortSignal`", async () => {
const abortHandler = jest.fn();
const stream = new PushReadableStream((controller) => {
controller.abortSignal.addEventListener("abort", abortHandler);
});
await stream.cancel("reason");
expect(abortHandler).toHaveBeenCalledTimes(1);
});
it("should ignore pending `enqueue`", async () => {
const log = jest.fn();
const stream = new PushReadableStream(
async (controller) => {
await controller.enqueue(1);
await controller.enqueue(2);
},
undefined,
log,
);
const reader = stream.getReader();
await delay(0);
await reader.cancel("reason");
expect(log.mock.calls).toMatchInlineSnapshot(`
[
[
{
"operation": "enqueue",
"phase": "start",
"source": "producer",
"value": 1,
},
],
[
{
"operation": "enqueue",
"phase": "complete",
"source": "producer",
"value": 1,
},
],
[
{
"operation": "enqueue",
"phase": "start",
"source": "producer",
"value": 2,
},
],
[
{
"operation": "enqueue",
"phase": "waiting",
"source": "producer",
"value": 2,
},
],
[
{
"operation": "cancel",
"phase": "start",
"source": "consumer",
},
],
[
{
"operation": "cancel",
"phase": "complete",
"source": "consumer",
},
],
[
{
"operation": "enqueue",
"phase": "ignored",
"source": "producer",
"value": 2,
},
],
[
{
"explicit": false,
"operation": "close",
"phase": "start",
"source": "producer",
},
],
[
{
"explicit": false,
"operation": "close",
"phase": "ignored",
"source": "producer",
},
],
]
`);
});
it("should ignore future `enqueue`", async () => {
const log = jest.fn();
const stream = new PushReadableStream(
async (controller) => {
await controller.enqueue(1);
await controller.enqueue(2);
await controller.enqueue(3);
},
undefined,
log,
);
const reader = stream.getReader();
await delay(1);
await reader.cancel("reason");
// Add extra microtasks to allow all operations to complete
await delay(1);
expect(log.mock.calls).toMatchInlineSnapshot(`
[
[
{
"operation": "enqueue",
"phase": "start",
"source": "producer",
"value": 1,
},
],
[
{
"operation": "enqueue",
"phase": "complete",
"source": "producer",
"value": 1,
},
],
[
{
"operation": "enqueue",
"phase": "start",
"source": "producer",
"value": 2,
},
],
[
{
"operation": "enqueue",
"phase": "waiting",
"source": "producer",
"value": 2,
},
],
[
{
"operation": "cancel",
"phase": "start",
"source": "consumer",
},
],
[
{
"operation": "cancel",
"phase": "complete",
"source": "consumer",
},
],
[
{
"operation": "enqueue",
"phase": "ignored",
"source": "producer",
"value": 2,
},
],
[
{
"operation": "enqueue",
"phase": "start",
"source": "producer",
"value": 3,
},
],
[
{
"operation": "enqueue",
"phase": "ignored",
"source": "producer",
"value": 3,
},
],
[
{
"explicit": false,
"operation": "close",
"phase": "start",
"source": "producer",
},
],
[
{
"explicit": false,
"operation": "close",
"phase": "ignored",
"source": "producer",
},
],
]
`);
});
it("should allow explicit `close` call", async () => {
const log = jest.fn();
const stream = new PushReadableStream(
async (controller) => {
await controller.enqueue(1);
await controller.enqueue(2);
controller.close();
},
undefined,
log,
);
const reader = stream.getReader();
await delay(1);
await reader.cancel("reason");
expect(log.mock.calls).toMatchInlineSnapshot(`
[
[
{
"operation": "enqueue",
"phase": "start",
"source": "producer",
"value": 1,
},
],
[
{
"operation": "enqueue",
"phase": "complete",
"source": "producer",
"value": 1,
},
],
[
{
"operation": "enqueue",
"phase": "start",
"source": "producer",
"value": 2,
},
],
[
{
"operation": "enqueue",
"phase": "waiting",
"source": "producer",
"value": 2,
},
],
[
{
"operation": "cancel",
"phase": "start",
"source": "consumer",
},
],
[
{
"operation": "cancel",
"phase": "complete",
"source": "consumer",
},
],
[
{
"operation": "enqueue",
"phase": "ignored",
"source": "producer",
"value": 2,
},
],
[
{
"explicit": true,
"operation": "close",
"phase": "start",
"source": "producer",
},
],
[
{
"explicit": true,
"operation": "close",
"phase": "ignored",
"source": "producer",
},
],
[
{
"explicit": false,
"operation": "close",
"phase": "start",
"source": "producer",
},
],
[
{
"explicit": false,
"operation": "close",
"phase": "ignored",
"source": "producer",
},
],
]
`);
});
});
describe(".error", () => {
it("should reject future `enqueue`", async () => {
let controller!: PushReadableStreamController<unknown>;
new PushReadableStream((controller_) => {
controller = controller_;
});
controller.error(new Error("error"));
await expect(controller.enqueue(1)).rejects.toThrow();
});
it("should reject future `close`", () => {
let controller!: PushReadableStreamController<unknown>;
new PushReadableStream((controller_) => {
controller = controller_;
});
controller.error(new Error("error"));
expect(() => controller.close()).toThrow();
});
});
describe("0 high water mark", () => {
it("should allow `read` before `enqueue`", async () => {
const log = jest.fn();
let controller!: PushReadableStreamController<unknown>;
const stream = new PushReadableStream(
(controller_) => {
controller = controller_;
},
{ highWaterMark: 0 },
log,
);
const reader = stream.getReader();
const promise = reader.read();
await delay(1);
await controller.enqueue(1);
await expect(promise).resolves.toEqual({ done: false, value: 1 });
expect(log.mock.calls).toMatchInlineSnapshot(`
[
[
{
"operation": "pull",
"phase": "start",
"source": "consumer",
},
],
[
{
"operation": "pull",
"phase": "complete",
"source": "consumer",
},
],
[
{
"operation": "enqueue",
"phase": "start",
"source": "producer",
"value": 1,
},
],
[
{
"operation": "enqueue",
"phase": "complete",
"source": "producer",
"value": 1,
},
],
]
`);
});
it("should allow `enqueue` before `read`", async () => {
const log = jest.fn();
const stream = new PushReadableStream(
async (controller) => {
await controller.enqueue(1);
},
{ highWaterMark: 0 },
log,
);
const reader = stream.getReader();
await expect(reader.read()).resolves.toEqual({
done: false,
value: 1,
});
expect(log.mock.calls).toMatchInlineSnapshot(`
[
[
{
"operation": "enqueue",
"phase": "start",
"source": "producer",
"value": 1,
},
],
[
{
"operation": "enqueue",
"phase": "waiting",
"source": "producer",
"value": 1,
},
],
[
{
"operation": "pull",
"phase": "start",
"source": "consumer",
},
],
[
{
"operation": "pull",
"phase": "complete",
"source": "consumer",
},
],
[
{
"operation": "enqueue",
"phase": "complete",
"source": "producer",
"value": 1,
},
],
[
{
"explicit": false,
"operation": "close",
"phase": "start",
"source": "producer",
},
],
[
{
"explicit": false,
"operation": "close",
"phase": "complete",
"source": "producer",
},
],
]
`);
});
});
describe("non 0 high water mark", () => {
it("should allow `read` before `enqueue`", async () => {
const log = jest.fn();
let controller!: PushReadableStreamController<unknown>;
const stream = new PushReadableStream(
(controller_) => {
controller = controller_;
},
{ highWaterMark: 1 },
log,
);
const reader = stream.getReader();
const promise = reader.read();
await delay(1);
await controller.enqueue(1);
await expect(promise).resolves.toEqual({ done: false, value: 1 });
expect(log.mock.calls).toMatchInlineSnapshot(`
[
[
{
"operation": "pull",
"phase": "start",
"source": "consumer",
},
],
[
{
"operation": "pull",
"phase": "complete",
"source": "consumer",
},
],
[
{
"operation": "enqueue",
"phase": "start",
"source": "producer",
"value": 1,
},
],
[
{
"operation": "pull",
"phase": "start",
"source": "consumer",
},
],
[
{
"operation": "pull",
"phase": "complete",
"source": "consumer",
},
],
[
{
"operation": "enqueue",
"phase": "complete",
"source": "producer",
"value": 1,
},
],
]
`);
});
it("should allow `enqueue` before `read`", async () => {
const log = jest.fn();
const stream = new PushReadableStream(
async (controller) => {
await controller.enqueue(1);
},
{ highWaterMark: 1 },
log,
);
const reader = stream.getReader();
await expect(reader.read()).resolves.toEqual({
done: false,
value: 1,
});
expect(log.mock.calls).toMatchInlineSnapshot(`
[
[
{
"operation": "enqueue",
"phase": "start",
"source": "producer",
"value": 1,
},
],
[
{
"operation": "enqueue",
"phase": "complete",
"source": "producer",
"value": 1,
},
],
[
{
"operation": "pull",
"phase": "start",
"source": "consumer",
},
],
[
{
"operation": "pull",
"phase": "complete",
"source": "consumer",
},
],
[
{
"explicit": false,
"operation": "close",
"phase": "start",
"source": "producer",
},
],
[
{
"explicit": false,
"operation": "close",
"phase": "complete",
"source": "producer",
},
],
]
`);
});
});
describe("async `source`", () => {
it("resolved Promise should close the stream", async () => {
const stream = new PushReadableStream(async () => {});
const reader = stream.getReader();
await reader.closed;
});
it("reject Promise should error the stream", async () => {
const stream = new PushReadableStream(async () => {
await delay(1);
throw new Error("error");
});
const reader = stream.getReader();
await expect(reader.closed).rejects.toThrow("error");
});
});
describe(".close", () => {
it("should close the stream", async () => {
const stream = new PushReadableStream((controller) => {
controller.close();
});
const reader = stream.getReader();
await expect(reader.closed).resolves.toBeUndefined();
});
it("should work with async `source`", () => {
const stream = new PushReadableStream(async (controller) => {
await delay(1);
controller.close();
});
const reader = stream.getReader();
return expect(reader.closed).resolves.toBeUndefined();
});
});
});

View file

@ -17,9 +17,28 @@ export type PushReadableStreamSource<T> = (
controller: PushReadableStreamController<T>, controller: PushReadableStreamController<T>,
) => void | Promise<void>; ) => void | Promise<void>;
export class PushReadableStream<T> extends ReadableStream<T> { export type PushReadableLogger<T> = (
#zeroHighWaterMarkAllowEnqueue = false; event:
| {
source: "producer";
operation: "enqueue";
value: T;
phase: "start" | "waiting" | "ignored" | "complete";
}
| {
source: "producer";
operation: "close" | "error";
explicit: boolean;
phase: "start" | "ignored" | "complete";
}
| {
source: "consumer";
operation: "pull" | "cancel";
phase: "start" | "complete";
},
) => void;
export class PushReadableStream<T> extends ReadableStream<T> {
/** /**
* Create a new `PushReadableStream` from a source. * Create a new `PushReadableStream` from a source.
* *
@ -30,81 +49,256 @@ export class PushReadableStream<T> extends ReadableStream<T> {
constructor( constructor(
source: PushReadableStreamSource<T>, source: PushReadableStreamSource<T>,
strategy?: QueuingStrategy<T>, strategy?: QueuingStrategy<T>,
logger?: PushReadableLogger<T>,
) { ) {
let waterMarkLow: PromiseResolver<void> | undefined; let waterMarkLow: PromiseResolver<void> | undefined;
let zeroHighWaterMarkAllowEnqueue = false;
const abortController = new AbortController(); const abortController = new AbortController();
super( super(
{ {
start: async (controller) => { start: (controller) => {
await Promise.resolve();
const result = source({ const result = source({
abortSignal: abortController.signal, abortSignal: abortController.signal,
enqueue: async (chunk) => { enqueue: async (chunk) => {
logger?.({
source: "producer",
operation: "enqueue",
value: chunk,
phase: "start",
});
if (abortController.signal.aborted) { if (abortController.signal.aborted) {
// If the stream is already cancelled, // In original `ReadableStream`, calling `enqueue` or `close`
// throw immediately. // on an cancelled stream will throw an error,
throw abortController.signal.reason; //
// But in `PushReadableStream`, `enqueue` is an async function,
// the producer can't just check `abortSignal.aborted`
// before calling `enqueue`, as it might change when waiting
// for the backpressure to be resolved.
//
// So IMO it's better to handle this for the producer
// by simply ignoring the `enqueue` call.
//
// Note that we check `abortSignal.aborted` instead of `stopped`,
// as it's not allowed for the producer to call `enqueue` after
// they called `close` or `error`.
//
// Obviously, the producer should listen to the `abortSignal` and
// stop producing, but most pushing data sources can't be stopped.
logger?.({
source: "producer",
operation: "enqueue",
value: chunk,
phase: "ignored",
});
return;
} }
if (controller.desiredSize === null) { if (controller.desiredSize === null) {
// `desiredSize` being `null` means the stream is in error state, // `desiredSize` being `null` means the stream is in error state,
// `controller.enqueue` will throw an error for us. // `controller.enqueue` will throw an error for us.
controller.enqueue(chunk); controller.enqueue(chunk);
// istanbul ignore next
return; return;
} }
if (this.#zeroHighWaterMarkAllowEnqueue) { if (zeroHighWaterMarkAllowEnqueue) {
this.#zeroHighWaterMarkAllowEnqueue = false; // When `highWaterMark` is set to `0`,
// `controller.desiredSize` will always be `0`,
// even if the consumer has called `reader.read()`.
// (in this case, each `reader.read()`/`pull`
// should allow one `enqueue` of any size)
//
// If the consumer has already called `reader.read()`,
// before the producer tries to `enqueue`,
// `controller.desiredSize` is `0` and normal `waterMarkLow` signal
// will never trigger,
// (because `ReadableStream` prevents reentrance of `pull`)
// The stream will stuck.
//
// So we need a special signal for this case.
zeroHighWaterMarkAllowEnqueue = false;
controller.enqueue(chunk); controller.enqueue(chunk);
logger?.({
source: "producer",
operation: "enqueue",
value: chunk,
phase: "complete",
});
return; return;
} }
if (controller.desiredSize <= 0) { if (controller.desiredSize <= 0) {
logger?.({
source: "producer",
operation: "enqueue",
value: chunk,
phase: "waiting",
});
waterMarkLow = new PromiseResolver<void>(); waterMarkLow = new PromiseResolver<void>();
await waterMarkLow.promise; await waterMarkLow.promise;
// Recheck consumer cancellation after async operations.
if (abortController.signal.aborted) {
logger?.({
source: "producer",
operation: "enqueue",
value: chunk,
phase: "ignored",
});
return;
}
} }
// `controller.enqueue` will throw error for us
// if the stream is already errored.
controller.enqueue(chunk); controller.enqueue(chunk);
logger?.({
source: "producer",
operation: "enqueue",
value: chunk,
phase: "complete",
});
}, },
close() { close() {
logger?.({
source: "producer",
operation: "close",
explicit: true,
phase: "start",
});
// Since `enqueue` on an cancelled stream won't throw an error,
// so does `close`.
if (abortController.signal.aborted) {
logger?.({
source: "producer",
operation: "close",
explicit: true,
phase: "ignored",
});
return;
}
controller.close(); controller.close();
logger?.({
source: "producer",
operation: "close",
explicit: true,
phase: "complete",
});
}, },
error(e) { error(e) {
logger?.({
source: "producer",
operation: "error",
explicit: true,
phase: "start",
});
// Calling `error` on an already closed or errored stream is a no-op.
controller.error(e); controller.error(e);
logger?.({
source: "producer",
operation: "error",
explicit: true,
phase: "complete",
});
}, },
}); });
if (result && "then" in result) { if (result && "then" in result) {
// If `source` returns a `Promise`,
// close the stream when the `Promise` is resolved,
// and error the stream when the `Promise` is rejected.
// The producer can return a never-settling `Promise`
// to disable this behavior.
result.then( result.then(
() => { () => {
logger?.({
source: "producer",
operation: "close",
explicit: false,
phase: "start",
});
try { try {
controller.close(); controller.close();
} catch (e) {
// controller already closed logger?.({
source: "producer",
operation: "close",
explicit: false,
phase: "complete",
});
} catch {
logger?.({
source: "producer",
operation: "close",
explicit: false,
phase: "ignored",
});
// The stream is already closed by the producer,
// Or cancelled by the consumer.
} }
}, },
(e) => { (e) => {
logger?.({
source: "producer",
operation: "error",
explicit: false,
phase: "start",
});
controller.error(e); controller.error(e);
logger?.({
source: "producer",
operation: "error",
explicit: false,
phase: "complete",
});
}, },
); );
} }
}, },
pull: () => { pull: () => {
logger?.({
source: "consumer",
operation: "pull",
phase: "start",
});
if (waterMarkLow) { if (waterMarkLow) {
waterMarkLow.resolve(); waterMarkLow.resolve();
return; } else if (strategy?.highWaterMark === 0) {
} zeroHighWaterMarkAllowEnqueue = true;
if (strategy?.highWaterMark === 0) {
this.#zeroHighWaterMarkAllowEnqueue = true;
} }
logger?.({
source: "consumer",
operation: "pull",
phase: "complete",
});
}, },
cancel: (reason) => { cancel: (reason) => {
logger?.({
source: "consumer",
operation: "cancel",
phase: "start",
});
abortController.abort(reason); abortController.abort(reason);
waterMarkLow?.reject(reason); // Resolve it on cancellation. `pull` will check `abortSignal.aborted` again.
waterMarkLow?.resolve();
logger?.({
source: "consumer",
operation: "cancel",
phase: "complete",
});
}, },
}, },
strategy, strategy,