feat(adb): don't kill subprocess on stream close

This commit is contained in:
Simon Chan 2023-10-17 21:13:36 +08:00
parent e45fb2ed55
commit fe5cb5a176
No known key found for this signature in database
GPG key ID: A8B69F750B9BCEDD
6 changed files with 102 additions and 77 deletions

View file

@ -5,8 +5,8 @@ import type {
AdbIncomingSocketHandler,
AdbServerConnection,
AdbServerConnectionOptions,
AdbServerConnector,
} from "@yume-chan/adb";
import type { ReadableWritablePair } from "@yume-chan/stream-extra";
import {
PushReadableStream,
UnwrapConsumableStream,
@ -15,12 +15,21 @@ import {
} from "@yume-chan/stream-extra";
import type { ValueOrPromise } from "@yume-chan/struct";
function nodeSocketToStreamPair(socket: Socket) {
function nodeSocketToConnection(socket: Socket): AdbServerConnection {
socket.setNoDelay(true);
const closed = new Promise<void>((resolve) => {
socket.on("close", resolve);
});
return {
readable: new PushReadableStream<Uint8Array>((controller) => {
// eslint-disable-next-line @typescript-eslint/no-misused-promises
socket.on("data", async (data) => {
if (controller.abortSignal.aborted) {
return;
}
socket.pause();
await controller.enqueue(data);
socket.resume();
@ -32,9 +41,6 @@ function nodeSocketToStreamPair(socket: Socket) {
// controller already closed
}
});
controller.abortSignal.addEventListener("abort", () => {
socket.end();
});
}),
writable: new WritableStream<Uint8Array>({
write: async (chunk) => {
@ -48,16 +54,17 @@ function nodeSocketToStreamPair(socket: Socket) {
});
});
},
close() {
return new Promise<void>((resolve) => {
socket.end(resolve);
});
},
}),
get closed() {
return closed;
},
close() {
socket.end();
},
};
}
export class AdbServerNodeTcpConnection implements AdbServerConnection {
export class AdbServerNodeTcpConnector implements AdbServerConnector {
readonly spec: SocketConnectOpts;
readonly #listeners = new Map<string, Server>();
@ -68,7 +75,7 @@ export class AdbServerNodeTcpConnection implements AdbServerConnection {
async connect(
{ unref }: AdbServerConnectionOptions = { unref: false },
): Promise<ReadableWritablePair<Uint8Array, Uint8Array>> {
): Promise<AdbServerConnection> {
const socket = new Socket();
if (unref) {
socket.unref();
@ -78,7 +85,7 @@ export class AdbServerNodeTcpConnection implements AdbServerConnection {
socket.once("connect", resolve);
socket.once("error", reject);
});
return nodeSocketToStreamPair(socket);
return nodeSocketToConnection(socket);
}
async addReverseTunnel(
@ -87,16 +94,19 @@ export class AdbServerNodeTcpConnection implements AdbServerConnection {
): Promise<string> {
// eslint-disable-next-line @typescript-eslint/no-misused-promises
const server = new Server(async (socket) => {
const stream = nodeSocketToStreamPair(socket);
const connection = nodeSocketToConnection(socket);
try {
await handler({
service: address!,
readable: stream.readable,
readable: connection.readable,
writable: new WrapWritableStream(
stream.writable,
connection.writable,
).bePipedThroughFrom(new UnwrapConsumableStream()),
close() {
socket.end();
get closed() {
return connection.closed;
},
async close() {
await connection.close();
},
});
} catch {

View file

@ -22,7 +22,9 @@ export interface Closeable {
export interface AdbSocket
extends ReadableWritablePair<Uint8Array, Consumable<Uint8Array>>,
Closeable {
readonly service: string;
get service(): string;
get closed(): Promise<void>;
}
export type AdbIncomingSocketHandler = (

View file

@ -1,7 +1,8 @@
import type { Consumable, WritableStream } from "@yume-chan/stream-extra";
import { DuplexStreamFactory, ReadableStream } from "@yume-chan/stream-extra";
import { ReadableStream } from "@yume-chan/stream-extra";
import type { Adb, AdbSocket } from "../../../adb.js";
import { unreachable } from "../../../utils/index.js";
import type { AdbSubprocessProtocol } from "./types.js";
@ -34,19 +35,16 @@ export class AdbSubprocessNoneProtocol implements AdbSubprocessProtocol {
readonly #socket: AdbSocket;
readonly #duplex: DuplexStreamFactory<Uint8Array, Uint8Array>;
// Legacy shell forwards all data to stdin.
get stdin(): WritableStream<Consumable<Uint8Array>> {
return this.#socket.writable;
}
#stdout: ReadableStream<Uint8Array>;
/**
* Legacy shell mixes stdout and stderr.
*/
get stdout(): ReadableStream<Uint8Array> {
return this.#stdout;
return this.#socket.readable;
}
#stderr: ReadableStream<Uint8Array>;
@ -65,24 +63,21 @@ export class AdbSubprocessNoneProtocol implements AdbSubprocessProtocol {
constructor(socket: AdbSocket) {
this.#socket = socket;
// Link `stdout`, `stderr` and `stdin` together,
// so closing any of them will close the others.
this.#duplex = new DuplexStreamFactory<Uint8Array, Uint8Array>({
close: async () => {
await this.#socket.close();
this.#stderr = new ReadableStream({
start: (controller) => {
this.#socket.closed
.then(() => controller.close())
.catch(unreachable);
},
});
this.#stdout = this.#duplex.wrapReadable(this.#socket.readable);
this.#stderr = this.#duplex.wrapReadable(new ReadableStream());
this.#exit = this.#duplex.closed.then(() => 0);
this.#exit = socket.closed.then(() => 0);
}
resize() {
// Not supported, but don't throw.
}
kill() {
return this.#duplex.close();
async kill() {
await this.#socket.close();
}
}

View file

@ -159,6 +159,7 @@ export class AdbSubprocessShellProtocol implements AdbSubprocessProtocol {
let stdoutController!: PushReadableStreamController<Uint8Array>;
let stderrController!: PushReadableStreamController<Uint8Array>;
this.#stdout = new PushReadableStream<Uint8Array>((controller) => {
stdoutController = controller;
});
@ -176,10 +177,14 @@ export class AdbSubprocessShellProtocol implements AdbSubprocessProtocol {
this.#exit.resolve(chunk.data[0]!);
break;
case AdbShellProtocolId.Stdout:
await stdoutController.enqueue(chunk.data);
if (!stdoutController.abortSignal.aborted) {
await stdoutController.enqueue(chunk.data);
}
break;
case AdbShellProtocolId.Stderr:
await stderrController.enqueue(chunk.data);
if (!stderrController.abortSignal.aborted) {
await stderrController.enqueue(chunk.data);
}
break;
}
},

View file

@ -3,13 +3,11 @@
import { PromiseResolver } from "@yume-chan/async";
import type {
AbortSignal,
Consumable,
ReadableWritablePair,
WritableStreamDefaultWriter,
} from "@yume-chan/stream-extra";
import {
BufferedReadableStream,
DuplexStreamFactory,
UnwrapConsumableStream,
WrapWritableStream,
} from "@yume-chan/stream-extra";
@ -25,7 +23,7 @@ import {
encodeUtf8,
} from "@yume-chan/struct";
import type { AdbIncomingSocketHandler, AdbSocket } from "../adb.js";
import type { AdbIncomingSocketHandler, AdbSocket, Closeable } from "../adb.js";
import { AdbBanner } from "../banner.js";
import type { AdbFeature } from "../features.js";
import { NOOP, hexToNumber, numberToHex } from "../utils/index.js";
@ -37,10 +35,16 @@ export interface AdbServerConnectionOptions {
signal?: AbortSignal | undefined;
}
export interface AdbServerConnection {
export interface AdbServerConnection
extends ReadableWritablePair<Uint8Array, Uint8Array>,
Closeable {
get closed(): Promise<void>;
}
export interface AdbServerConnector {
connect(
options?: AdbServerConnectionOptions,
): ValueOrPromise<ReadableWritablePair<Uint8Array, Uint8Array>>;
): ValueOrPromise<AdbServerConnection>;
addReverseTunnel(
handler: AdbIncomingSocketHandler,
@ -74,9 +78,9 @@ export interface AdbServerDevice {
export class AdbServerClient {
static readonly VERSION = 41;
readonly connection: AdbServerConnection;
readonly connection: AdbServerConnector;
constructor(connection: AdbServerConnection) {
constructor(connection: AdbServerConnector) {
this.connection = connection;
}
@ -126,30 +130,41 @@ export class AdbServerClient {
async connect(
request: string,
options?: AdbServerConnectionOptions,
): Promise<ReadableWritablePair<Uint8Array, Uint8Array>> {
): Promise<AdbServerConnection> {
const connection = await this.connection.connect(options);
const writer = connection.writable.getWriter();
await AdbServerClient.writeString(writer, request);
try {
const writer = connection.writable.getWriter();
await AdbServerClient.writeString(writer, request);
writer.releaseLock();
} catch (e) {
await connection.readable.cancel();
await connection.close();
throw e;
}
const readable = new BufferedReadableStream(connection.readable);
try {
// `raceSignal` throws if the signal is aborted,
// `raceSignal` throws when the signal is aborted,
// so the `catch` block can close the connection.
await raceSignal(
() => AdbServerClient.readOkay(readable),
options?.signal,
);
writer.releaseLock();
return {
readable: readable.release(),
writable: connection.writable,
get closed() {
return connection.closed;
},
async close() {
await connection.close();
},
};
} catch (e) {
writer.close().catch(NOOP);
readable.cancel().catch(NOOP);
await readable.cancel().catch(NOOP);
await connection.close();
throw e;
}
}
@ -328,8 +343,18 @@ export class AdbServerClient {
}
const connection = await this.connect(switchService);
try {
const writer = connection.writable.getWriter();
await AdbServerClient.writeString(writer, service);
writer.releaseLock();
} catch (e) {
await connection.readable.cancel();
await connection.close();
throw e;
}
const readable = new BufferedReadableStream(connection.readable);
const writer = connection.writable.getWriter();
try {
if (transportId === undefined) {
const array = await readable.readExactly(8);
@ -342,34 +367,25 @@ export class AdbServerClient {
transportId = BigIntFieldType.Uint64.getter(dataView, 0, true);
}
await AdbServerClient.writeString(writer, service);
await AdbServerClient.readOkay(readable);
writer.releaseLock();
const duplex = new DuplexStreamFactory<
Uint8Array,
Consumable<Uint8Array>
>();
const wrapReadable = duplex.wrapReadable(readable.release());
const wrapWritable = duplex.createWritable(
new WrapWritableStream(connection.writable).bePipedThroughFrom(
new UnwrapConsumableStream(),
),
);
return {
transportId,
service,
readable: wrapReadable,
writable: wrapWritable,
close() {
return duplex.close();
readable: readable.release(),
writable: new WrapWritableStream(
connection.writable,
).bePipedThroughFrom(new UnwrapConsumableStream()),
get closed() {
return connection.closed;
},
async close() {
await connection.close();
},
};
} catch (e) {
writer.close().catch(NOOP);
readable.cancel().catch(NOOP);
await readable.cancel().catch(NOOP);
await connection.close();
throw e;
}
}

View file

@ -45,10 +45,7 @@ export class PushReadableStream<T> extends ReadableStream<T> {
if (abortController.signal.aborted) {
// If the stream is already cancelled,
// throw immediately.
throw (
abortController.signal.reason ??
new Error("Aborted")
);
throw abortController.signal.reason;
}
if (controller.desiredSize === null) {