feat(adb): don't kill subprocess on stream close

This commit is contained in:
Simon Chan 2023-10-17 21:13:36 +08:00
parent e45fb2ed55
commit fe5cb5a176
No known key found for this signature in database
GPG key ID: A8B69F750B9BCEDD
6 changed files with 102 additions and 77 deletions

View file

@ -5,8 +5,8 @@ import type {
AdbIncomingSocketHandler, AdbIncomingSocketHandler,
AdbServerConnection, AdbServerConnection,
AdbServerConnectionOptions, AdbServerConnectionOptions,
AdbServerConnector,
} from "@yume-chan/adb"; } from "@yume-chan/adb";
import type { ReadableWritablePair } from "@yume-chan/stream-extra";
import { import {
PushReadableStream, PushReadableStream,
UnwrapConsumableStream, UnwrapConsumableStream,
@ -15,12 +15,21 @@ import {
} from "@yume-chan/stream-extra"; } from "@yume-chan/stream-extra";
import type { ValueOrPromise } from "@yume-chan/struct"; import type { ValueOrPromise } from "@yume-chan/struct";
function nodeSocketToStreamPair(socket: Socket) { function nodeSocketToConnection(socket: Socket): AdbServerConnection {
socket.setNoDelay(true); socket.setNoDelay(true);
const closed = new Promise<void>((resolve) => {
socket.on("close", resolve);
});
return { return {
readable: new PushReadableStream<Uint8Array>((controller) => { readable: new PushReadableStream<Uint8Array>((controller) => {
// eslint-disable-next-line @typescript-eslint/no-misused-promises // eslint-disable-next-line @typescript-eslint/no-misused-promises
socket.on("data", async (data) => { socket.on("data", async (data) => {
if (controller.abortSignal.aborted) {
return;
}
socket.pause(); socket.pause();
await controller.enqueue(data); await controller.enqueue(data);
socket.resume(); socket.resume();
@ -32,9 +41,6 @@ function nodeSocketToStreamPair(socket: Socket) {
// controller already closed // controller already closed
} }
}); });
controller.abortSignal.addEventListener("abort", () => {
socket.end();
});
}), }),
writable: new WritableStream<Uint8Array>({ writable: new WritableStream<Uint8Array>({
write: async (chunk) => { write: async (chunk) => {
@ -48,16 +54,17 @@ function nodeSocketToStreamPair(socket: Socket) {
}); });
}); });
}, },
close() {
return new Promise<void>((resolve) => {
socket.end(resolve);
});
},
}), }),
get closed() {
return closed;
},
close() {
socket.end();
},
}; };
} }
export class AdbServerNodeTcpConnection implements AdbServerConnection { export class AdbServerNodeTcpConnector implements AdbServerConnector {
readonly spec: SocketConnectOpts; readonly spec: SocketConnectOpts;
readonly #listeners = new Map<string, Server>(); readonly #listeners = new Map<string, Server>();
@ -68,7 +75,7 @@ export class AdbServerNodeTcpConnection implements AdbServerConnection {
async connect( async connect(
{ unref }: AdbServerConnectionOptions = { unref: false }, { unref }: AdbServerConnectionOptions = { unref: false },
): Promise<ReadableWritablePair<Uint8Array, Uint8Array>> { ): Promise<AdbServerConnection> {
const socket = new Socket(); const socket = new Socket();
if (unref) { if (unref) {
socket.unref(); socket.unref();
@ -78,7 +85,7 @@ export class AdbServerNodeTcpConnection implements AdbServerConnection {
socket.once("connect", resolve); socket.once("connect", resolve);
socket.once("error", reject); socket.once("error", reject);
}); });
return nodeSocketToStreamPair(socket); return nodeSocketToConnection(socket);
} }
async addReverseTunnel( async addReverseTunnel(
@ -87,16 +94,19 @@ export class AdbServerNodeTcpConnection implements AdbServerConnection {
): Promise<string> { ): Promise<string> {
// eslint-disable-next-line @typescript-eslint/no-misused-promises // eslint-disable-next-line @typescript-eslint/no-misused-promises
const server = new Server(async (socket) => { const server = new Server(async (socket) => {
const stream = nodeSocketToStreamPair(socket); const connection = nodeSocketToConnection(socket);
try { try {
await handler({ await handler({
service: address!, service: address!,
readable: stream.readable, readable: connection.readable,
writable: new WrapWritableStream( writable: new WrapWritableStream(
stream.writable, connection.writable,
).bePipedThroughFrom(new UnwrapConsumableStream()), ).bePipedThroughFrom(new UnwrapConsumableStream()),
close() { get closed() {
socket.end(); return connection.closed;
},
async close() {
await connection.close();
}, },
}); });
} catch { } catch {

View file

@ -22,7 +22,9 @@ export interface Closeable {
export interface AdbSocket export interface AdbSocket
extends ReadableWritablePair<Uint8Array, Consumable<Uint8Array>>, extends ReadableWritablePair<Uint8Array, Consumable<Uint8Array>>,
Closeable { Closeable {
readonly service: string; get service(): string;
get closed(): Promise<void>;
} }
export type AdbIncomingSocketHandler = ( export type AdbIncomingSocketHandler = (

View file

@ -1,7 +1,8 @@
import type { Consumable, WritableStream } from "@yume-chan/stream-extra"; import type { Consumable, WritableStream } from "@yume-chan/stream-extra";
import { DuplexStreamFactory, ReadableStream } from "@yume-chan/stream-extra"; import { ReadableStream } from "@yume-chan/stream-extra";
import type { Adb, AdbSocket } from "../../../adb.js"; import type { Adb, AdbSocket } from "../../../adb.js";
import { unreachable } from "../../../utils/index.js";
import type { AdbSubprocessProtocol } from "./types.js"; import type { AdbSubprocessProtocol } from "./types.js";
@ -34,19 +35,16 @@ export class AdbSubprocessNoneProtocol implements AdbSubprocessProtocol {
readonly #socket: AdbSocket; readonly #socket: AdbSocket;
readonly #duplex: DuplexStreamFactory<Uint8Array, Uint8Array>;
// Legacy shell forwards all data to stdin. // Legacy shell forwards all data to stdin.
get stdin(): WritableStream<Consumable<Uint8Array>> { get stdin(): WritableStream<Consumable<Uint8Array>> {
return this.#socket.writable; return this.#socket.writable;
} }
#stdout: ReadableStream<Uint8Array>;
/** /**
* Legacy shell mixes stdout and stderr. * Legacy shell mixes stdout and stderr.
*/ */
get stdout(): ReadableStream<Uint8Array> { get stdout(): ReadableStream<Uint8Array> {
return this.#stdout; return this.#socket.readable;
} }
#stderr: ReadableStream<Uint8Array>; #stderr: ReadableStream<Uint8Array>;
@ -65,24 +63,21 @@ export class AdbSubprocessNoneProtocol implements AdbSubprocessProtocol {
constructor(socket: AdbSocket) { constructor(socket: AdbSocket) {
this.#socket = socket; this.#socket = socket;
// Link `stdout`, `stderr` and `stdin` together, this.#stderr = new ReadableStream({
// so closing any of them will close the others. start: (controller) => {
this.#duplex = new DuplexStreamFactory<Uint8Array, Uint8Array>({ this.#socket.closed
close: async () => { .then(() => controller.close())
await this.#socket.close(); .catch(unreachable);
}, },
}); });
this.#exit = socket.closed.then(() => 0);
this.#stdout = this.#duplex.wrapReadable(this.#socket.readable);
this.#stderr = this.#duplex.wrapReadable(new ReadableStream());
this.#exit = this.#duplex.closed.then(() => 0);
} }
resize() { resize() {
// Not supported, but don't throw. // Not supported, but don't throw.
} }
kill() { async kill() {
return this.#duplex.close(); await this.#socket.close();
} }
} }

View file

@ -159,6 +159,7 @@ export class AdbSubprocessShellProtocol implements AdbSubprocessProtocol {
let stdoutController!: PushReadableStreamController<Uint8Array>; let stdoutController!: PushReadableStreamController<Uint8Array>;
let stderrController!: PushReadableStreamController<Uint8Array>; let stderrController!: PushReadableStreamController<Uint8Array>;
this.#stdout = new PushReadableStream<Uint8Array>((controller) => { this.#stdout = new PushReadableStream<Uint8Array>((controller) => {
stdoutController = controller; stdoutController = controller;
}); });
@ -176,10 +177,14 @@ export class AdbSubprocessShellProtocol implements AdbSubprocessProtocol {
this.#exit.resolve(chunk.data[0]!); this.#exit.resolve(chunk.data[0]!);
break; break;
case AdbShellProtocolId.Stdout: case AdbShellProtocolId.Stdout:
await stdoutController.enqueue(chunk.data); if (!stdoutController.abortSignal.aborted) {
await stdoutController.enqueue(chunk.data);
}
break; break;
case AdbShellProtocolId.Stderr: case AdbShellProtocolId.Stderr:
await stderrController.enqueue(chunk.data); if (!stderrController.abortSignal.aborted) {
await stderrController.enqueue(chunk.data);
}
break; break;
} }
}, },

View file

@ -3,13 +3,11 @@
import { PromiseResolver } from "@yume-chan/async"; import { PromiseResolver } from "@yume-chan/async";
import type { import type {
AbortSignal, AbortSignal,
Consumable,
ReadableWritablePair, ReadableWritablePair,
WritableStreamDefaultWriter, WritableStreamDefaultWriter,
} from "@yume-chan/stream-extra"; } from "@yume-chan/stream-extra";
import { import {
BufferedReadableStream, BufferedReadableStream,
DuplexStreamFactory,
UnwrapConsumableStream, UnwrapConsumableStream,
WrapWritableStream, WrapWritableStream,
} from "@yume-chan/stream-extra"; } from "@yume-chan/stream-extra";
@ -25,7 +23,7 @@ import {
encodeUtf8, encodeUtf8,
} from "@yume-chan/struct"; } from "@yume-chan/struct";
import type { AdbIncomingSocketHandler, AdbSocket } from "../adb.js"; import type { AdbIncomingSocketHandler, AdbSocket, Closeable } from "../adb.js";
import { AdbBanner } from "../banner.js"; import { AdbBanner } from "../banner.js";
import type { AdbFeature } from "../features.js"; import type { AdbFeature } from "../features.js";
import { NOOP, hexToNumber, numberToHex } from "../utils/index.js"; import { NOOP, hexToNumber, numberToHex } from "../utils/index.js";
@ -37,10 +35,16 @@ export interface AdbServerConnectionOptions {
signal?: AbortSignal | undefined; signal?: AbortSignal | undefined;
} }
export interface AdbServerConnection { export interface AdbServerConnection
extends ReadableWritablePair<Uint8Array, Uint8Array>,
Closeable {
get closed(): Promise<void>;
}
export interface AdbServerConnector {
connect( connect(
options?: AdbServerConnectionOptions, options?: AdbServerConnectionOptions,
): ValueOrPromise<ReadableWritablePair<Uint8Array, Uint8Array>>; ): ValueOrPromise<AdbServerConnection>;
addReverseTunnel( addReverseTunnel(
handler: AdbIncomingSocketHandler, handler: AdbIncomingSocketHandler,
@ -74,9 +78,9 @@ export interface AdbServerDevice {
export class AdbServerClient { export class AdbServerClient {
static readonly VERSION = 41; static readonly VERSION = 41;
readonly connection: AdbServerConnection; readonly connection: AdbServerConnector;
constructor(connection: AdbServerConnection) { constructor(connection: AdbServerConnector) {
this.connection = connection; this.connection = connection;
} }
@ -126,30 +130,41 @@ export class AdbServerClient {
async connect( async connect(
request: string, request: string,
options?: AdbServerConnectionOptions, options?: AdbServerConnectionOptions,
): Promise<ReadableWritablePair<Uint8Array, Uint8Array>> { ): Promise<AdbServerConnection> {
const connection = await this.connection.connect(options); const connection = await this.connection.connect(options);
const writer = connection.writable.getWriter(); try {
await AdbServerClient.writeString(writer, request); const writer = connection.writable.getWriter();
await AdbServerClient.writeString(writer, request);
writer.releaseLock();
} catch (e) {
await connection.readable.cancel();
await connection.close();
throw e;
}
const readable = new BufferedReadableStream(connection.readable); const readable = new BufferedReadableStream(connection.readable);
try { try {
// `raceSignal` throws if the signal is aborted, // `raceSignal` throws when the signal is aborted,
// so the `catch` block can close the connection. // so the `catch` block can close the connection.
await raceSignal( await raceSignal(
() => AdbServerClient.readOkay(readable), () => AdbServerClient.readOkay(readable),
options?.signal, options?.signal,
); );
writer.releaseLock();
return { return {
readable: readable.release(), readable: readable.release(),
writable: connection.writable, writable: connection.writable,
get closed() {
return connection.closed;
},
async close() {
await connection.close();
},
}; };
} catch (e) { } catch (e) {
writer.close().catch(NOOP); await readable.cancel().catch(NOOP);
readable.cancel().catch(NOOP); await connection.close();
throw e; throw e;
} }
} }
@ -328,8 +343,18 @@ export class AdbServerClient {
} }
const connection = await this.connect(switchService); const connection = await this.connect(switchService);
try {
const writer = connection.writable.getWriter();
await AdbServerClient.writeString(writer, service);
writer.releaseLock();
} catch (e) {
await connection.readable.cancel();
await connection.close();
throw e;
}
const readable = new BufferedReadableStream(connection.readable); const readable = new BufferedReadableStream(connection.readable);
const writer = connection.writable.getWriter();
try { try {
if (transportId === undefined) { if (transportId === undefined) {
const array = await readable.readExactly(8); const array = await readable.readExactly(8);
@ -342,34 +367,25 @@ export class AdbServerClient {
transportId = BigIntFieldType.Uint64.getter(dataView, 0, true); transportId = BigIntFieldType.Uint64.getter(dataView, 0, true);
} }
await AdbServerClient.writeString(writer, service);
await AdbServerClient.readOkay(readable); await AdbServerClient.readOkay(readable);
writer.releaseLock();
const duplex = new DuplexStreamFactory<
Uint8Array,
Consumable<Uint8Array>
>();
const wrapReadable = duplex.wrapReadable(readable.release());
const wrapWritable = duplex.createWritable(
new WrapWritableStream(connection.writable).bePipedThroughFrom(
new UnwrapConsumableStream(),
),
);
return { return {
transportId, transportId,
service, service,
readable: wrapReadable, readable: readable.release(),
writable: wrapWritable, writable: new WrapWritableStream(
close() { connection.writable,
return duplex.close(); ).bePipedThroughFrom(new UnwrapConsumableStream()),
get closed() {
return connection.closed;
},
async close() {
await connection.close();
}, },
}; };
} catch (e) { } catch (e) {
writer.close().catch(NOOP); await readable.cancel().catch(NOOP);
readable.cancel().catch(NOOP); await connection.close();
throw e; throw e;
} }
} }

View file

@ -45,10 +45,7 @@ export class PushReadableStream<T> extends ReadableStream<T> {
if (abortController.signal.aborted) { if (abortController.signal.aborted) {
// If the stream is already cancelled, // If the stream is already cancelled,
// throw immediately. // throw immediately.
throw ( throw abortController.signal.reason;
abortController.signal.reason ??
new Error("Aborted")
);
} }
if (controller.desiredSize === null) { if (controller.desiredSize === null) {