Skip to content

Commit e6d6bd0

Browse files
committed
fix: Fix WebSocket client to maintain subscription even if other subscription with same parameter ends
1 parent d9b3539 commit e6d6bd0

10 files changed

+208
-8
lines changed

src/adapters/action/dispatcher-ws.spec.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import { MastoUnexpectedError } from "../errors";
22
import { createLogger } from "../logger";
33
import { SerializerNativeImpl } from "../serializers";
4-
import { WebSocketConnectorImpl } from "../ws";
4+
import {
5+
WebSocketConnectorImpl,
6+
WebSocketSubscriptionCounterImpl,
7+
} from "../ws";
58
import { WebSocketActionDispatcher } from "./dispatcher-ws";
69

710
describe("DispatcherWs", () => {
@@ -10,6 +13,7 @@ describe("DispatcherWs", () => {
1013
new WebSocketConnectorImpl({
1114
constructorParameters: ["wss://example.com"],
1215
}),
16+
new WebSocketSubscriptionCounterImpl(),
1317
new SerializerNativeImpl(),
1418
createLogger("error"),
1519
);
@@ -31,6 +35,7 @@ describe("DispatcherWs", () => {
3135
});
3236
const dispatcher = new WebSocketActionDispatcher(
3337
connector,
38+
new WebSocketSubscriptionCounterImpl(),
3439
new SerializerNativeImpl(),
3540
createLogger("error"),
3641
);

src/adapters/action/dispatcher-ws.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
type Logger,
55
type Serializer,
66
type WebSocketConnector,
7+
type WebSocketSubscriptionCounter,
78
} from "../../interfaces";
89
import { MastoUnexpectedError } from "../errors";
910
import { WebSocketSubscription } from "../ws";
@@ -16,6 +17,7 @@ export class WebSocketActionDispatcher
1617
{
1718
constructor(
1819
private readonly connector: WebSocketConnector,
20+
private readonly counter: WebSocketSubscriptionCounter,
1921
private readonly serializer: Serializer,
2022
private readonly logger?: Logger,
2123
) {}
@@ -39,6 +41,7 @@ export class WebSocketActionDispatcher
3941

4042
return new WebSocketSubscription(
4143
this.connector,
44+
this.counter,
4245
this.serializer,
4346
stream,
4447
this.logger,

src/adapters/clients.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import {
1515
import { HttpNativeImpl } from "./http";
1616
import { createLogger } from "./logger";
1717
import { SerializerNativeImpl } from "./serializers";
18-
import { WebSocketConnectorImpl } from "./ws";
18+
import { WebSocketConnectorImpl, WebSocketSubscriptionCounterImpl } from "./ws";
1919

2020
interface LogConfigProps {
2121
/**
@@ -87,8 +87,10 @@ export function createStreamingAPIClient(
8787
},
8888
logger,
8989
);
90+
const counter = new WebSocketSubscriptionCounterImpl();
9091
const actionDispatcher = new WebSocketActionDispatcher(
9192
connector,
93+
counter,
9294
serializer,
9395
logger,
9496
);

src/adapters/ws/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
export * from "./web-socket-connector";
22
export * from "./web-socket-subscription";
3+
export * from "./web-socket-subscription-counter";
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import { WebSocketSubscriptionCounterImpl } from "./web-socket-subscription-counter";
2+
3+
describe("WebSocketSubscriptionCounter", () => {
4+
it("counts", () => {
5+
const counter = new WebSocketSubscriptionCounterImpl();
6+
expect(counter.count("stream")).toBe(0);
7+
});
8+
9+
it("increments", () => {
10+
const counter = new WebSocketSubscriptionCounterImpl();
11+
counter.increment("stream");
12+
expect(counter.count("stream")).toBe(1);
13+
});
14+
15+
it("decrements", () => {
16+
const counter = new WebSocketSubscriptionCounterImpl();
17+
counter.increment("stream");
18+
expect(counter.count("stream")).toBe(1);
19+
counter.decrement("stream");
20+
expect(counter.count("stream")).toBe(0);
21+
});
22+
23+
it("count differently for different params", () => {
24+
const counter = new WebSocketSubscriptionCounterImpl();
25+
counter.increment("stream", { foo: "bar" });
26+
expect(counter.count("stream", { foo: "bar" })).toBe(1);
27+
expect(counter.count("stream", { foo: "baz" })).toBe(0);
28+
});
29+
30+
it("does not decrement non-existing stream", () => {
31+
const counter = new WebSocketSubscriptionCounterImpl();
32+
expect(() => counter.decrement("stream")).toThrow();
33+
});
34+
});
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import { type WebSocketSubscriptionCounter } from "../../interfaces";
2+
3+
export class WebSocketSubscriptionCounterImpl
4+
implements WebSocketSubscriptionCounter
5+
{
6+
private counts = new Map<string, number>();
7+
8+
count(stream: string, params?: Record<string, unknown>): number {
9+
const key = this.hash(stream, params);
10+
11+
return this.counts.get(key) ?? 0;
12+
}
13+
14+
increment(stream: string, params?: Record<string, unknown>): void {
15+
const key = this.hash(stream, params);
16+
17+
if (!this.counts.has(key)) {
18+
this.counts.set(key, 0);
19+
}
20+
21+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
22+
this.counts.set(key, this.counts.get(key)! + 1);
23+
}
24+
25+
decrement(stream: string, params?: Record<string, unknown>): void {
26+
const key = this.hash(stream, params);
27+
28+
if (!this.counts.has(key)) {
29+
throw new Error("Cannot decrement non-existent count");
30+
}
31+
32+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
33+
this.counts.set(key, this.counts.get(key)! - 1);
34+
}
35+
36+
private hash(stream: string, params?: Record<string, unknown>): string {
37+
return JSON.stringify({ stream, params });
38+
}
39+
}

src/adapters/ws/web-socket-subscription.spec.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { createLogger } from "../logger";
55
import { SerializerNativeImpl } from "../serializers";
66
import { WebSocketConnectorImpl } from "./web-socket-connector";
77
import { WebSocketSubscription } from "./web-socket-subscription";
8+
import { WebSocketSubscriptionCounterImpl } from "./web-socket-subscription-counter";
89

910
describe("WebSocketSubscription", () => {
1011
it("doesn't do anything if no connection was established", async () => {
@@ -15,6 +16,7 @@ describe("WebSocketSubscription", () => {
1516
{ constructorParameters: ["ws://localhost:0"] },
1617
logger,
1718
),
19+
new WebSocketSubscriptionCounterImpl(),
1820
new SerializerNativeImpl(),
1921
"public",
2022
logger,
@@ -47,6 +49,7 @@ describe("WebSocketSubscription", () => {
4749
);
4850
const subscription = new WebSocketSubscription(
4951
connection,
52+
new WebSocketSubscriptionCounterImpl(),
5053
new SerializerNativeImpl(),
5154
"public",
5255
logger,

src/adapters/ws/web-socket-subscription.ts

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
type Logger,
55
type Serializer,
66
type WebSocketConnector,
7+
type WebSocketSubscriptionCounter,
78
} from "../../interfaces";
89
import { type mastodon } from "../../mastodon";
910
import { MastoUnexpectedError } from "../errors";
@@ -14,6 +15,7 @@ export class WebSocketSubscription implements mastodon.streaming.Subscription {
1415

1516
constructor(
1617
private readonly connector: WebSocketConnector,
18+
private readonly counter: WebSocketSubscriptionCounter,
1719
private readonly serializer: Serializer,
1820
private readonly stream: string,
1921
private readonly logger?: Logger,
@@ -34,6 +36,7 @@ export class WebSocketSubscription implements mastodon.streaming.Subscription {
3436

3537
this.logger?.log("debug", "↑ WEBSOCKET", data);
3638
this.connection.send(data);
39+
this.counter.increment(this.stream, this.params);
3740

3841
const messages = toAsyncIterable(this.connection);
3942

@@ -55,13 +58,19 @@ export class WebSocketSubscription implements mastodon.streaming.Subscription {
5558
return;
5659
}
5760

58-
const data = this.serializer.serialize("json", {
59-
type: "unsubscribe",
60-
stream: this.stream,
61-
...this.params,
62-
});
61+
this.counter.decrement(this.stream, this.params);
6362

64-
this.connection.send(data);
63+
if (this.counter.count(this.stream, this.params) <= 0) {
64+
const data = this.serializer.serialize("json", {
65+
type: "unsubscribe",
66+
stream: this.stream,
67+
...this.params,
68+
});
69+
70+
this.connection.send(data);
71+
}
72+
73+
this.connection = undefined;
6574
}
6675

6776
[Symbol.asyncIterator](): AsyncIterableIterator<mastodon.streaming.Event> {

src/interfaces/ws.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,9 @@ export interface WebSocketConnector {
55
close(): void;
66
canAcquire(): boolean;
77
}
8+
9+
export interface WebSocketSubscriptionCounter {
10+
count(stream: string, params?: Record<string, unknown>): number;
11+
increment(stream: string, params?: Record<string, unknown>): void;
12+
decrement(stream: string, params?: Record<string, unknown>): void;
13+
}

tests/streaming/connections.spec.ts

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import assert from "node:assert";
2+
3+
it("maintains connections for the event even if other handlers closed it", async () => {
4+
await using alice = await sessions.acquire({ waitForWs: true });
5+
6+
using subscription1 = alice.ws.hashtag.subscribe({ tag: "test" });
7+
using subscription2 = alice.ws.hashtag.subscribe({ tag: "test" });
8+
9+
const promise1 = subscription1.values().take(1).toArray();
10+
const promise2 = subscription2.values().take(2).toArray();
11+
12+
// Dispatch event for subscription1 to establish connection
13+
const status1 = await alice.rest.v1.statuses.create({
14+
status: "#test",
15+
visibility: "public",
16+
});
17+
await promise1;
18+
subscription1.unsubscribe();
19+
20+
// subscription1 is now closed, so status2 will only be dispatched to subscription2
21+
const status2 = await alice.rest.v1.statuses.create({
22+
status: "#test",
23+
visibility: "public",
24+
});
25+
26+
try {
27+
const [e1, e2] = await promise2;
28+
assert(e1.event === "update");
29+
expect(e1.payload.id).toBe(status1.id);
30+
31+
assert(e2.event === "update");
32+
expect(e2.payload.id).toBe(status2.id);
33+
} finally {
34+
await alice.rest.v1.statuses.$select(status1.id).remove();
35+
await alice.rest.v1.statuses.$select(status2.id).remove();
36+
}
37+
});
38+
39+
it("maintains connections for the event if unsubscribe called twice", async () => {
40+
await using alice = await sessions.acquire({ waitForWs: true });
41+
42+
using subscription1 = alice.ws.hashtag.subscribe({ tag: "test" });
43+
using subscription2 = alice.ws.hashtag.subscribe({ tag: "test" });
44+
45+
const promise1 = subscription1.values().take(1).toArray();
46+
const promise2 = subscription2.values().take(2).toArray();
47+
48+
const status1 = await alice.rest.v1.statuses.create({
49+
status: "#test",
50+
visibility: "public",
51+
});
52+
await promise1;
53+
subscription1.unsubscribe();
54+
subscription1.unsubscribe();
55+
subscription1.unsubscribe();
56+
subscription1.unsubscribe();
57+
58+
const status2 = await alice.rest.v1.statuses.create({
59+
status: "#test",
60+
visibility: "public",
61+
});
62+
63+
try {
64+
const [e1, e2] = await promise2;
65+
assert(e1.event === "update");
66+
expect(e1.payload.id).toBe(status1.id);
67+
68+
assert(e2.event === "update");
69+
expect(e2.payload.id).toBe(status2.id);
70+
} finally {
71+
await alice.rest.v1.statuses.$select(status1.id).remove();
72+
await alice.rest.v1.statuses.$select(status2.id).remove();
73+
}
74+
});
75+
76+
it("maintains connections for the event if another handler called unsubscribe before connection established", async () => {
77+
await using alice = await sessions.acquire({ waitForWs: true });
78+
79+
using subscription1 = alice.ws.hashtag.subscribe({ tag: "test" });
80+
using subscription2 = alice.ws.hashtag.subscribe({ tag: "test" });
81+
82+
subscription1.unsubscribe();
83+
84+
const promise2 = subscription2.values().take(1).toArray();
85+
86+
const status1 = await alice.rest.v1.statuses.create({
87+
status: "#test",
88+
visibility: "public",
89+
});
90+
91+
try {
92+
const [e1] = await promise2;
93+
assert(e1.event === "update");
94+
expect(e1.payload.id).toBe(status1.id);
95+
} finally {
96+
await alice.rest.v1.statuses.$select(status1.id).remove();
97+
}
98+
});

0 commit comments

Comments
 (0)